import { ModelGraph, Node } from '@tensorleap/engine-contract';

const X_STEP = 500;
const Y_STEP = 500;

type Depth = { x: number; y: number };

const NO_POSITION_STRING_ARRAY = [0, 0].toString();

function isOrganizedGraph(nodes: Record<string, Node>): boolean {
  return !Object.values(nodes).every(
    ({ position }) => position.toString() === NO_POSITION_STRING_ARRAY
  );
}

export function reorganizeModelGraphIfNeeded(
  modelGraph: ModelGraph
): ModelGraph {
  return isOrganizedGraph(modelGraph.nodes)
    ? modelGraph
    : {
        ...modelGraph,
        nodes: reorganizeMap(modelGraph.nodes),
      };
}

export function reorganizeMap(
  nodes: Record<string, Node>
): Record<string, Node> {
  const clonedNodes = { ...nodes };
  const sharedBlockIds = new Set<string>(
    Object.values(clonedNodes)
      .filter((n) => n.data.output_blocks?.length && n.name !== 'Dataset')
      .map(({ id }) => id)
  );

  const visualizersBlockIds = new Set<string>(
    Object.values(clonedNodes)
      .filter((n) => n.name === 'Visualizer')
      .map(({ id }) => id)
  );
  const leafNodes = Object.values(clonedNodes).filter(({ id, outputs }) => {
    if (sharedBlockIds.has(id) || visualizersBlockIds.has(id)) return false;

    const nodeOutputs = Object.values(outputs);
    return (
      nodeOutputs.length === 0 ||
      nodeOutputs.every((output) => !output.connections?.length)
    );
  });

  if (leafNodes.length === 0) {
    return nodes;
  }
  const [firstLeaf] = leafNodes;
  const [xOffset] = firstLeaf.position;
  const nodeToDepth = new Map<string, Depth>();
  const depthToNodes = new Map<number, Set<Node>>();
  leafNodes.forEach((node) =>
    calcNodesDepth(node, { x: 0, y: 0 }, clonedNodes, nodeToDepth, depthToNodes)
  );

  depthToNodes.forEach((nodes, depth) => {
    const xPosition = depth * -X_STEP + xOffset;
    const depthNodes = wideningVerticalNodes(
      Array.from(nodes).map((node) => ({
        node,
        y: (nodeToDepth.get(node.id) as Depth)?.y || 0,
      }))
    );

    depthNodes.forEach(({ node, y }) => {
      clonedNodes[node.id] = {
        ...node,
        position: [xPosition, y * Y_STEP],
      };
    });
  });

  sharedBlockIds.forEach((blockId) => {
    const sharedBlock = clonedNodes[blockId];
    if (
      sharedBlock.data.output_blocks &&
      sharedBlock.data.output_blocks.length > 0
    ) {
      const [firstChild] = sharedBlock.data.output_blocks;
      const childNode = clonedNodes[firstChild.block_node_id];
      if (!childNode) {
        console.warn('Shared block is not found', firstChild.block_node_id);
        return;
      }
      clonedNodes[blockId] = {
        ...sharedBlock,
        position: [childNode.position[0], childNode.position[1] - Y_STEP / 2],
      };
    }
  });

  arrangeVisualizers(clonedNodes, visualizersBlockIds);

  return clonedNodes;
}

function arrangeVisualizers(
  clonedNodes: Record<string, Node>,
  visualizersBlockIds: Set<string>
) {
  const visitedVisualizers = new Set<string>();
  Object.values(clonedNodes).forEach((currentNode) => {
    const outVisualizersConn = Object.values(currentNode.outputs).reduce(
      (acc, { connections }) => {
        const outputVisualizers = connections
          .filter(
            ({ node: outputNodeId }) =>
              visualizersBlockIds.has(outputNodeId) &&
              !visitedVisualizers.has(outputNodeId)
          )
          .map(({ node: outputNodeId }) => {
            visitedVisualizers.add(outputNodeId);
            return outputNodeId;
          });
        return acc.concat(outputVisualizers);
      },
      [] as string[]
    );

    let i = 0;
    outVisualizersConn.forEach((visualizerId) => {
      const yAxis =
        outVisualizersConn.length === 1
          ? -1
          : getYAxisDirection(i++, outVisualizersConn.length);
      clonedNodes[visualizerId] = {
        ...clonedNodes[visualizerId],
        position: [
          currentNode.position[0] + X_STEP / 2,
          currentNode.position[1] + (Y_STEP / 2) * yAxis,
        ],
      };
    });
  });
}

function wideningVerticalNodes(
  nodes: Array<{ node: Node; y: number }>
): Array<{ node: Node; y: number }> {
  if (nodes.length <= 1) return nodes;

  for (let i = 0; i < nodes.length - 1; i++) {
    if (nodes[i].y === nodes[i + 1].y) {
      for (let j = 0; j <= i; j++) nodes[j].y -= 0.5;
      for (let j = i + 1; j < nodes.length; j++) nodes[j].y += 0.5;
    }
  }
  return nodes;
}

function calcNodesDepth(
  currentNode: Node,
  depth: Depth,
  allNodes: Record<string, Node>,
  nodeToDepth: Map<string, Depth>,
  depthToNodes: Map<number, Set<Node>>
): void {
  if (!currentNode?.id) {
    console.error("Can't calculate depth for node");
    return;
  }
  const nodeDepth = nodeToDepth.get(currentNode.id);
  if (nodeDepth) {
    if (depth.x <= nodeDepth.x) return;
    depthToNodes.get(nodeDepth.x)?.delete(currentNode);
  }
  nodeToDepth.set(currentNode.id, depth);
  depthToNodes.set(
    depth.x,
    depthToNodes.get(depth.x)?.add(currentNode) ?? new Set([currentNode])
  );

  const inputs = Object.values(currentNode.inputs);
  let i = 0;
  inputs.forEach((input) => {
    if (!input.connections || input.connections?.length === 0) {
      return;
    }
    const [connection] = input.connections;
    const inputNodeId = connection['node'];

    if (nodeToDepth.has(inputNodeId)) return;

    const inputNode = allNodes[inputNodeId];
    calcNodesDepth(
      inputNode,
      { x: depth.x + 1, y: depth.y + getYAxisDirection(i++, inputs.length) },
      allNodes,
      nodeToDepth,
      depthToNodes
    );
  });
}

function getYAxisDirection(index: number, len: number): number {
  const m = Math.floor(len / 2);
  if (len % 2 === 1 && m === index) return 0;
  return index < m ? index - m : index - m + (len % 2 === 0 ? 1 : 0);
}
