import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UnsupportedEncodingException;

public class ModifyATN3
{

  public static byte[] process(byte[] buffer, String searchString, String replacementString, boolean lenientLittleEndianHandling) throws UnsupportedEncodingException
  {
    if (searchString.length() == 0)
      throw new IllegalArgumentException("Search string must not be empty");

    byte[] TXTU_BYTES = "txtu".getBytes("UTF-8");
    if (TXTU_BYTES.length != 4)
      throw new RuntimeException("Internal error: Byte-representation of txtu is not 4 bytes!");
    
    int bufferLen = buffer.length;
    byte[] searchStringBytesBE = searchString.getBytes("UTF-16BE");
    byte[] searchStringBytesLE = searchString.getBytes("UTF-16LE");
    int searchStringBytesLen = searchStringBytesBE.length;

    if (searchStringBytesLen!= searchStringBytesLE.length)
      throw new RuntimeException("Internal error: Big-endian and little-endian UTF-16 encodings of the search string are not the same length!");
    if ((searchStringBytesLen & 1) != 0)
      throw new RuntimeException("Internal error: UTF-16 encoding of the search string does not have an even number of bytes!");

    byte[] replacementStringBytesBE = replacementString.getBytes("UTF-16BE");
    byte[] replacementStringBytesLE = replacementString.getBytes("UTF-16LE");
    int replacementStringBytesLen = replacementStringBytesBE.length;

    if (replacementStringBytesLen != replacementStringBytesLE.length)
      throw new RuntimeException("Internal error: Big-endian and little-endian UTF-16 encodings of the replacement string are not the same length!");
    if ((replacementStringBytesLen & 1) != 0)
      throw new RuntimeException("Internal error: UTF-16 encoding of the replacement string does not have an even number of bytes!");
    
    buffer = buffer.clone();

    for (int curIndex = 0; curIndex < (bufferLen - searchStringBytesLen); curIndex++)
    {
      boolean originalMatchBE = false;
      if (matchByteSequence(buffer, curIndex, searchStringBytesBE))
        originalMatchBE = true;
      else
        if (!matchByteSequence(buffer, curIndex, searchStringBytesLE))
          continue;
      
      // Check if shifting forward by one byte if the opposite endian match
      // This is checked because for regular ASCII characters, a big-endian sequence will usually match little-endian at curIndex+1, and ATN uses both endians...
      int matchIndexBE = -1;
      int matchIndexLE = -1;
      if (originalMatchBE)
        matchIndexBE = curIndex;
      else
        matchIndexLE = curIndex;
      
      if ((curIndex + 1) < (bufferLen - searchStringBytesLen))
      {
        // Check the opposite endian
        if (matchByteSequence(buffer, curIndex + 1, originalMatchBE ? searchStringBytesLE : searchStringBytesBE))
        {
          if (originalMatchBE)
            matchIndexLE = curIndex + 1;
          else
            matchIndexBE = curIndex + 1;
        }
      }
      
      // From what I can tell, UTF-16 strings in ATN always have a UTF-16 null appended to them; find it to help determine string length
      int stringEndIndexBE = -1;
      int stringEndIndexLE = -1;
      // Try both endians as needed
      for (int endian = 0; endian < 2; endian++)
      {
        boolean curBE = (endian == 0);
        int searchStartIndex = curBE ? matchIndexBE : matchIndexLE;
        if (searchStartIndex < 0)
          continue;
        for (int nullSearchIndex = searchStartIndex + searchStringBytesLen; nullSearchIndex < (bufferLen - 1); nullSearchIndex += 2)
        {
          if ((buffer[nullSearchIndex] == 0) && (buffer[nullSearchIndex + 1] == 0))
          {
            if (curBE)
              stringEndIndexBE = nullSearchIndex + 2;
            else
              stringEndIndexLE = nullSearchIndex + 2;
            break;
          }
        }
      }
      
      if ((stringEndIndexBE < 0) && (stringEndIndexLE < 0))
      {
        System.out.println("Match found at index " + curIndex + " of output file, but unable to find the ending null.");
        continue;
      }
      
      // Now move backwards, and try to find the length (Pascal strings)
      // For big-endian strings, the length is stored big-endian.
      //   [big-endian size chars] [big-endian UTF-16 string]
      // For little-endian strings (used for file paths from what I've seen), it gets weird.
      //   [big-endian size bytes + 0x0C] "txtu" [little-endian size bytes + 0x0C] [little-endian size chars] [little-endian UTF-16 string]
      //   Note that three sizes need to be updated.
      //   I don't know if other variations of this are permitted, so leniency can be controlled via "lenientLittleEndianHandling"
      //   To compound the problem, the first character of a little-endian string can look like the length of a big-endian string,
      //   and will be found first when walking backwards. The workaround is that if a BE is found, an LE match must be found in the next iteration.
      // Alternate between checking for little-endian and big-endian depending on current alignment
      boolean checkBE = originalMatchBE;
      
      // If a null character ends up in the string by both BE and LE byte-orderings, abort
      int finishFlags = ((stringEndIndexBE < 0) ? 1 : 0) | ((stringEndIndexLE < 0) ? 2 : 0);
      
      int sizeIndexBE = -1;
      int sizeIndexLE = -1;
      for (int sizeIndex = curIndex - 4; (sizeIndex >= 0) && (finishFlags != 3); sizeIndex--, checkBE = !checkBE)
      {
        int stringEndIndex = checkBE ? stringEndIndexBE : stringEndIndexLE;
        if (stringEndIndex < 0)
          continue;
        if ((buffer[sizeIndex + 4] == 0) && (buffer[sizeIndex + 5] == 0))
        {
          finishFlags |= (checkBE ? 1 : 2);
          continue;
        }
        
        int expectedLengthBytes = (stringEndIndex - sizeIndex - 4);
        int expectedLengthChars = expectedLengthBytes >>> 1;
        if ((checkBE ? readBigEndian(buffer, sizeIndex) : readLittleEndian(buffer, sizeIndex)) != expectedLengthChars)
        {
          // If the other endian has found a size match, quit searching
          if ((checkBE ? sizeIndexLE : sizeIndexBE) >= 0)
            break;
          continue;
        }
        
        finishFlags |= (checkBE ? 1 : 2);
        if (checkBE)
          sizeIndexBE = sizeIndex;
        else
        {
          if (sizeIndex < 12)
            continue;
          if (!lenientLittleEndianHandling)
          {
            int expectedValue = expectedLengthBytes + 0x0C;
            if (readLittleEndian(buffer, sizeIndex - 4) != expectedValue)
            {
              System.out.println("Possible little-endian string found at offset " + curIndex + ", but little-endian byte count is not the expected value.");
              continue;
            }
            if (readBigEndian(buffer, sizeIndex - 12) != expectedValue)
            {
              System.out.println("Possible little-endian string found at offset " + curIndex + ", but big-endian byte count is not the expected value.");
              continue;
            }
            if (!matchByteSequence(buffer, sizeIndex - 8, TXTU_BYTES))
            {
              System.out.println("Possible little-endian string found at offset " + curIndex + ", but signature \"txtu\" not found.");
              continue;
            }
          }
          sizeIndexLE = sizeIndex;
        }
      }
      
      if ((sizeIndexBE < 0) && (sizeIndexLE < 0))
      {
        System.out.println("Unable to find the start of the string for a match at offset " + curIndex + " in output file.");
        break;
      }
      
      // Little-endian takes precedence
      boolean replaceBE = (sizeIndexLE < 0);
      int matchIndex = replaceBE ? matchIndexBE : matchIndexLE;
      int stringEndIndex = replaceBE ? stringEndIndexBE : stringEndIndexLE;
      int sizeIndex = replaceBE ? sizeIndexBE : sizeIndexLE;
      byte[] replacementStringBytes = replaceBE ? replacementStringBytesBE : replacementStringBytesLE;
      
      int sizeChangeBytes = replacementStringBytesLen - searchStringBytesLen;
      int expectedLength = (stringEndIndex - sizeIndex - 4 + sizeChangeBytes) >> 1;

      if (replaceBE)
        writeBigEndian(expectedLength, buffer, sizeIndex);
      else
        writeLittleEndian(expectedLength, buffer, sizeIndex);
      if ((bufferLen + sizeChangeBytes) > buffer.length)
      {
        int newBufferSize = buffer.length;
        while (newBufferSize < (bufferLen + sizeChangeBytes))
          newBufferSize <<= 1;
        byte[] newBuffer = new byte[newBufferSize];
        System.arraycopy(buffer, 0, newBuffer, 0, matchIndex);
        if (replacementStringBytesLen > 0)
          System.arraycopy(replacementStringBytes, 0, newBuffer, matchIndex, replacementStringBytesLen);
        System.arraycopy(buffer, matchIndex + searchStringBytesLen, newBuffer, matchIndex + replacementStringBytesLen, bufferLen - matchIndex - searchStringBytesLen);
        bufferLen += sizeChangeBytes;
        buffer = newBuffer;
      }
      else
      {
        if (sizeChangeBytes != 0)
          System.arraycopy(buffer, matchIndex + searchStringBytesLen, buffer, matchIndex + replacementStringBytesLen, bufferLen - matchIndex - searchStringBytesLen);
        if (replacementStringBytesLen > 0)
          System.arraycopy(replacementStringBytes, 0, buffer, matchIndex, replacementStringBytesLen);
        bufferLen += sizeChangeBytes;
      }
      
      if (!replaceBE)
      {
        // Update the other size fields
        writeLittleEndian(readLittleEndian(buffer, sizeIndex - 4) + sizeChangeBytes, buffer, sizeIndex - 4);
        writeBigEndian(readBigEndian(buffer, sizeIndex - 12) + sizeChangeBytes, buffer, sizeIndex - 12);
      }
      curIndex += replacementStringBytesLen - 1;
    }
    
    if (buffer.length != bufferLen)
    {
      byte[] newBuffer = new byte[bufferLen];
      System.arraycopy(buffer, 0, newBuffer, 0, bufferLen);
      buffer = newBuffer;
    }
    return buffer;
  }

  public static int readBigEndian(byte[] b, int offset)
  {
    int n = 0;
    for (int i = 3; i >= 0; i--)
      n = (n << 8) | (b[offset++] & 0xFF);
    return n;
  }
  
  public static int readLittleEndian(byte[] b, int offset)
  {
    int n = 0;
    for (int i = 3; i >= 0; i--)
      n = (n << 8) | (b[offset + i] & 0xFF);
    return n;
  }
  
  public static void writeBigEndian(int n, byte[] b, int offset)
  {
    for (int i = 3; i >= 0; i--)
    {
      b[offset + i] = (byte)n;
      n >>>= 8;
    }
  }
  
  public static void writeLittleEndian(int n, byte[] b, int offset)
  {
    for (int i = 3; i >= 0; i--)
    {
      b[offset++] = (byte)n;
      n >>>= 8;
    }
  }
  
  public static boolean matchByteSequence(byte[] buffer, int offset, byte[] cmp)
  {
    for (int i = 0; i < cmp.length; i++)
      if (buffer[offset + i] != cmp[i])
        return false;
    return true;
  }
  
  public static void printUsage()
  {
    System.out.println("ModifyATN3 [--lenient-le] inputFile outputFile searchString replacementString [searchString2 replacementString2 [...]]");
    System.out.println("If filenames or replacement strings contains spaces, put quotes around the entire parameter");
    System.out.println("Ex: input.atn \"output file.atn\" \"single post\" \"pro h\"");
    return;
  }
  
  public static void closeIt(Closeable c)
  {
    if (c != null)
    {
      try
      {  c.close();  }
      catch (Exception e)
      {  }
    }
    return;
  }
  
  public static void main(String[] argv)
  {
    if (argv.length < 4)
    {
      printUsage();
      return;
    }

    int y = 0;
    String inputFilename = argv[y++];
    boolean lenientLittleEndianHandling = false;
    if (inputFilename.equalsIgnoreCase("--lenient-le"))
    {
      lenientLittleEndianHandling = true;
      inputFilename = argv[y++];
      if (argv.length < 5)
      {
        printUsage();
        return;
      }
      System.out.println("Lenient little-endian strings enabled");
    }
    String outputFilename = argv[y++];
    if (((argv.length - y) & 1) != 0)
    {
      System.out.println("The last search string is not matched with a replacement string");
      return;
    }

    FileInputStream fis = null;
    FileOutputStream fos = null;
    try
    {
      System.out.println("Opening input file: " + inputFilename);
      fis = new FileInputStream(inputFilename);
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      byte[] buffer = new byte[4096];
      int numBytesRead;
      while ((numBytesRead = fis.read(buffer)) > 0)
        baos.write(buffer, 0, numBytesRead);
      closeIt(fis);
      fis = null;
      byte[] curBuffer = baos.toByteArray();
      System.out.println("Read input file (" + curBuffer.length + " bytes)");
      while (y < argv.length)
      {
        String searchString = argv[y++];
        String replacementString = argv[y++];
        System.out.println("Beginning replacement of \"" + searchString + "\" with \"" + replacementString + "\"");
        curBuffer = process(curBuffer, searchString, replacementString, lenientLittleEndianHandling);
      }

      System.out.println("Opening output file: " + outputFilename);
      fos = new FileOutputStream(outputFilename);
      System.out.println("Writing output file (" + curBuffer.length + " bytes)");
      fos.write(curBuffer);
    }
    catch (Exception e)
    {
      e.printStackTrace();
    }
    finally
    {
      closeIt(fis);
      closeIt(fos);
    }
    System.out.println("Done");
  }
}
