import { DatasetVersion, ModelGraph, Node } from '@tensorleap/api-client';
import {
  CustomLayerData,
  updateCustomLayerNode,
} from '../layer-details/CustomLayerDetails';
import { first } from 'lodash';
import { findByName } from '../core/named-halper';
import {
  isCustomlossNode,
  isGroundTruthNode,
  isInputsNode,
} from './graph-calculation/utils';
import {
  getDatasetOutputData,
  isCustomLayerNode,
  isPredictionNode,
} from './utils';
import { CustomLossNodeData } from '../layer-details/CustomLossDetails';
import {
  calcConnectionsToUpdate,
  ChangeNodePropFunc,
  createOldReteData,
  mapOldGraphToConnectionsAndNodes,
  NodePropertyType,
  UpdateConnectionFunc,
} from './networkStateUtils';
import { NodeWithLabels } from './descriptor/types';
import { Connection } from './interfaces/Connection';

type UpdateGraphWithDatasetSetupProps = {
  nodes: Map<string, Node>;
  codeIntegrationVersion?: DatasetVersion;
} & (
  | { connections: Connection[] }
  | {
      updateConnection: UpdateConnectionFunc;
      changeNodeProperty: ChangeNodePropFunc;
    }
);

type UpdateNodeWithDatasetSetupProps = {
  node: Node;
  codeIntegrationVersion?: DatasetVersion;
  updateConnection: UpdateConnectionFunc;
  changeNodeProperty: ChangeNodePropFunc;
};

// changeNodeProperty, updateConnection
function updateDatasetForGroundTruthNode({
  node,
  codeIntegrationVersion,
  updateConnection,
  changeNodeProperty,
}: UpdateNodeWithDatasetSetupProps) {
  const { output_name } = node.data;

  const outputs = codeIntegrationVersion?.metadata.setup?.outputs || [];

  let selectedOutput = findByName(outputs, output_name);

  if (!selectedOutput && outputs.length === 1) {
    [selectedOutput] = outputs;
  }

  updateConnection(
    node.id,
    undefined,
    outputs.map((o) => o.name)
  );

  changeNodeProperty({
    nodeId: node.id,
    nodeDataPropsToUpdate: { output_name: selectedOutput?.name },
  });
}

// changeNodeProperty
function updateDatasetForPredictionNode({
  node,
  codeIntegrationVersion,
  changeNodeProperty,
}: UpdateNodeWithDatasetSetupProps) {
  const { prediction_type } = node.data;
  const predictions =
    codeIntegrationVersion?.metadata.setup?.prediction_types || [];

  let newPrediction = predictions.find(({ name }) => prediction_type === name);

  if (!newPrediction && predictions.length === 1) {
    [newPrediction] = predictions;
  }

  changeNodeProperty({
    nodeId: node.id,
    nodeDataPropsToUpdate: { prediction_type: newPrediction?.name },
  });
}
// changeNodeProperty
function updateDatasetForCustomLayerNode({
  node,
  codeIntegrationVersion,
  changeNodeProperty,
}: UpdateNodeWithDatasetSetupProps) {
  const { selected, type: _type, ...init_props } = node.data as CustomLayerData;
  if (!selected) return;

  const currentCustomLayers =
    codeIntegrationVersion?.metadata.modelSetup?.custom_layers;

  const newCustomLayer =
    currentCustomLayers?.length === 1
      ? first(currentCustomLayers)
      : findByName(currentCustomLayers, selected);

  const previousPropsValues = new Map(Object.entries(init_props));

  updateCustomLayerNode(
    node,
    changeNodeProperty,
    newCustomLayer,
    previousPropsValues
  );
}
function updateDatasetForCustomLossNode({
  node,
  codeIntegrationVersion,
  changeNodeProperty,
  updateConnection,
}: UpdateNodeWithDatasetSetupProps) {
  const { selected, type } = node.data as CustomLossNodeData;

  const currentCustomLosses =
    codeIntegrationVersion?.metadata.setup?.custom_losses;

  const newCustomLayer =
    currentCustomLosses?.length === 1
      ? first(currentCustomLosses)
      : findByName(currentCustomLosses, selected);

  const { arg_names = [], name } = newCustomLayer || {};
  const user_unique_name = node.data.user_unique_name || '';

  changeNodeProperty({
    nodeId: node.id,
    nodeDataPropsToUpdate: {
      arg_names,
      user_unique_name,
      name,
      selected: name,
      type,
    },
    override: true,
  });

  updateConnection(node.id, arg_names);
}

function updateDatasetForDatasetNode({
  node,
  codeIntegrationVersion,
  updateConnection,
}: UpdateNodeWithDatasetSetupProps) {
  const newOutputNames = getDatasetOutputData(
    codeIntegrationVersion?.metadata.setup
  ).map(({ name }) => name);

  updateConnection(node.id, undefined, newOutputNames);
}

export function updateOldGraphWithCodeIntegration(
  graph: ModelGraph,
  code?: DatasetVersion
): ModelGraph {
  const { connections, nodes } = mapOldGraphToConnectionsAndNodes(graph);
  updateGraphWithCodeIntegration({
    nodes,
    codeIntegrationVersion: code,
    connections,
  });

  return createOldReteData(nodes, connections);
}

export function updateGraphWithCodeIntegration({
  nodes,
  codeIntegrationVersion,
  ...rest
}: UpdateGraphWithDatasetSetupProps) {
  for (const node of Array.from(nodes.values())) {
    let updateConnection: UpdateConnectionFunc = () => undefined;
    let changeNodeProperty: ChangeNodePropFunc = () => undefined;
    if ('connections' in rest) {
      changeNodeProperty = ({ nodeId, ...rest }) => {
        const node = nodes.get(nodeId);
        if (!node) {
          console.error(`Node with id ${nodeId} not found`);
          return;
        }
        const updatedNode = updateNodeData({
          node,
          ...rest,
        });
        nodes.set(nodeId, updatedNode);
      };
      updateConnection = (nodeId, currentInputNames, currentOutputNames) => {
        const { previous, added } = calcConnectionsToUpdate({
          connections: rest.connections,
          nodeId,
          currentInputNames,
          currentOutputNames,
        });
        rest.connections.length = 0;
        rest.connections.push(...[...previous, ...added]);
      };
    } else {
      ({ updateConnection, changeNodeProperty } = rest);
    }

    const updateNodeProps: UpdateNodeWithDatasetSetupProps = {
      node,
      codeIntegrationVersion,
      changeNodeProperty,
      updateConnection,
    };
    if (isInputsNode(node)) {
      updateDatasetForDatasetNode(updateNodeProps);
    } else if (isGroundTruthNode(node)) {
      updateDatasetForGroundTruthNode(updateNodeProps);
    } else if (isPredictionNode(node)) {
      updateDatasetForPredictionNode(updateNodeProps);
    } else if (isCustomLayerNode(node)) {
      updateDatasetForCustomLayerNode(updateNodeProps);
    } else if (isCustomlossNode(node)) {
      updateDatasetForCustomLossNode(updateNodeProps);
    }
  }
}

type UpdateNodeProps = {
  node: Node;
  nodeDataPropsToUpdate: Record<string, NodePropertyType>;
  nodePropsToUpdate?: Pick<NodeWithLabels, 'labels'>;
  override?: boolean;
};

export function updateNodeData({
  node,
  nodeDataPropsToUpdate,
  nodePropsToUpdate,
  override,
}: UpdateNodeProps): Node {
  const updatedNode: Node = {
    ...node,
    ...nodePropsToUpdate,
    data: {
      ...(override ? {} : node.data),
      ...nodeDataPropsToUpdate,
    },
  };

  return updatedNode;
}
