const tf = require("@tensorflow/tfjs-node");

// Tokenization
const natural = require("natural");
const fs = require("fs");

async function readInData() {
  await tf.ready();
  const languageDataSet = tf.data.csv("file://" + "./ger_en_trans.csv");

  // Extract language pairs
  const dataset = languageDataSet.map((record) => ({
    en: cleanData(record.en),
    de: "startToken " + cleanData(record.de) + " endToken",
  }));

  const pairs = await dataset.toArray();
  const indexedDataset = tokenize(pairs);
  // pad(indexedDataset);
  trainModel(indexedDataset);
}

function cleanData(text) {
  //if necessary also remove any characters not a-z and ?!,'
  return text.toLowerCase();
}

function tokenize(data) {
  const tokenizer = new natural.WordTokenizer();

  enData = data.map((row) => tokenizer.tokenize(row.en));
  deData = data.map((row) => tokenizer.tokenize(row.de));

  const enVocabulary = new Map();
  const deVocabulary = new Map();

  // Insert <pad> at index 0
  enVocabulary.set("<pad>", 0);
  deVocabulary.set("<pad>", 0);

  const fillVocabulary = (langData, vocabMap) => {
    langData.forEach((sentence) => {
      sentence.forEach((word) => {
        if (!vocabMap.has(word)) {
          const newIndex = vocabMap.size;
          vocabMap.set(word, newIndex);
        }
      });
    });
  };

  fillVocabulary(enData, enVocabulary);
  fillVocabulary(deData, deVocabulary);

  // store the input and output key value pairs
  fs.writeFileSync(
    "vocabulary/inputVocableSet.json",
    JSON.stringify(switchKeysAndValues(Object.fromEntries(enVocabulary)))
  );
  fs.writeFileSync(
    "vocabulary/outputVocableSet.json",
    JSON.stringify(switchKeysAndValues(Object.fromEntries(deVocabulary)))
  );

  // Replace words with indexes
  const indexedEnData = enData.map((element) =>
    element.map((word) => enVocabulary.get(word))
  );
  const indexedDeData = deData.map((element) =>
    element.map((word) => deVocabulary.get(word))
  );

  const model = createModell(
    enVocabulary.size + 1,
    deVocabulary.size + 1,
    findMaxLength(indexedEnData),
    findMaxLength(indexedDeData)
  );

  return { en: indexedEnData, de: indexedDeData, model: model };
}

function pad(data) {
  const encoderSeq = padSequences(data.en);
  const decoderInp = padSequences(data.de.map((arr) => arr.slice(0, -1))); // Has startToken
  const decoderOutput = padSequences(data.de.map((arr) => arr.slice(1))); // Has endToken
}

function switchKeysAndValues(obj) {
  const switchedObj = {};
  for (const key in obj) {
    if (obj.hasOwnProperty(key)) {
      const value = obj[key];
      switchedObj[value] = key;
    }
  }
  return switchedObj;
}

function findMaxLength(langData) {
  let maxLength = 0;

  for (const element of langData) {
    if (element.length > maxLength) {
      maxLength = element.length;
    }
  }

  return maxLength;
}

function padSequences(sequences) {
  const paddedSequences = [];
  const maxlen = findMaxLength(sequences);

  for (const sequence of sequences) {
    if (sequence.length >= maxlen) {
      paddedSequences.push(sequence.slice(0, maxlen));
    } else {
      const paddingLength = maxlen - sequence.length;
      const paddingArray = new Array(paddingLength).fill(0);
      const paddedSequence = sequence.concat(paddingArray);
      paddedSequences.push(paddedSequence);
    }
  }

  return paddedSequences;
}

function createModell(
  numEncoderTokens,
  numDecoderTokens,
  maxEncoderSequenceLen,
  maxDecoderSequenceLen
) {
  // Encoder model
  const encoderInput = tf.input({ shape: [null], name: "encoderInputLayer" });
  const encoderEmbedding = tf.layers
    .embedding({
      inputDim: numEncoderTokens,
      outputDim: 300,
      inputLength: maxEncoderSequenceLen,
      name: "encoderEmbeddingLayer",
    })
    .apply(encoderInput);
  const encoderLstm = tf.layers
    .lstm({
      units: 256,
      activation: "tanh",
      returnSequences: true,
      returnState: true,
      name: "encoderLstm1Layer",
    })
    .apply(encoderEmbedding);
  const [_, state_h, state_c] = tf.layers
    .lstm({
      units: 256,
      activation: "tanh",
      returnState: true,
      name: "encoderLstm2Layer",
    })
    .apply(encoderLstm);
  const encoderStates = [state_h, state_c];

  // Decoder model
  const decoderInput = tf.input({ shape: [null], name: "decoderInputLayer" });
  const decoderEmbedding = tf.layers
    .embedding({
      inputDim: numDecoderTokens,
      outputDim: 300,
      inputLength: maxDecoderSequenceLen,
      name: "decoderEmbeddingLayer",
    })
    .apply(decoderInput);
  const decoderLstm = tf.layers.lstm({
    units: 256,
    activation: "tanh",
    returnState: true,
    returnSequences: true,
    name: "decoderLstmLayer",
  });
  const [decoderOutputs, ,] = decoderLstm.apply(decoderEmbedding, {
    initialState: encoderStates,
  });
  const decoderDense = tf.layers.dense({
    units: numDecoderTokens + 1,
    activation: "softmax",
    name: "decoderFinalLayer",
  });
  const outputs = decoderDense.apply(decoderOutputs);

  const model = tf.model({ inputs: [encoderInput, decoderInput], outputs });
  return model;
}

async function trainModel(data) {
  const encoderSeq = padSequences(data.en);
  const decoderInp = padSequences(data.de.map((arr) => arr.slice(0, -1))); // Has startToken
  const decoderOutput = padSequences(data.de.map((arr) => arr.slice(1))); // Has endToken

  data.model.compile({
    optimizer: "rmsprop",
    loss: "sparseCategoricalCrossentropy",
    metrics: ["accuracy"],
  });
  const history = await data.model.fit(
    [encoderSeq, decoderInp],
    decoderOutput,
    {
      epochs: 80,
      batch_size: 450,
    }
  );
}

readInData();
