import { useCallback, useMemo } from 'react';
import { DatasetSetup } from '@tensorleap/api-client';
import { Position } from '../../../core/position';
import { Plus } from '../../../ui/icons';
import {
  COMPONENT_DESCRIPTORS_MAP,
  isValidNodeName,
} from '../../interfaces/NodeDescriptor';
import { UIComponent, UI_COMPONENTS } from '../../../core/types/ui-components';
import { IMAGE_VISUALIZER_PROPERTIES } from '../../consts';
import { USER_UNIQUE_NAME } from '../../../layer-details/UserUniqueName';
import {
  isInputNode,
  isInputsNode,
  isOptimizerNode,
} from '../../graph-calculation/utils';
import { generateValidateAssetsQuickFix } from './ValidateAssetsErrorData';
import { useNetworkMapContext } from '../../../core/NetworkMapContext';
import { ValidateAssetsStatus } from '../../interfaces/ValidateGraphStatus';
import { Unreachable } from '../../../core/Errors';
import {
  NetworkWizardCategory,
  NetworkWizardData,
  QuickFixProps,
  GraphErrorKind,
  NetworkWizardErrorSeverity,
} from '../types';
import { GraphNodeErrorType, NodeErrorMsg } from '../errors';
import { NetworkTabsEnum } from '../../NetworkDrawer';

const NODES_OFFSET = 300;

export type NodeMessageData = {
  nodeId: string;
  msg: string;
  category?: NetworkWizardCategory;
};

export interface AddLossParams {
  newNodeName: string;
  toGT: boolean;
  isCustomLoss: boolean;
}

function prepareAllLosses(
  datasetSetup: DatasetSetup | undefined
): Record<string, 'Loss' | 'CustomLoss'> {
  const uiCompTypes = UI_COMPONENTS.filter(
    ({ type }) => type === 'Loss'
  ).reduce<Record<string, 'Loss' | 'CustomLoss'>>((acc, { name }) => {
    acc[name] = 'Loss';
    return acc;
  }, {});

  const allLossTypes = (datasetSetup?.custom_losses || []).reduce<
    Record<string, 'Loss' | 'CustomLoss'>
  >((acc, { name }) => {
    acc[name] = 'CustomLoss';
    return acc;
  }, uiCompTypes);

  return allLossTypes;
}

export function useGraphNodeDictErrorData({
  title = 'INVALID NODE',
  isValidateAssetsError,
  category: categoryFromProps,
  nodesMessageData = [],
}: GraphNodeErrorType & {
  nodesMessageData: NodeMessageData[];
}): NetworkWizardData[] {
  const {
    nodes,
    addNewNode,
    addNewConnection,
    changeNodeProperty,
    updateConnection,
    addPredictionLabel,
    getNewNodeId,
    selectNode,
    datasetSetup,
    validateAssetsStatus,
    setOpenNetworkTab,
    onFitNodeToScreen,
    currentDatasetSetup,
  } = useNetworkMapContext();

  const addOptimizer = useCallback(
    (newNodeName: UIComponent['name'], nodeId: string): void => {
      const lossNode = nodes.get(nodeId);
      if (!lossNode) {
        console.warn(`Node #${nodeId} does not exist`);
        return;
      }

      const [lossNodePosX, lossNodePosY] = lossNode.position;
      const optimizerPosition: Position = [
        lossNodePosX + NODES_OFFSET,
        lossNodePosY - NODES_OFFSET,
      ];

      const optimizerId = getNewNodeId();
      addNewNode({ name: newNodeName, position: optimizerPosition });
      addNewConnection({
        inputNodeId: optimizerId,
        inputName: '0',
        outputNodeId: nodeId,
        outputName: 'loss',
        isDynamicInput: true,
      });

      onFitNodeToScreen(nodeId);
      selectNode(optimizerId);
    },
    [
      addNewConnection,
      addNewNode,
      getNewNodeId,
      nodes,
      onFitNodeToScreen,
      selectNode,
    ]
  );

  const addVisualizer = useCallback(
    (nodeId: string) => {
      const dsNode = nodes.get(nodeId);
      if (!dsNode) {
        console.warn(`Node #${nodeId} does not exist`);
        return;
      }

      const [dsNodePosX, dsNodePosY] = dsNode.position;
      const visualizrPosition: Position = [
        dsNodePosX + NODES_OFFSET,
        dsNodePosY - NODES_OFFSET,
      ];

      const visualizerNodeId = getNewNodeId();
      addNewNode({ name: 'Visualizer', position: visualizrPosition });
      changeNodeProperty({
        nodeId: visualizerNodeId,
        nodeDataPropsToUpdate: {
          ...IMAGE_VISUALIZER_PROPERTIES,
          [USER_UNIQUE_NAME]: IMAGE_VISUALIZER_PROPERTIES.name,
        },
      });

      const datasetInputs = datasetSetup?.inputs;
      if (!datasetInputs || !datasetInputs.length) {
        console.error('Dataset node was not found or has no outputs');
        return;
      }

      const datasetNodeFirstInputName =
        dsNode.data['output_name'] || datasetInputs[0].name;
      addNewConnection({
        outputNodeId: nodeId,
        outputName: datasetNodeFirstInputName,
        inputNodeId: visualizerNodeId,
        inputName: 'data',
        isDynamicInput: false,
      });

      onFitNodeToScreen(nodeId);
      selectNode(visualizerNodeId);
    },
    [
      addNewConnection,
      addNewNode,
      changeNodeProperty,
      datasetSetup?.inputs,
      getNewNodeId,
      nodes,
      onFitNodeToScreen,
      selectNode,
    ]
  );

  const selectInput = useCallback(
    (nodeId: string, inputName?: string) => {
      if (!inputName) throw new Unreachable();
      changeNodeProperty({
        nodeId,
        nodeDataPropsToUpdate: { output_name: inputName },
      });
      updateConnection(nodeId, undefined, [inputName]);
    },
    [changeNodeProperty, updateConnection]
  );

  const addLoss = useCallback(
    (
      { newNodeName, toGT, isCustomLoss }: AddLossParams,
      nodeId: string
    ): void => {
      const layerNode = nodes.get(nodeId);
      if (!layerNode) {
        console.error(`Node #${nodeId} does not exist`);
        return;
      }

      const [layerNodePosX, layerNodePosY] = layerNode.position;
      const lossPosition: Position = [
        layerNodePosX + NODES_OFFSET,
        layerNodePosY - NODES_OFFSET,
      ];

      const datasetNode = Array.from(nodes.values()).find(
        (n) => isInputNode(n) || isInputsNode(n)
      );

      const componentDescriptor = COMPONENT_DESCRIPTORS_MAP.get(layerNode.name);

      const outputName = toGT
        ? datasetNode?.data.datasetVersion?.metadata?.setup.outputs[0]?.name
        : componentDescriptor?.outputs_data.outputs[0]?.name;

      if (!outputName) {
        console.error(`Node #${layerNode.name} does not have output name`);
        return;
      }

      const lossId = getNewNodeId();

      if (isCustomLoss) {
        const customLossInstance = currentDatasetSetup?.custom_losses.find(
          ({ name }) => name === newNodeName
        );

        if (!customLossInstance) {
          console.error(`${newNodeName} is not a valid custom loss name`);
          return;
        }

        addNewNode({
          name: 'CustomLoss',
          position: lossPosition,
          subType: newNodeName,
        });

        changeNodeProperty({
          nodeId: lossId,
          nodeDataPropsToUpdate: {
            name: newNodeName,
            selected: newNodeName,
            [USER_UNIQUE_NAME]: newNodeName,
            arg_names: customLossInstance.arg_names,
          },
        });
      } else {
        addNewNode({ name: newNodeName, position: lossPosition });

        addNewConnection({
          inputNodeId: lossId,
          inputName: toGT ? 'ground_truth' : 'prediction',
          outputNodeId: nodeId,
          outputName,
          isDynamicInput: false,
        });

        if (!toGT) {
          addPredictionLabel(layerNode);
        }
      }

      onFitNodeToScreen(nodeId);
      selectNode(lossId);
    },
    [
      addNewConnection,
      addNewNode,
      addPredictionLabel,
      changeNodeProperty,
      currentDatasetSetup?.custom_losses,
      getNewNodeId,
      nodes,
      onFitNodeToScreen,
      selectNode,
    ]
  );
  const {
    validateAssetsButtonState,
    handleValidateAssetsClicked,
  } = useNetworkMapContext();

  const isValidateAssetsErrorAndCalculating =
    isValidateAssetsError &&
    validateAssetsStatus === ValidateAssetsStatus.Calculating;

  return useMemo(() => {
    return nodesMessageData.map(({ nodeId, msg, category }) => {
      const createQuickFix = (): QuickFixProps | undefined => {
        if (!nodeId) {
          return undefined;
        }
        if (
          !!datasetSetup?.inputs?.length &&
          msg === NodeErrorMsg.InputIsNotSelected
        ) {
          return {
            title: 'Inputs',
            selectOptions: datasetSetup.inputs.map(({ name }) => name) || [],
            onSelect: (value?: string) => {
              selectInput(nodeId, value);
            },
          };
        } else if (
          !!datasetSetup?.inputs?.length &&
          msg === NodeErrorMsg.InputHasntVisualizer
        ) {
          return {
            onSelect: () => addVisualizer(nodeId),
            title: 'Add',
            tooltipMsg: 'Add Visualizer',
            icon: <Plus className="h-5 w-5" />,
          };
        } else if (isValidateAssetsError) {
          return generateValidateAssetsQuickFix(
            validateAssetsButtonState,
            handleValidateAssetsClicked
          );
        } else if (msg === NodeErrorMsg.LossNotConnectedToOptimizer) {
          const optimizerNode = Array.from(nodes.values()).find(
            isOptimizerNode
          );
          const inputName =
            optimizerNode?.data?.custom_input_keys?.length?.toString() || '0';

          if (optimizerNode) {
            return {
              title: 'Connect',
              tooltipMsg: 'Connect Existing Optimizer',
              icon: <Plus className="h-5 w-5" />,
              onSelect: () => {
                addNewConnection({
                  inputNodeId: optimizerNode.id,
                  inputName,
                  outputNodeId: nodeId,
                  outputName: 'loss',
                  isDynamicInput: true,
                });
              },
            };
          }

          const optimizerTypeNodeNames = UI_COMPONENTS.filter(
            ({ type }) => type === 'Optimizer'
          ).map(({ name }) => name);
          return {
            title: 'Optimizers',
            selectOptions: optimizerTypeNodeNames,
            onSelect: (value?: string) => {
              if (isValidNodeName(value)) {
                addOptimizer(value, nodeId);
              } else {
                console.warn(`${value} is not a valid node name`);
              }
            },
          };
        } else if (msg === NodeErrorMsg.GTHasntLoss) {
          const allLossNodes = prepareAllLosses(datasetSetup);
          return {
            title: 'Losses',
            selectOptions: Object.keys(allLossNodes),
            onSelect: (value?: string) => {
              if (!value) {
                console.warn('The Loss type is missing');
                return;
              }

              const isCustomLoss = allLossNodes[value] === 'CustomLoss';
              if (isCustomLoss || isValidNodeName(value)) {
                addLoss(
                  {
                    newNodeName: value,
                    toGT: true,
                    isCustomLoss,
                  },
                  nodeId
                );
              } else {
                console.warn(`${value} is not a valid node name`);
              }
            },
          };
        } else if (msg === NodeErrorMsg.NoOutputNode) {
          const allLossNodes = prepareAllLosses(datasetSetup);
          return {
            title: 'Losses',
            selectOptions: Object.keys(allLossNodes),
            onSelect: (newNodeName?: string) => {
              if (!newNodeName) {
                console.warn('The Loss type is missing');
                return;
              }
              const isCustomLoss = allLossNodes[newNodeName] === 'CustomLoss';
              if (isCustomLoss || isValidNodeName(newNodeName)) {
                addLoss({ newNodeName, toGT: false, isCustomLoss }, nodeId);
              } else {
                console.warn(`"${newNodeName}" is not a valid node name`);
              }
            },
          };
        } else if (msg === NodeErrorMsg.NoIntegrationScript) {
          return {
            title: 'Select',
            tooltipMsg: 'Select script',
            icon: <Plus className="h-5 w-5" />,
            onSelect: () => {
              setOpenNetworkTab(NetworkTabsEnum.CodeIntegration);
            },
          };
        }
      };
      const quickFix = createQuickFix();

      return {
        errorType: GraphErrorKind.node,
        category: category || categoryFromProps || NetworkWizardCategory.MODEL,
        showNodeFooter: true,
        title: title,
        message: msg,
        calculateKey: () => msg + nodeId,
        showNode: () => {
          if (nodeId) {
            onFitNodeToScreen(nodeId);
            selectNode(nodeId);
          }
        },
        quickFixes: quickFix ? [quickFix] : [],
        errorSeverity: NetworkWizardErrorSeverity.ERROR,
        isLoading: isValidateAssetsErrorAndCalculating,
        key: msg + nodeId,
      };
    });
  }, [
    nodesMessageData,
    categoryFromProps,
    title,
    isValidateAssetsErrorAndCalculating,
    datasetSetup,
    isValidateAssetsError,
    selectInput,
    addVisualizer,
    validateAssetsButtonState,
    handleValidateAssetsClicked,
    nodes,
    addNewConnection,
    addOptimizer,
    addLoss,
    setOpenNetworkTab,
    onFitNodeToScreen,
    selectNode,
  ]);
}
