import { Node, PDiff, TreePair } from "../modules";
import { parse } from "../parser/parse";
import { UniqueDiff } from "../unCycle/UniqueDiff";
import { compareDiffTraces, generateDiffAndSubtreeMap, getNearestCommonAncestor, removeNode } from "./DiffFinder";

/**
 *   Author:  Henry Seed
 *   Date: Thurs 9th August 2018
 *   (c) Jaipuna 2018
 *
 *   This contains the Special Case Diffs for the Amy Diff Finder.
 *   To be used in conjunction with the DiffFinder module for parsing and displaying.
 *
 **/

interface SpecialDiff {
    func: string;
    source: Node;
    target: Node;
    map?: Map<string, Node[]>;
}

function findHighestZ(node: Node): number {
    let maxZ: number = 0;

    for (const treenode of node.getBFS()) {
        const nodeName = treenode.name;

        if (nodeName.includes("Z")) {
            try {
                const num = parseFloat(nodeName.split("Z")[1]);
                if (num > maxZ) {
                    maxZ = num;
                }
            } catch {
                continue;
            }
        }
    }
    // console.log(node.toString(), maxZ);
    return maxZ;
}

/**
 * Checks if the * f_box(Z1) * -> Z1 case where * is anything
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function f_boxForwardSpecialDiff(trees: TreePair): SpecialDiff {
    const sString = trees.source.toString();
    const tString = trees.target.toString();
    const allEqualSourceTree = parse(
        sString.replace(/!=/g, "==").replace(/<=/g, "==").replace(/>=/g, "==").replace(/</g, "==").replace(/>/g, "=="),
    );
    const allEqualSourceTreeKids = allEqualSourceTree.args;

    // source contains an f_box()
    const f_has_box: boolean = sString.includes("f_box(");

    if (f_has_box) {
        let f_boxContents = trees.source.getNodesByName("f_box")[0]?.args[0]?.toString();
        // target has fewer f_box()s than source
        const sMatches = sString.match(new RegExp("f_box\\(", "g"));
        const tMatches = tString.match(new RegExp("f_box\\(", "g"));
        const f_fewer_boxes: boolean = (sMatches ? sMatches.length : 0) > (tMatches ? tMatches.length : 0);

        // ... f_box(Z1) ... ->  f_box(Z1) == Z1
        let f_next_box_equals_soln: boolean = false;
        if (trees.target.name === "==") {
            const targetKids = trees.target.args;
            if (targetKids[0].name === "f_box" && targetKids.length > 1) {
                const contents = targetKids[0].args[0];
                if (contents.toString() === targetKids[1].toString()) {
                    f_boxContents = contents.toString();
                    f_next_box_equals_soln = true;
                }
            }
        }

        //  ... f_box(Z1) ... -> Z1
        let f_next_soln: boolean = false;
        const f_boxs: Node[] = trees.source.getNodesByName("f_box");
        for (const node of f_boxs) {
            if (trees.target.toString() === node.args[0].toString()) {
                f_next_soln = true;
                f_boxContents = trees.target.toString();
            }
        }

        // console.log("f_boxContents", f_boxContents);

        // ... f_box(Z1) ... -> ... Z1 ....
        const f_current_equals_next =
            trees.source.cloneDeep().removeNodeByName("f_box").toString() ===
            trees.target.cloneDeep().removeNodeByName("f_box").toString();

        // ... f_box(Z1) ... -> Z1 == Z1
        let f_lhs_equals_rhs = false;
        if (allEqualSourceTree.name === "==") {
            // check both sides are equal
            f_lhs_equals_rhs =
                allEqualSourceTreeKids[0].cloneDeep().removeNodeByName("f_box").toString() ===
                allEqualSourceTreeKids[1].cloneDeep().removeNodeByName("f_box").toString();
        }

        // console.log(
        //     `f_fewer_boxes: ${f_fewer_boxes}\nf_next_box_equals_soln: ${f_next_box_equals_soln}\nf_next_soln: ${f_next_soln}\nf_current_equals_next: ${f_current_equals_next}\nf_lhs_equals_rhs: ${f_lhs_equals_rhs}`,
        // );

        const f_boxContentsTree = parse(f_boxContents);
        const retMap: Map<string, Node[]> = new Map<string, Node[]>();
        retMap.set("Z1", f_boxContentsTree.getBFS());

        // ==============  Check cases ================
        // Case 1   ... f_box(Z1) ... ->  f_box(Z1) == Z1
        if (!f_fewer_boxes && f_next_box_equals_soln && !f_next_soln && !f_current_equals_next) {
            if (f_lhs_equals_rhs) {
                // then Case 1a
                return {
                    source: parse(`(f_box(Z1))`),
                    target: parse(`f_box(Z1) == Z1`),
                    map: retMap,
                    func: "f_boxForwardSpecialDiff",
                };
            }
        }
        // Case 2   ... f_box(Z1) ... ->  Z1
        if (f_fewer_boxes && !f_next_box_equals_soln && f_next_soln && !f_current_equals_next) {
            if (f_lhs_equals_rhs) {
                // then Case 2a
                return {
                    source: parse(`(f_box(Z1))`),
                    target: parse(`Z1`),
                    map: retMap,
                    func: "f_boxForwardSpecialDiff",
                };
            }
        }
        // Case 3       ... f_box(Z1) ... -> ... Z1 ....
        if (f_fewer_boxes && !f_next_box_equals_soln && !f_next_soln && f_current_equals_next) {
            if (f_lhs_equals_rhs) {
                // then Case 3a
                return {
                    source: parse(`(f_box(Z1))`),
                    target: parse(`(Z1)`),
                    map: retMap,
                    func: "f_boxForwardSpecialDiff",
                };
            }
        }
    }
}

/**
 * Checks if the * f_box(Z1) * -> Z1 case where * is anything
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
function f_boxBackwardSpecialDiff(trees: TreePair): SpecialDiff {
    const Targetbfs: Node[] = trees.target.getBFS();
    let contents: Node;
    let foundF_box: Node;

    // check for f_box in source
    for (const node of Targetbfs) {
        if (node.name === "f_box") {
            // check there is only one f_box
            if (foundF_box === undefined) {
                contents = node.args[0];
                foundF_box = node;
            } else {
                return undefined;
            }
        }
    }

    // if there is only f_box() we cant get the diff
    if (contents === undefined) {
        return undefined;
    }

    const case2Regex = /f_box\((.*)\) == \1/;
    const targetString: string = trees.target.toString();
    // Check that either targetString is either a) contents or b) f_box(contents) == contents
    if (targetString === contents.toString() || targetString.match(case2Regex)) {
        return {
            source: parse(contents.toString()),
            target: parse(foundF_box.toString()),
            func: "f_boxBackwardSpecialDiff",
        };
    }
}

/**
 * Handles the case where multiple f_box are opened up
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function multif_boxSpecialDiff(trees: TreePair): SpecialDiff {
    // first check just the source have f_box
    if (!trees.source.toString().includes("f_box") || trees.target.toString().includes("f_box")) {
        return undefined;
    }

    // count how many f_box we get
    let f_boxCount = 0;

    // collect all the insides of the f_box
    const f_boxInsides: string[] = [];
    for (const node of trees.source.getBFS()) {
        if (node.name === "f_box") {
            f_boxCount += 1;
            const inside = node.args[0].toString();
            f_boxInsides.push(inside);

            if (!trees.target.toString().includes(inside)) {
                return undefined;
            }
        }
    }

    // if we only found one, rewturn undefined
    if (f_boxCount < 2) {
        return undefined;
    }

    // we generate a string of all the insides to the transitionFinder can find an appropriate match
    const insidesStr = f_boxInsides.join(" * ").toString();
    return { source: parse(`f_box(${insidesStr})`), target: parse(insidesStr), func: "multif_boxSpecialDiff" };
}

/**
 * A special case diff to handle f_boxOp beign uses
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function f_boxOpSpecialDiff(trees: TreePair): SpecialDiff {
    let source = trees.source.cloneDeep();
    const target = trees.target.cloneDeep();

    for (const node of source.getBFS()) {
        let newNode;

        if (node.name === "f_boxOp") {
            const kids = node.args;
            const op = kids[2].name.replace(/"/g, "");
            const a = kids[0];
            const b = kids[1];

            newNode = parse(`${a} ${op} ${b}`);
        }

        if (newNode !== undefined) {
            source = node.replaceInTree(newNode);
        }
    }

    if (source.toString() === target.toString() && trees.source.toString().includes("f_boxOp")) {
        return { source: parse(`f_boxOp(Z1)`), target: parse(`Z1`), func: "f_boxOpSpecialDiff" };
    }
}

/**
 * A special case diff where we remove all units before generating the diff
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function unitSpecialDiff(trees: TreePair, keepf_format: boolean): SpecialDiff {
    let source: Node = trees.source.cloneDeep();
    let target: Node = trees.target.cloneDeep();

    for (const node of source.getBFS()) {
        if (node.name === "f_unit") {
            source = removeNode(node, source);
        }
    }
    for (const node of target.getBFS()) {
        if (node.name === "f_unit") {
            target = removeNode(node, target);
        }
    }

    if (source && target) {
        // now generate the diff normally
        const diffandmap = generateDiffAndSubtreeMap(source, target, true, keepf_format);
        // if the diff isnt an identity, return it
        if (diffandmap.diff.source.toString() !== diffandmap.diff.target.toString()) {
            return {
                source: diffandmap.diff.source,
                target: diffandmap.diff.target,
                map: diffandmap.map,
                func: "unitSpecialDiff",
            };
        }
    }
}

/**
 * A special case diff to handle f_strike(2) * 2 => 2
 * @static
 * @param {{ sourceString: string; targetString: string }} origTrees
 * @returns {TreePair}
 */
export function commonTermsSpecialDiff(origTrees: { sourceString: string; targetString: string }): SpecialDiff {
    if (!origTrees.sourceString.includes("f_strike") || origTrees.targetString.includes("f_strike")) {
        return undefined;
    }
    if (origTrees.sourceString.replace(/ /g, "") === origTrees.targetString.replace(/ /g, "")) {
        return undefined;
    }

    const origSource = parse(origTrees.sourceString, null);

    // get the NCA
    const NCA: Node = getNearestCommonAncestor(origSource.getNodesByName("f_strike"));

    if (!NCA) {
        return undefined;
    }

    const NCAName = NCA.name;

    let hasVars: boolean = false;
    // we make a fake tree containing all the f_strikes so we can use a bfs to search it easily
    const f_strikeTree: Node = parse("2 + 2");
    f_strikeTree.args = origSource.getNodesByName("f_strike");
    // check if any f_strike contains a variable
    for (const node of f_strikeTree.getBFS()) {
        if (node.isVar()) {
            hasVars = true;
            break;
        }
    }

    let hasf_div: boolean = NCA.getNodesByName("f_div").length > 0;
    if (!hasf_div && NCA.parent) {
        hasf_div = NCA.parent.name === "f_div";
    }

    const diffs: Map<string, string[]> = new Map<string, string[]>([
        ["/,hasNoVars", ["f_strike({a})/f_strike({a})", "1"]],
        ["/,hasVars", ["f_strike(x)/f_strike(x)", "1"]],
        ["f_div,hasNoVars", ["f_div(f_strike({a}), f_strike({a}))", "1"]],
        ["f_div,hasVars", ["f_div(f_strike(x), f_strike(x))", "1"]],
        ["*,hasNoVars,hasf_div", ["f_div({a},f_strike({b})) * f_strike({b})", "{a}"]],
        ["*,hasVars,hasf_div", ["f_div({a},f_strike(x)) * f_strike(x)", "{a}"]],
        ["*,hasNoVars", ["f_strike(1/{a}) * f_strike({a})", "1"]],
        ["*,hasVars", ["f_strike(1/x) * f_strike(x)", "1"]],
        ["+,hasNoVars", ["a + f_strike({b}) - f_strike({b})", "a"]],
        ["+,hasVars", ["a + f_strike(y) - f_strike(y)", "a"]],
        ["-,hasNoVars", ["a + f_strike({b}) - f_strike({b})", "a"]],
        ["-,hasVars", ["a + f_strike(y) - f_strike(y)", "a"]],
    ]);

    // now get the key
    let key = `${NCAName},${hasVars ? "hasVars" : "hasNoVars"}`;
    if (NCAName === "*" && hasf_div) {
        key += ",hasf_div";
    }

    // now use key to get the right diff
    const diff: string[] = diffs.get(key);
    if (diff) {
        return { source: parse(diff[0]), target: parse(diff[1]), func: "commonTermsSpecialDiff" };
    } else {
        return undefined;
    }
}

/**
 * A special case diff to handle 2 * 2 => f_strike(2) * f_strike(2)
 * @static
 * @param {{ sourceString: string; targetString: string }} origTrees
 * @returns {TreePair}
 */
export function f_strikeSpecialDiff(origTrees: { sourceString: string; targetString: string }): SpecialDiff {
    if (origTrees.sourceString.includes("f_strike") || !origTrees.targetString.includes("f_strike")) {
        return undefined;
    }
    if (origTrees.sourceString.replace(/ /g, "") === origTrees.targetString.replace(/ /g, "")) {
        return undefined;
    }

    const origTarget = parse(origTrees.targetString, null);

    // get the NCA
    const NCA: Node = getNearestCommonAncestor(origTarget.getNodesByName("f_strike"));

    if (!NCA) {
        return undefined;
    }

    const NCAName = NCA.name;

    let hasVars: boolean = false;
    // we make a fake tree containing all the f_strikes so we can use a bfs to search it easily
    const f_strikeTree: Node = parse("2 + 2");
    f_strikeTree.args = origTarget.getNodesByName("f_strike");
    // check if any f_strike contains a variable
    for (const node of f_strikeTree.getBFS()) {
        if (node.isVar()) {
            hasVars = true;
            break;
        }
    }

    let hasf_div: boolean = NCA.getNodesByName("f_div").length > 0;
    if (!hasf_div && NCA.parent) {
        hasf_div = NCA.parent.name === "f_div";
    }

    const diffs: Map<string, string[]> = new Map<string, string[]>([
        ["/,hasNoVars", [`{a}/{a}`, `f_strike({a})/f_strike({a})`]],
        ["/,hasVars", [`x / x`, `f_strike(x)/f_strike(x)`]],
        ["f_div,hasNoVars", [`f_div({a}, {a})`, `f_div(f_strike({a}), f_strike({a}))`]],
        ["f_div,hasVars", [`f_div(x, x)`, `f_div(f_strike(x), f_strike(x))`]],
        ["*,hasNoVars,hasf_div", [`f_div({a}, {b}) * {b}`, `f_div({a},f_strike({b})) * f_strike({b})`]],
        ["*,hasVars,hasf_div", [`f_div({a},x) * x`, `f_div({a},f_strike(x)) * f_strike(x)`]],
        ["*,hasNoVars", [`1/{a} * {a}`, `f_strike(1/{a}) * f_strike({a})`]],
        ["*,hasVars", [`1/x * x`, `f_strike(1/x) * f_strike(x)`]],
        ["+,hasNoVars", [`a + {b} - {b}`, `a + f_strike({b}) - f_strike({b})`]],
        ["+,hasVars", [`a + y - y`, `a + f_strike(y) - f_strike(y)`]],
        ["-,hasNoVars", [`a + {b} - {b}`, `a + f_strike({b}) - f_strike({b})`]],
        ["-,hasVars", [`a + y - y`, `a + f_strike(y) - f_strike(y)`]],
    ]);

    // now get the key
    let key = `${NCAName},${hasVars ? "hasVars" : "hasNoVars"}`;
    if (NCAName === "*" && hasf_div) {
        key += ",hasf_div";
    }

    // now use key to get the right diff
    const diff: string[] = diffs.get(key);
    if (diff) {
        return { source: parse(diff[0]), target: parse(diff[1]), func: "f_strikeSpecialDiff" };
    }
}

/**
 * A special case diff to handle f_boxOp beign uses
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function bedmasSpecialDiff(
    trees: TreePair,
    normalDiffWMap: { diff: PDiff; map: Map<string, Node[]> },
): SpecialDiff {
    // generate a regular diff
    const diff = normalDiffWMap.diff;
    const map = normalDiffWMap.map;
    const stepUnique: UniqueDiff = new UniqueDiff(diff.sourceString, diff.targetString);
    const highestStepZ = Math.max(findHighestZ(diff.source), findHighestZ(diff.target));

    const cases: string[][] = [
        ["Z1 + {b} - {c}", "Z1 + {b - c}"],
        ["{a} + {b - c}", "{a + b - c}"],
        ["Z1 * {b} / {c}", "Z1 * {b / c}"],
        ["{a} * {b / c}", "{a * b / c}"],
    ];

    let type: "totalSolve" | "partSolve";
    let direction: "solving" | "expanding";

    // we check the diffs forwards and backwards
    for (const steps of cases) {
        const diffTraceSolving = new UniqueDiff(steps[0], steps[1]);
        const diffTraceExpanding = new UniqueDiff(steps[1], steps[0]);

        // console.log(diffTraceExpanding.source, stepUnique.source);

        // check if the diff matches the solving case
        if (compareDiffTraces(stepUnique, diffTraceSolving)) {
            direction = "solving";
            type = parse(steps[1]).type === "EvalNode" ? "totalSolve" : "partSolve";
            break;
        }
        // check if the diff matches the expanding case
        else if (compareDiffTraces(stepUnique, diffTraceExpanding)) {
            direction = "expanding";
            type = parse(steps[1]).type === "EvalNode" ? "totalSolve" : "partSolve";
            break;
        }
    }

    if (direction && type) {
        // console.log(trees.source.toString(), trees.target.toString());
        // console.log(type, direction);

        const expandedSide = direction === "expanding" ? diff.target : diff.source;

        // printTwoTrees(diff.source, diff.target);

        let start: Node;
        let end: Node;
        // now check which case we have
        if (type === "totalSolve") {
            const op0 = expandedSide;
            const kids = op0.args;

            start = parse(`{${kids[0].args[0]}} ${op0.name} {Z${highestStepZ + 1}}`);
            end = parse(`{${kids[0].args[0]} ${op0.name} Z${highestStepZ + 1}}`);

            let subtreeString = kids[1].args[0].toString();

            for (let znum = 0; znum < highestStepZ + 1; znum++) {
                if (map.get(`Z${znum}`)) {
                    const zVal = map.get(`Z${znum}`)[0].toString();
                    subtreeString = subtreeString.replace(new RegExp(`Z${znum}`, "g"), zVal);
                }
            }

            map.set(`Z${highestStepZ + 1}`, parse(subtreeString).getBFS());
        } else {
            const op0 = expandedSide;

            const kids = [op0.args[0].args[1].args[0], op0.args[1].args[0]];

            start = parse(`{${kids[0]}} ${op0.name} {Z${highestStepZ + 1}}`);
            end = parse(`{${kids[0]} ${op0.name} Z${highestStepZ + 1}}`);

            // here we grab the [1] kid not the "kids[1].args[0]" like above because the tree is different
            let subtreeString = kids[1].toString();
            for (let znum = 1; znum < highestStepZ + 1; znum++) {
                if (map.get(`Z${znum}`)) {
                    const zVal = map.get(`Z${znum}`)[0].toString();

                    subtreeString = subtreeString.replace(new RegExp(`Z${znum}`, "g"), zVal);
                }
            }

            map.set(`Z${highestStepZ + 1}`, parse(subtreeString).getBFS());
        }

        if (direction === "solving") {
            // console.log(`${trees.source} -> ${trees.target}     |       ${start} -> ${end} with map: `);
            return { source: start, target: end, map: map, func: "bedmasSpecialDiff" };
        } else {
            // console.log(`${trees.source} -> ${trees.target}     |       ${end} -> ${start} with map: `, map);
            return { source: end, target: start, map: map, func: "bedmasSpecialDiff" };
        }
    }
}

/**
 * Detect special case for f_format (ignoring evaluation)
 * eg: Z1 => f_format(Z1 ...)
 * @param treesWf_format
 */
export function f_formatSpecialDiff(treesWf_format: TreePair): SpecialDiff {
    const source = treesWf_format.source.cloneDeep();
    const target = treesWf_format.target.cloneDeep();

    // generate a diff between the two
    const diffandMap = generateDiffAndSubtreeMap(source, target, true, true);
    const diff = diffandMap.diff;

    // now check if the diff matches the case (namely Z1 => f_format(Z1 ...))
    const targetName = diff.target.name;

    // check if target root node is 'f_format' and if source includes f_format
    if (targetName === "f_format" && !diff.sourceString.includes("f_format")) {
        return {
            source: parse(diff.sourceString),
            target: parse(diff.targetString),
            map: diffandMap.map,
            func: "f_formatSpecialDiff",
        };
    }
}

//==========================================================================================

export function searchf_strikeCases(origTrees: { sourceString: string; targetString: string }): SpecialDiff {
    const commonTermsDiffs = commonTermsSpecialDiff(origTrees);
    const f_strikeSpecialDiffDiffs = f_strikeSpecialDiff(origTrees);

    const diffsFound = [commonTermsDiffs, f_strikeSpecialDiffDiffs];

    const map: Map<string, Node[]> = new Map<string, Node[]>();

    for (const diff of diffsFound) {
        const newDIff = diff;

        if (diff !== undefined) {
            return { source: newDIff.source, target: newDIff.target, map: map, func: "searchf_strikeCases" };
        }
    }
}

/**
 * Takes the tree pair and checks each of the special cases
 * @static
 * @param {TreePair} trees
 * @returns {TreePair}
 */
export function searchAllCases(
    trees: TreePair,
    treesWf_format: TreePair,
    keepf_format: boolean,
): { diff: TreePair; map: Map<string, Node[]> } {
    const normalDiffWMap = generateDiffAndSubtreeMap(trees.source, trees.target, true, keepf_format);

    // order matters here
    const diffsAndMaps = [
        f_boxForwardSpecialDiff(trees),
        f_boxBackwardSpecialDiff(trees),
        f_boxOpSpecialDiff(trees),
        multif_boxSpecialDiff(trees),
        bedmasSpecialDiff(trees, normalDiffWMap),
        f_formatSpecialDiff(treesWf_format),
    ].filter((val) => val !== undefined);

    for (const { source, target, map, func } of diffsAndMaps) {
        if (func === "f_boxBackwardSpecialDiff") {
            const diffObj = generateDiffAndSubtreeMap(source, target, false, keepf_format);
            return { diff: diffObj.diff, map: diffObj.map };
        }

        return { diff: { source, target }, map: map || new Map() };
    }
    return { diff: undefined, map: new Map() };
}
