import { FilterOperatorType } from '@tensorleap/api-client';
import { useCallback, useMemo, useState } from 'react';
import { CategoricalChartState } from 'recharts/types/chart/generateCategoricalChart';
import {
  BTWVisualizationFilter,
  EQVisualizationFilter,
  INVisualizationFilter,
} from '../../../core/types/filters';
import { mapToVisualizationFilters } from '../../../model-tests/modelTestHelpers';
import { DEFAULT_AXIS_SIZE_INTERVAL } from '../common/constants';
import {
  UseXYChartDragFilterRequest,
  UseXYChartDragFilter,
  DataPoint,
} from '../common/interfaces';
import { useReferenceAreaProps } from '../visualizers/ChartBlocks/ReferenceAreaProps';
import { getParsedChartPoints } from './utils';

export function useXYChartDragFilter({
  graphData,
  chartRequestData,
  axisType,
  filters = [],
  onFiltersChange,
  setIsMouseDown,
}: UseXYChartDragFilterRequest): UseXYChartDragFilter {
  const allPointsFlat = useMemo(
    (): DataPoint[] => getParsedChartPoints({ graphData, chartRequestData }),
    [chartRequestData, graphData]
  );

  const [dragStart, setDragStart] = useState<number | string | undefined>(
    undefined
  );
  const [dragEnd, setDragEnd] = useState<number | string | undefined>(
    undefined
  );

  const resetDrag = useCallback((): void => {
    if (!onFiltersChange) {
      return;
    }
    setDragStart(undefined);
    setDragEnd(undefined);
  }, [onFiltersChange]);

  const setLocalFilter = useCallback((): void => {
    if (!onFiltersChange || dragStart === undefined) {
      return;
    }
    const updatedFilters = filters.map(mapToVisualizationFilters);
    if (dragStart === dragEnd || dragEnd === undefined) {
      if (axisType === 'category') {
        const eqFilter: EQVisualizationFilter = {
          field: chartRequestData.xField,
          operator: FilterOperatorType.Equal,
          value: dragStart,
        };
        updatedFilters.push(eqFilter);
      } else {
        if (isNaN(Number(dragStart))) {
          return;
        }
        const interval = Number(
          chartRequestData.xSizeInterval || DEFAULT_AXIS_SIZE_INTERVAL
        );
        const gte = Number(dragStart);
        const lt = gte + interval;
        const xBtwFilter: BTWVisualizationFilter = {
          field: chartRequestData.xField,
          operator: FilterOperatorType.Between,
          value: { gte, lt },
        };
        updatedFilters.push(xBtwFilter);
      }
      onFiltersChange(updatedFilters);
      resetDrag();
      return;
    }

    if (
      axisType === 'number' &&
      !isNaN(Number(dragStart)) &&
      !isNaN(Number(dragEnd))
    ) {
      const gte = Math.min(Number(dragStart), Number(dragEnd));
      const lt = Math.max(Number(dragStart), Number(dragEnd));
      const btwFilter: BTWVisualizationFilter = {
        field: chartRequestData.xField,
        operator: FilterOperatorType.Between,
        value: { gte, lt },
      };
      updatedFilters.push(btwFilter);
      onFiltersChange(updatedFilters);
      resetDrag();
      return;
    }
    const terms = allPointsFlat.map((point) => point[chartRequestData.xField]);

    const firstIndex = terms.indexOf(dragStart);
    const lastIndex = terms.indexOf(dragEnd);
    const slice = terms.slice(
      Math.min(firstIndex, lastIndex),
      Math.max(lastIndex, firstIndex) + 1
    );
    const filterTerms = slice.map((term) => String(term));
    const inFilter: INVisualizationFilter = {
      field: chartRequestData.xField,
      operator: FilterOperatorType.In,
      value: filterTerms,
    };
    updatedFilters.push(inFilter);
    onFiltersChange(updatedFilters);
    resetDrag();
    return;
  }, [
    onFiltersChange,
    dragStart,
    filters,
    dragEnd,
    axisType,
    allPointsFlat,
    chartRequestData.xField,
    chartRequestData.xSizeInterval,
    resetDrag,
  ]);

  const onMouseDown = useCallback(
    (nextState: CategoricalChartState) => {
      if (!onFiltersChange) {
        return;
      }
      setDragStart(nextState?.activeLabel);
      setIsMouseDown(true);
    },
    [onFiltersChange, setIsMouseDown]
  );

  const onMouseMove = useCallback(
    (nextState: CategoricalChartState) => {
      if (!onFiltersChange) {
        return;
      }
      dragStart && setDragEnd(nextState?.activeLabel);
    },
    [dragStart, onFiltersChange]
  );

  const onMouseUp = useCallback(() => {
    setIsMouseDown(false);
    setLocalFilter();
  }, [setIsMouseDown, setLocalFilter]);

  const onMouseLeave = useCallback(() => {
    setIsMouseDown(false);
    resetDrag();
  }, [resetDrag, setIsMouseDown]);

  const referenceAreaProps = useReferenceAreaProps({ dragStart, dragEnd });

  return useMemo(
    () => ({
      onMouseDown,
      onMouseMove,
      onMouseUp,
      onMouseLeave,
      referenceAreaProps,
    }),
    [onMouseDown, onMouseLeave, onMouseMove, onMouseUp, referenceAreaProps]
  );
}
