import {
  AnalyticsDashletType,
  ChartData,
  ESFilter,
  MultiChartsResponse,
  NumberOrString,
  OrderType,
  SessionRunData,
  SplitAgg,
  SplitValue,
} from '@tensorleap/api-client';
import clsx from 'clsx';
import { VisualizationFilter } from '../../../core/types/filters';
import { NoDataChart } from '../common/NoDataChart';
import { ChartRequestData } from '../common/interfaces';
import {
  DEFAULT_TRUNCATE_LONG_TAIL,
  LabelsLegend,
} from '../legend/LabelsLegend';
import { useMultiChart, UseSplitXYChart } from '../hooks/useMultiChart';
import { getChartComponent } from '../hooks/utils';
import { useCallback, RefObject, useMemo, useRef, useState } from 'react';
import { exportMultiChartsToCsv } from '../common/exportChartToCsv';
import { CircularProgress, Tooltip } from '../../mui';
import { LABEL_TRUNCATE_LENGTH } from '../common/constants';
import { getSplitLabel } from '../common/utils';
import { useVersionControl } from '../../../core/VersionControlContext';
import { truncateLongtail } from '../../../core/formatters/string-formatting';
import {
  ConfusionMatrixType,
  ConfusionMatrixTypeEnum,
} from '../../../dashboard/dashlet/Analytics/form/ConfusionMatrix';
import { orderBy } from 'lodash';
import { ClassNameProp } from '../../../core/types';
import { useToggle } from '../../../core/useToggle';

export type MultiChartsProps = {
  chartRequestData: ChartRequestData;
  className?: string;
  onFiltersChange?: (_: VisualizationFilter[]) => void;
  filters?: ESFilter[];
  chartType: AnalyticsDashletType;
  chartSubType?: ConfusionMatrixType;
  showLegend?: boolean;
  autoScaleY?: boolean;
  error: Error | undefined;
  isLoading: boolean;
  horizontalSplit: SplitAgg | null;
  verticalSplit: SplitAgg | null;
};

export function MultiCharts(
  props: MultiChartsProps & {
    xyChartsResponse: MultiChartsResponse | undefined;
    preview?: boolean;
  }
) {
  const {
    xyChartsResponse,
    chartRequestData,
    className,
    chartType,
    showLegend = true,
    autoScaleY,
    isLoading,
    error,
    preview,
  } = props;

  const { selectedSessionRunMap } = useVersionControl();

  const mapValue = useCallback(
    (value: NumberOrString) => selectedSessionRunMap.get(value as string)?.name,
    [selectedSessionRunMap]
  );

  const multiChart = useMultiChart({
    xyChartsResponse,
    chartRequestData,
    chartType,
    autoScaleY,
    mapValue,
    preview,
  });

  const {
    allLabels,
    colorMap,
    setHoverLabel,
    hiddenLabels,
    setHiddenLabels,
  } = multiChart;
  const multiChartRef = useRef<HTMLDivElement>(null);
  const [showNames, toggleShowNames] = useToggle(preview ? false : true);
  const [truncatedLongtail, setTruncatedLongtail] = useState<number>(
    DEFAULT_TRUNCATE_LONG_TAIL
  );

  if (isLoading) {
    return (
      <div className="flex items-center font-bold justify-center h-full content-center">
        <CircularProgress />
      </div>
    );
  }

  if (!xyChartsResponse?.charts?.length || error) {
    return (
      <div className="flex h-full w-full overflow-hidden">
        <NoDataChart />
      </div>
    );
  }

  const { innerSplit = null } = props.chartRequestData;

  const overflowingChart =
    chartType === AnalyticsDashletType.Table ||
    props.chartSubType === ConfusionMatrixTypeEnum.ConfusionMatrixTable;

  return (
    <div
      className={clsx(
        'h-full w-full flex flex-1 flex-col',
        !preview && 'p-4',
        className
      )}
      ref={multiChartRef}
    >
      <div
        className={clsx(
          'flex flex-row flex-1',
          overflowingChart ? 'overflow-auto' : 'overflow-hidden'
        )}
      >
        <DisplayChartsOnGrid
          xyChartsResponse={xyChartsResponse}
          multiChart={multiChart}
          chartSubType={props.chartSubType}
          overflowingChart={overflowingChart}
          multiChartRef={multiChartRef}
          multiChartsProps={props}
        />
        {showLegend && (
          <div>
            <LabelsLegend
              showNames={showNames}
              truncatedLongtail={truncatedLongtail}
              setTruncatedLongtail={setTruncatedLongtail}
              toggleShowNames={toggleShowNames}
              labels={allLabels}
              colorMap={colorMap}
              setHoverLabel={setHoverLabel}
              setHiddenLabels={setHiddenLabels}
              hiddenLabels={hiddenLabels}
              innerSplit={innerSplit}
            />
          </div>
        )}
      </div>
      {!preview && (
        <button
          className="flex w-fit h-fit text-3xs text-gray-600 cursor-pointer hover:text-cyan-400"
          onClick={() =>
            exportMultiChartsToCsv(
              xyChartsResponse,
              selectedSessionRunMap,
              innerSplit
            )
          }
        >
          Download as CSV
        </button>
      )}
    </div>
  );
}

interface HorizontalSplitChartsProps {
  xyChartsResponse: MultiChartsResponse;
  multiChart: UseSplitXYChart;
  multiChartsProps: MultiChartsProps;
  chartSubType?: ConfusionMatrixType;
  overflowingChart: boolean;
  multiChartRef?: RefObject<HTMLDivElement>;
}

function calcAxisLocation(
  key: SplitValue | undefined,
  mapper: Map<SplitValue, number> | undefined,
  offset: number
) {
  if (key === undefined || !mapper) {
    return 1 + offset;
  }
  const value = mapper.get(key);
  if (value !== undefined) {
    return value + offset;
  }
  console.error(`Key ${key} not found in mapper`);
  return 1 + offset;
}

function DisplayChartsOnGrid({
  xyChartsResponse,
  multiChart,
  multiChartsProps,
  chartSubType,
  overflowingChart,
  multiChartRef,
}: HorizontalSplitChartsProps) {
  const { charts = [] } = xyChartsResponse;

  const verticalSplit = multiChartsProps.verticalSplit;
  const horizontalSplit = multiChartsProps.horizontalSplit;

  const {
    horizontalGridMapper,
    horizontalLabels,
    verticalGridMapper,
    verticalLabels,
    withoutDataCells,
  } = useLayoutChartGrid(charts, horizontalSplit, verticalSplit);

  const colOffset = verticalLabels?.length ? 1 : 0;
  const rowOffset = horizontalLabels?.length ? 1 : 0;

  const calcLocation = useCallback(
    (cell: { horizontalKey?: SplitValue; verticalKey?: SplitValue }) => {
      return {
        gridColumn: calcAxisLocation(
          cell.horizontalKey,
          horizontalGridMapper,
          colOffset
        ),
        gridRow: calcAxisLocation(
          cell.verticalKey,
          verticalGridMapper,
          rowOffset
        ),
      };
    },
    [horizontalGridMapper, verticalGridMapper, colOffset, rowOffset]
  );

  return (
    <div className="relative flex flex-1 w-full min-h-full max-h-full max-w-full">
      <div
        style={{
          gridTemplateColumns: `${colOffset ? '30px ' : ''} repeat(${
            horizontalLabels?.length ?? 1
          }, 1fr)`,
          gridTemplateRows: `${rowOffset ? '30px ' : ''}repeat(${
            verticalLabels?.length ?? 1
          }, 1fr)`,
        }}
        className={clsx(
          'gap-2 grid absolute inset-0',
          !overflowingChart && 'overflow-hidden'
        )}
      >
        {horizontalLabels?.map((label, index) => (
          <ChartGridLabel
            label={label}
            key={String(label)}
            gridRow={1}
            gridColumn={index + 1 + colOffset}
          />
        ))}

        {verticalLabels?.map((label, index) => (
          <ChartGridLabel
            label={label}
            key={String(label)}
            gridRow={index + 1 + rowOffset}
            gridColumn={1}
            vertical
          />
        ))}

        {withoutDataCells.map((cell, index) => (
          <div
            key={index}
            style={{ ...calcLocation(cell) }}
            className="w-full h-full flex flex-col items-center justify-center gap-1"
          >
            <span className="text-sm text-gray-600">No data</span>
          </div>
        ))}

        {charts.map((chart, index) => (
          <SingleSplitChart
            key={`${chart.horizontalKey}-${chart.verticalKey}-${index}`}
            multiChart={multiChart}
            multiChartsProps={multiChartsProps}
            chart={chart}
            chartSubType={chartSubType}
            multiChartRef={multiChartRef}
            {...calcLocation(chart)}
          />
        ))}
      </div>
    </div>
  );
}

type ChartGridLabelProps = {
  label: SplitValue;
  vertical?: boolean;
  gridRow: number;
  gridColumn: number;
} & ClassNameProp;

function ChartGridLabel({
  label,
  className,
  vertical,
  gridRow,
  gridColumn,
}: ChartGridLabelProps) {
  return (
    <Tooltip title={label} arrow>
      <span
        style={{ gridRow, gridColumn }}
        className={clsx(
          'flex justify-center items-center text-gray-500 whitespace-nowrap',
          vertical ? '-rotate-90 w-[30px] h-full' : 'w-full h-[30px]',
          className
        )}
      >
        {truncateLongtail({
          value: label,
          startSubsetLength: LABEL_TRUNCATE_LENGTH,
        })}
      </span>
    </Tooltip>
  );
}

function createKeyFromVerticalAndHorizontal(
  verticalKey: SplitValue,
  horizontalKey: SplitValue
) {
  return `${verticalKey}-${horizontalKey}`;
}

function useLayoutChartGrid(
  charts: ChartData[],
  horizontalSplit: SplitAgg | null,
  verticalSplit: SplitAgg | null
) {
  const { selectedSessionRunMap } = useVersionControl();

  const { horizontal, vertical } = useMemo(() => {
    const horizontal = calcKeysLabelsAndMapGrid(
      charts,
      'horizontalKey',
      horizontalSplit,
      selectedSessionRunMap
    );
    const vertical = calcKeysLabelsAndMapGrid(
      charts,
      'verticalKey',
      verticalSplit,
      selectedSessionRunMap
    );

    return { horizontal, vertical };
  }, [charts, selectedSessionRunMap, horizontalSplit, verticalSplit]);

  const withoutDataCells = useMemo(() => {
    if (!horizontal || !vertical) {
      return [];
    }

    const allCells = new Map(
      horizontal.keys.flatMap((h) =>
        vertical.keys.map((v) => [
          createKeyFromVerticalAndHorizontal(v, h),
          { horizontalKey: h, verticalKey: v },
        ])
      )
    );

    charts.forEach((chart) =>
      allCells.delete(
        createKeyFromVerticalAndHorizontal(
          chart.verticalKey!,
          chart.horizontalKey!
        )
      )
    );

    return Array.from(allCells.values());
  }, [charts, horizontal, vertical]);

  return {
    withoutDataCells,
    horizontalGridMapper: horizontal?.gridMapper,
    horizontalLabels: horizontal?.labels,
    verticalGridMapper: vertical?.gridMapper,
    verticalLabels: vertical?.labels,
  };
}

function calcKeysLabelsAndMapGrid(
  charts: ChartData[],
  key: 'horizontalKey' | 'verticalKey',
  split: SplitAgg | null,
  selectedSessionRunMap: Map<string, SessionRunData>
) {
  const keysSet = new Set<SplitValue>();

  charts.forEach((chart) => {
    const keyValue: SplitValue | undefined = chart[key];
    if (keyValue !== undefined) {
      keysSet.add(keyValue);
    }
  });
  if (keysSet.size === 0) {
    return undefined;
  }

  const keys = orderBy(Array.from(keysSet), undefined, split?.order ?? 'desc');

  let labelSplit = split;
  if (key === 'verticalKey' && labelSplit) {
    labelSplit = {
      ...labelSplit,
      order:
        labelSplit?.order === OrderType.Asc ? OrderType.Desc : OrderType.Asc,
    };
  }

  const labels = keys.map((key) =>
    getSplitLabel(key, selectedSessionRunMap, labelSplit)
  );

  const gridMapper = new Map<SplitValue, number>();

  keys.forEach((key, index) => {
    gridMapper.set(key, index + 1);
  });

  return { keys, labels, gridMapper };
}

interface SingleSplitChartProps {
  multiChart: UseSplitXYChart;
  multiChartsProps: MultiChartsProps;
  chart: ChartData;
  gridColumn: number;
  gridRow: number;
  chartSubType?: ConfusionMatrixType;
  multiChartRef?: RefObject<HTMLDivElement>;
}
function SingleSplitChart({
  multiChart,
  multiChartsProps,
  gridColumn,
  gridRow,
  chart,
  chartSubType,
  multiChartRef,
}: SingleSplitChartProps) {
  const {
    onFiltersChange,
    filters,
    chartType,
    chartRequestData,
  } = multiChartsProps;

  const {
    splitXAxisDomain,
    splitYAxisDomain,
    hiddenLabels,
    hoverLabel,
    colorMap,
    splitAxisType,
    mapValue,
    preview,
  } = multiChart;

  const ChartComponent = useMemo(
    () => getChartComponent(chartType, chartSubType),
    [chartSubType, chartType]
  );
  return (
    <div
      style={{ gridColumn, gridRow }}
      className={clsx(
        'flex flex-col flex-1 w-full h-full gap-1',
        chartType === AnalyticsDashletType.Table && 'max-h-96'
      )}
    >
      {chart && ChartComponent ? (
        <ChartComponent
          className="flex flex-1"
          graphData={chart.data}
          onFiltersChange={onFiltersChange}
          filters={filters}
          showXAxisLine
          showYLabel
          xAxisDomain={splitXAxisDomain}
          yAxisDomain={splitYAxisDomain}
          hiddenLabels={hiddenLabels}
          hoverLabel={hoverLabel}
          colorMap={colorMap}
          chartRequestData={chartRequestData}
          axisType={splitAxisType}
          multiChartRef={multiChartRef}
          mapValue={mapValue}
          preview={preview}
        />
      ) : (
        <NoDataChart />
      )}
    </div>
  );
}
