import {
  AnalyticsDashletType,
  FilterOperatorType,
  GenericDataItem,
  GenericDataResponse,
  MultiChartsResponse,
  NumberOrString,
  OrderType,
} from '@tensorleap/api-client';
import { orderBy } from 'lodash';
import { useMemo } from 'react';
import { AxisDomain } from 'recharts/types/util/types';
import { labelsColorSupplier } from '../../../core/color-helper';
import { useTheme } from '../../mui';
import {
  CHART_META_FIELDS_KEY,
  DEFAULT_AXIS_SIZE_INTERVAL,
  GRAPH_STYLE,
} from '../common/constants';
import {
  ChartRequestData,
  DataPoint,
  XYAxistType,
  XYChartProps,
} from '../common/interfaces';
import { AreaChart } from '../visualizers/AreaChart';
import { BarChart } from '../visualizers/BarChart';
import { Heatmap } from '../visualizers/Heatmap';
import { LineChart } from '../visualizers/LineChart';
import { PieChart } from '../visualizers/PieChart';
import { TableChart } from '../visualizers/TableChart';
import { isSplitDefined } from '../../../dashboard/dashlet/Analytics/ElasticVis/utils';
import {
  ConfusionMatrixType,
  ConfusionMatrixTypeEnum,
} from '../../../dashboard/dashlet/Analytics/form/ConfusionMatrix';
import { ConfusionMatrixSingleHeatmapTable } from '../visualizers/HeatmapTable';
import { VisualizationFilter } from '../../../core/types/filters';

export function useChartColorMap(
  labels: string[]
): { [label: string]: string } {
  const theme = useTheme();
  return useMemo(
    () =>
      labels.reduce((curr, label, index) => {
        curr[label] =
          theme.palette.lineGraph[index % theme.palette.lineGraph.length] ||
          labelsColorSupplier.get(label);
        return curr;
      }, {} as Record<string, string>),
    [labels, theme.palette.lineGraph]
  );
}

export interface AxistDomainRequestParams {
  points: DataPoint[];
  chartRequestData: ChartRequestData;
  axis: XYAxistType;
  axisType: 'number' | 'category';
  chartType: AnalyticsDashletType;
  minDomainOverride?: NumberOrString;
  maxDomainOverride?: NumberOrString;
}
/**
 * Gets the axis domain (x or y accordingly).
 * The domain is the range of values (min and max) that the axis data uses.
 * If no values can be retreived or if the min and max are equal then dataMax and dataMin are returned as default values.
 */
export function getAxisDomain({
  points,
  chartRequestData,
  axis,
  axisType,
  chartType,
  minDomainOverride,
  maxDomainOverride,
}: AxistDomainRequestParams): AxisDomain | undefined {
  if (minDomainOverride != undefined && maxDomainOverride != undefined) {
    return [minDomainOverride, maxDomainOverride];
  }
  if (axis === 'x' && axisType && axisType === 'category') {
    return chartType === AnalyticsDashletType.Bar
      ? ['dataMin', 'dataMax']
      : undefined;
  }

  const allPoints = Array.from(
    new Set(
      points.map((point: DataPoint) => {
        const pointKey =
          axis === 'x' ? chartRequestData.xField : chartRequestData.yField;
        const pointValue = Number(point?.[pointKey]);
        return pointValue;
      }, [])
    )
  );

  if (allPoints.length === 0) {
    const maxDomain =
      maxDomainOverride !== undefined ? maxDomainOverride : 'dataMax';
    const minDomain =
      minDomainOverride !== undefined ? minDomainOverride : 'dataMin';
    return [minDomain, maxDomain];
  }
  if (allPoints.length === 1) {
    const point = allPoints[0];
    const interval =
      axis === 'x' && chartType === AnalyticsDashletType.Bar ? point / 100 : 0;
    const maxNum = maxDomainOverride ?? point + Math.ceil(interval);
    const minNum = minDomainOverride ?? point - Math.floor(interval);
    if (
      typeof maxNum === 'number' &&
      typeof minNum === 'number' &&
      minNum === maxNum
    ) {
      return [minNum - 1, maxNum + 1];
    }

    return [minNum, maxNum];
  }

  allPoints.sort((a, b) => a - b);
  const interval =
    axis === 'x' && chartType === AnalyticsDashletType.Bar
      ? (allPoints.at(-1) || 0) - (allPoints.at(-2) || 0)
      : 0;

  const maxNum =
    maxDomainOverride ?? (allPoints.at(-1) || 0 + interval) + interval;
  const minNum = minDomainOverride ?? allPoints[0] - interval;
  if (minNum === maxNum) {
    return [Number(minNum) - 1, Number(maxNum) + 1];
  }

  return [minNum, maxNum];
}

export function getChartComponent(
  chartType: AnalyticsDashletType,
  subType?: ConfusionMatrixType
): (props: XYChartProps) => JSX.Element {
  switch (chartType) {
    case AnalyticsDashletType.Bar:
      return BarChart;
    case AnalyticsDashletType.Line:
      return LineChart;
    case AnalyticsDashletType.Area:
      return AreaChart;
    case AnalyticsDashletType.Table:
      return TableChart;
    case AnalyticsDashletType.Heatmap:
      return Heatmap;
    case AnalyticsDashletType.Donut:
      return PieChart;
    case AnalyticsDashletType.ConfusionMatrix:
      switch (subType) {
        case ConfusionMatrixTypeEnum.ConfusionMatrixTable:
          return ConfusionMatrixSingleHeatmapTable;
        case ConfusionMatrixTypeEnum.ConfusionMatrixByLabelVis:
          return Heatmap;
        default:
          return LineChart;
      }
  }
}

function getDataPoints({
  graphData,
  chartRequestData: { xField, yField, innerSplit },
}: GetParsedChartPointsParams): DataPoint[] {
  return graphData.data.map(({ data, metadata, innerKey }: GenericDataItem) => {
    let dataPoint: DataPoint;
    if (isSplitDefined(innerSplit)) {
      const inner = innerKey ?? data[innerSplit.field];
      dataPoint = {
        [xField]: data[xField],
        [inner]: data[yField],
      };
    } else {
      dataPoint = data;
    }
    if (metadata) {
      dataPoint[CHART_META_FIELDS_KEY] = metadata;
    }
    return dataPoint;
  });
}

function sortDataPoints(
  dataPoints: DataPoint[],
  {
    xField,
    yField,
    orderParams = OrderType.Asc,
    orderByParam = '_key',
  }: ChartRequestData
): DataPoint[] {
  const orderField = orderByParam === '_key' ? xField : yField;
  return orderBy(dataPoints, orderField, orderParams);
}
export interface GetParsedChartPointsParams {
  graphData: GenericDataResponse;
  chartRequestData: ChartRequestData;
}
export function getParsedChartPoints({
  graphData,
  chartRequestData,
}: GetParsedChartPointsParams): DataPoint[] {
  const dataPoints = getDataPoints({ graphData, chartRequestData });
  const sortedDataPoints = sortDataPoints(dataPoints, chartRequestData);
  return sortedDataPoints;
}

export function getNormalizedParsedChartPoints({
  graphData,
  chartRequestData,
}: GetParsedChartPointsParams): Record<NumberOrString, NumberOrString>[] {
  return getParsedChartPoints({
    graphData,
    chartRequestData,
  }).map(
    (dataPoint: DataPoint) =>
      dataPoint as Record<NumberOrString, NumberOrString>
  );
}

export interface getAllChartPointsFlatParams {
  xyChartsResponse: MultiChartsResponse | undefined;
  chartType: AnalyticsDashletType;
  chartRequestData: ChartRequestData;
}
export function getAllChartPointsFlat({
  xyChartsResponse,
  chartRequestData,
  chartType,
}: getAllChartPointsFlatParams): DataPoint[] {
  if (!xyChartsResponse) {
    return [];
  }
  return xyChartsResponse.charts.flatMap((chart) => {
    if (
      chartType === AnalyticsDashletType.Table ||
      chartType === AnalyticsDashletType.Heatmap
    ) {
      return [];
    }
    return getParsedChartPoints({
      graphData: chart.data,
      chartRequestData,
    });
  });
}

export interface GetAllLabelsParams {
  dataPoints: DataPoint[];
  chartType: AnalyticsDashletType;
  xField: string;
}

export function getAllLabels({
  dataPoints,
  chartType,
  xField,
}: GetAllLabelsParams): string[] {
  if (
    chartType === AnalyticsDashletType.Table ||
    chartType === AnalyticsDashletType.Heatmap
  ) {
    return [];
  }
  const distinctLabels = Array.from(
    new Set(
      dataPoints
        ?.flatMap((point) => Object.keys(point))
        ?.filter((x) => x !== xField && x !== CHART_META_FIELDS_KEY)
    )
  );
  return distinctLabels;
}

export function getLabelFontSize(label: string): number {
  const labelCharsCount = label.length;
  if (labelCharsCount <= GRAPH_STYLE.mediumFontLabelLength) {
    return GRAPH_STYLE.yAxis.label.fontSizeLarge;
  }
  if (labelCharsCount <= GRAPH_STYLE.largeFontLabelLength) {
    return GRAPH_STYLE.yAxis.label.fontSizeMedium;
  }
  return GRAPH_STYLE.yAxis.label.fontSizeSmall;
}

export type GetOnFirstClickFilterParams = {
  from: NumberOrString;
  field: string;
  dataDistribution: 'continuous' | 'distinct' | undefined;
  sizeInterval?: number;
};
export function getOnFirstClickFilter({
  from,
  field,
  dataDistribution,
  sizeInterval,
}: GetOnFirstClickFilterParams): VisualizationFilter | undefined {
  if (dataDistribution === 'distinct') {
    return {
      field,
      operator: FilterOperatorType.Equal,
      value: from,
    };
  }

  if (isNaN(Number(from))) {
    return;
  }

  const interval = Number(sizeInterval || DEFAULT_AXIS_SIZE_INTERVAL);
  const gte = Number(from);
  const lt = gte + interval;
  return {
    field,
    operator: FilterOperatorType.Between,
    value: { gte, lt },
  };
}

export type GetOnLastClickFilterParams = {
  from: NumberOrString;
  to: NumberOrString;
  field: string;
  dataDistribution: 'continuous' | 'distinct' | undefined;
  labels: NumberOrString[];
};

export function getOnlastClickFilter({
  dataDistribution,
  from,
  to,
  field,
  labels,
}: GetOnLastClickFilterParams): VisualizationFilter {
  if (
    dataDistribution === 'continuous' &&
    !isNaN(Number(from)) &&
    !isNaN(Number(to))
  ) {
    const gte = Math.min(Number(from), Number(to));
    const lt = Math.max(Number(from), Number(to));
    return {
      field,
      operator: FilterOperatorType.Between,
      value: { gte, lt },
    };
  }

  const slice = getLabelRangeByFromTo(from, to, labels);
  const filterTerms = slice.map(String);
  return {
    field,
    operator: FilterOperatorType.In,
    value: filterTerms,
  };
}

export function getLabelRangeByFromTo(
  from: NumberOrString,
  to: NumberOrString,
  labels: NumberOrString[]
): NumberOrString[] {
  const firstIndex = labels.indexOf(from);
  const lastIndex = labels.indexOf(to);
  const slice = labels.slice(
    Math.min(firstIndex, lastIndex),
    Math.max(lastIndex, firstIndex) + 1
  );
  return slice;
}
