import React, { RefObject, useEffect, useMemo, useState } from 'react';
import { useDragLayer } from 'react-dnd';
import clsx from 'clsx';
import { Node } from '@tensorleap/api-client';

import { createStyles, makeStyles } from '../ui/mui';
import { Position, addPositions } from '../core/position';
import { Connection } from './interfaces/Connection';
import {
  ConnectionDragData,
  NodeRepositionDragData,
  NODE_REPOSITION_DRAG,
} from './interfaces/drag-and-drop';
import { PanZoomParams } from './PanZoom';
import { NetworkMapControl } from './hooks';

const useStyles = makeStyles((theme) =>
  createStyles({
    connection: {
      stroke: theme.palette.networkEditor.stroke,
      fill: 'none',
    },
  })
);

const EMPTY_DRAG_LAYER = {
  dragData: undefined,
  mousePosition: undefined,
};

export const ConnectionDragLayer = React.memo<{
  nodesMap: ROMap<string, Node>;
  nodeRefs: NetworkMapControl['nodeRefs'];
  panZoomParamsRef: RefObject<PanZoomParams>;
}>(({ nodesMap, nodeRefs, panZoomParamsRef }) => {
  const { dragData, mousePosition } = useDragLayer((monitor) => {
    const mousePos = monitor.getClientOffset();
    if (
      !mousePos ||
      !monitor.isDragging() ||
      monitor.getItemType() === NODE_REPOSITION_DRAG ||
      !panZoomParamsRef.current
    ) {
      return EMPTY_DRAG_LAYER;
    }
    const { zoom, position, domTarget } = panZoomParamsRef.current;

    const parentPosition = getElementPosition(domTarget.current);
    return {
      dragData: monitor.getItem() as ConnectionDragData,
      mousePosition: [
        (mousePos.x - parentPosition[0] - position[0]) / zoom,
        (mousePos.y - parentPosition[1] - position[1]) / zoom,
      ] as Position,
    };
  });

  const nodePosition = useMemo(
    () =>
      dragData && (nodesMap.get(dragData.outputNodeId)?.position as Position),
    [dragData, nodesMap]
  );

  const outputOffset = useMemo(
    () =>
      dragData &&
      getSocketOffset(
        nodeRefs.current
          .get(dragData.outputNodeId)
          ?.outputRefs.get(dragData.outputName)
      ),
    [dragData, nodeRefs]
  );

  const startPosition = useMemo(
    () =>
      nodePosition && outputOffset && addPositions(nodePosition, outputOffset),
    [nodePosition, outputOffset]
  );

  if (!startPosition || !mousePosition) {
    return null;
  }

  return <SvgCurve start={startPosition} end={mousePosition} />;
});
ConnectionDragLayer.displayName = 'ConnectionDragLayer';

function getElementPosition(elem?: HTMLElement | null): Position {
  if (!elem) {
    return [0, 0];
  }
  const { x, y } = elem.getBoundingClientRect();
  return [x, y];
}

const DEFAULT_DRAG_LAYER_PROPS = {
  hide: false,
};
export const NodeConnection = React.memo<{
  connection: Connection;
  outputNode?: Node;
  inputNode?: Node;
  outputSocketElem?: HTMLDivElement;
  inputSocketElem?: HTMLDivElement;
}>(
  ({
    connection,
    outputNode,
    inputNode,
    outputSocketElem,
    inputSocketElem,
  }) => {
    const [inputOffset, setInputOffset] = useState<Position>();
    const [outputOffset, setOutputOffset] = useState<Position>();

    useEffect(() => {
      setOutputOffset(getSocketOffset(outputSocketElem));
      setInputOffset(getSocketOffset(inputSocketElem));
    }, [outputSocketElem, inputSocketElem]);

    const { outputNodeOffset, inputNodeOffset, hide } = useDragLayer<
      Partial<{
        hide: boolean;
        outputNodeOffset: Position;
        inputNodeOffset: Position;
      }>
    >((monitor) => {
      if (!monitor.isDragging()) {
        return DEFAULT_DRAG_LAYER_PROPS;
      }
      if (monitor.getItemType() !== NODE_REPOSITION_DRAG) {
        return {
          hide:
            (monitor.getItem() as ConnectionDragData).existingConnection ===
            connection,
        };
      }
      const { nodeId, zoom } = monitor.getItem() as NodeRepositionDragData;
      const offset = monitor.getDifferenceFromInitialOffset();
      if (!offset) {
        return DEFAULT_DRAG_LAYER_PROPS;
      }
      if (outputNode?.id === nodeId) {
        return {
          outputNodeOffset: [offset.x / zoom, offset.y / zoom],
        };
      }
      if (inputNode?.id === nodeId) {
        return {
          inputNodeOffset: [offset.x / zoom, offset.y / zoom],
        };
      }
      return DEFAULT_DRAG_LAYER_PROPS;
    });

    const outputSocketPosition = useMemo(
      () =>
        outputOffset &&
        addPositions(outputNode?.position as Position, outputOffset),
      [outputNode?.position, outputOffset]
    );

    const inputSocketPosition = useMemo(
      () =>
        inputOffset &&
        !!inputNode?.position?.length &&
        addPositions(inputNode?.position as Position, inputOffset),
      [inputNode?.position, inputOffset]
    );

    if (!outputSocketPosition || !inputSocketPosition || hide) {
      return null;
    }

    return (
      <SvgCurve
        start={
          outputNodeOffset
            ? addPositions(outputSocketPosition, outputNodeOffset)
            : outputSocketPosition
        }
        end={
          inputNodeOffset
            ? addPositions(inputSocketPosition, inputNodeOffset)
            : inputSocketPosition
        }
      />
    );
  }
);
NodeConnection.displayName = 'NodeConnection';

function getSocketOffset(elem?: HTMLDivElement): Position {
  if (!elem) {
    return [0, 0];
  }
  const { offsetLeft, offsetTop, offsetWidth, offsetHeight } = elem;
  return [offsetLeft + offsetWidth / 2, offsetTop + offsetHeight / 2];
}

const CURVE_POINT_X_DELTA = 100;
const CURVE_POINT_Y_DELTA = 0;

const SvgCurve = React.memo<{
  start: Position;
  end: Position;
}>(({ start, end }) => {
  const classes = useStyles();

  const startPosition = `${start[0]} ${start[1]}`;
  const startCurve = `${start[0] + CURVE_POINT_X_DELTA} ${
    start[1] - CURVE_POINT_Y_DELTA
  }`;
  const endCurve = `${end[0] - CURVE_POINT_X_DELTA} ${
    end[1] + CURVE_POINT_Y_DELTA
  }`;
  const endPosition = `${end[0]} ${end[1]}`;
  return (
    <svg
      className={clsx(
        'overflow-visible absolute stroke-[5px]',
        classes.connection
      )}
    >
      <path
        d={`M ${startPosition}C ${startCurve} ${endCurve} ${endPosition}`}
      />
    </svg>
  );
});
SvgCurve.displayName = 'SvgCurve';
