import { useMemo } from 'react';
import { SampleIdentity, ScatterVizDataState } from '@tensorleap/api-client';
import { useScatterData } from './ScatterDataContext';
import { labelsColorSupplier } from '../core/color-helper';
import { orderBy } from 'lodash';
import { SHAPES, ShapeType } from './ScatterAnalyzerView/ScatterShape';
import { SortTypeEnum } from '../ui/charts/legend/LabelsLegendMenu';

export const NONE_OPTION = 'None';

const MIN_DOT_RADIUS = 1;
const MAX_DOT_RADIUS = 5;

const HEAT_COLORS = [
  { r: 106, g: 137, b: 204 },
  { r: 130, g: 204, b: 221 },
  { r: 184, g: 233, b: 148 },
  { r: 250, g: 211, b: 144 },
  { r: 229, g: 80, b: 57 },
  { r: 125, g: 0, b: 0 },
];

export interface DataPoint {
  x: number;
  y: number;
  z: number;
  label: string;
  sample?: SampleIdentity;
  radius: number;
  shape: ShapeType;
  color: string;
}

export function minMaxValues(
  data: number[]
): {
  min: number | undefined;
  max: number | undefined;
} {
  return data.reduce(
    ({ min, max }, n) => ({
      min: Math.min(n, min),
      max: Math.max(n, max),
    }),
    { min: data[0], max: data[0] }
  );
}

function calcColor(i: number, j: number, heatDiff: number): string {
  if (Number.isNaN(i) || Number.isNaN(j) || Number.isNaN(heatDiff)) {
    console.error('got unvalid number', i, j, heatDiff);
    return 'rgb(0,0,0)';
  }
  if (j === 0) {
    return `rgb(${HEAT_COLORS[j].r}, ${HEAT_COLORS[j].g}, ${HEAT_COLORS[j].b})`;
  }
  if (i === 5) {
    return `rgb(${80}, ${HEAT_COLORS[j].g}, ${HEAT_COLORS[j].b})`;
  }
  const r = HEAT_COLORS[i].r + (HEAT_COLORS[j].r - HEAT_COLORS[i].r) * heatDiff;
  const g = HEAT_COLORS[i].g + (HEAT_COLORS[j].g - HEAT_COLORS[i].g) * heatDiff;
  const b = HEAT_COLORS[i].b + (HEAT_COLORS[j].b - HEAT_COLORS[i].b) * heatDiff;
  return `rgb(${r}, ${g}, ${b})`;
}

export function getRangeDataPointColor(
  heatValue: number,
  minValue: number,
  maxValue: number
): string {
  const heatRange = maxValue - minValue;
  const heatRatio = heatRange * 0.2;
  const heatDiff =
    heatRatio === 0 ? 0 : ((heatValue - minValue) % heatRatio) / heatRatio;
  const heatStop = heatRatio === 0 ? 0 : (heatValue - minValue) / heatRatio;
  switch (heatStop) {
    // min value no need to calculate difference, just the first color
    case 0: {
      return calcColor(0, 0, heatDiff);
    }
    case 5: {
      // max value no need to calculate difference, just the last color
      return calcColor(5, 5, heatDiff);
    }
  }
  const x = Math.floor(heatStop);
  const y = Math.min(x + 1, 5);

  return calcColor(x, y, heatDiff);
}

interface GenerateScatterDataProps {
  vizState: ScatterVizDataState;
  sizeOrShape?: string;
  dotColor?: string;
  sizeOrShapeOrderMethod: SortTypeEnum;
}
function generateScatterData({
  vizState: { metadata = {}, scatter_data: scatterData, samples },
  sizeOrShape,
  dotColor,
  sizeOrShapeOrderMethod,
}: GenerateScatterDataProps): (DataPoint & { originalIndex: number })[] {
  let dotsRadiuses: number[] | undefined = undefined;
  let indexToShape: Map<number | string, ShapeType> | undefined = undefined;

  if (sizeOrShape !== undefined) {
    const sizeOrShapeSubjectMetadata = metadata[sizeOrShape];
    if (sizeOrShapeSubjectMetadata?.type !== 'range') {
      ({ indexToShape } = createShapeMapping(
        sizeOrShapeSubjectMetadata.body,
        sizeOrShapeOrderMethod
      ));
    } else {
      dotsRadiuses = createSizeMapping(
        sizeOrShapeSubjectMetadata.body as number[]
      );
    }
  }

  const filteredIndexes = Array.from(
    { length: scatterData.length },
    (_, i) => i
  );

  let calcColor: (_: number) => string = () =>
    labelsColorSupplier.get(String(NONE_OPTION));
  if (dotColor !== undefined) {
    const dotColorSubjectMetadata = metadata[dotColor];

    if (dotColorSubjectMetadata?.type === 'labels') {
      calcColor = (pointIndex) =>
        labelsColorSupplier.get(
          String(dotColorSubjectMetadata.body[pointIndex])
        );
    } else {
      const allColorValues = dotColorSubjectMetadata.body as number[];

      const filteredColorValues = filteredIndexes.map(
        (index) => allColorValues[index]
      );

      const { min = 0, max = 0 } = minMaxValues(filteredColorValues);
      calcColor = (pointIndex) =>
        getRangeDataPointColor(allColorValues[pointIndex], min, max);
    }
  }

  return filteredIndexes.map<DataPoint & { originalIndex: number }>(
    (pointIndex) => {
      const [x, y, z] = scatterData[pointIndex];
      const sample = samples[pointIndex];

      return {
        originalIndex: pointIndex,
        x,
        y,
        z,
        label: '',
        sample,
        radius: dotsRadiuses ? dotsRadiuses[pointIndex] : 2,
        shape: indexToShape?.get(pointIndex) ?? 'circle',
        color: calcColor(pointIndex),
      };
    }
  );
}

export type ScatterMapData = DataPoint & {
  originalIndex: number;
};

function createSizeMapping(sizeSubjectMetadata: number[]): number[] {
  const { min: minDotSize = 0, max: maxDotSize = 0 } = minMaxValues(
    sizeSubjectMetadata
  );

  const dotsRadiuses = sizeSubjectMetadata.map((value) => {
    const normalized =
      maxDotSize === minDotSize
        ? 0
        : (value - minDotSize) / (maxDotSize - minDotSize);
    return normalized * (MAX_DOT_RADIUS - MIN_DOT_RADIUS) + MIN_DOT_RADIUS;
  });
  return dotsRadiuses;
}

type CreateShapeMappingResult = {
  indexToShape: Map<number | string, ShapeType>;
  metadataToShape: Map<string | number, ShapeType>;
  metadataInOtherGroup: Set<string | number>;
  appearance: Map<string | number, number>;
};

export function createShapeMapping(
  shapeSubjectMetadata: (string | number)[],
  sortMethod: SortTypeEnum
): CreateShapeMappingResult {
  const sortByAlphabet =
    sortMethod === SortTypeEnum.ASC_ALPHABETICALLY ||
    sortMethod === SortTypeEnum.DESC_ALPHABETICALLY;

  const sortDirection =
    sortMethod === SortTypeEnum.ASC_ALPHABETICALLY ||
    sortMethod === SortTypeEnum.ASC_BY_PRESENCE
      ? 'asc'
      : 'desc';

  const appearance = shapeSubjectMetadata.reduce((acc, value) => {
    acc.set(value, (acc.get(value) || 0) + 1);
    return acc;
  }, new Map<string | number, number>());

  const options = Array.from(appearance.keys());

  const sortedOptions = orderBy(
    options,
    (option) => {
      if (sortByAlphabet) {
        const asNumber = Number(option);
        return Number.isNaN(asNumber) ? option : asNumber;
      }
      return appearance.get(option);
    },
    sortDirection
  );

  const metadataInOtherGroup = new Set<string | number>();

  const metadataToShape = new Map<string | number, ShapeType>(
    sortedOptions.map((option, index) => {
      if (index >= SHAPES.length) {
        metadataInOtherGroup.add(option);
        return [option, SHAPES[SHAPES.length - 1]];
      }

      return [option, SHAPES[index]];
    })
  );
  const indexToShape = new Map<number | string, ShapeType>(
    shapeSubjectMetadata.map((option, index) => [
      index,
      metadataToShape.get(option) || SHAPES[0],
    ])
  );
  if (metadataInOtherGroup.size == 1) {
    metadataToShape.clear();
  }
  return {
    indexToShape,
    metadataToShape,
    metadataInOtherGroup,
    appearance,
  };
}

export function useScatterMapData(): ScatterMapData[] {
  const {
    scatterData,
    settings: { dotColor, sizeOrShape },
    sizeOrShapeOrderMethod,
  } = useScatterData();

  const data = useMemo(
    () =>
      generateScatterData({
        vizState: scatterData,
        dotColor,
        sizeOrShape,
        sizeOrShapeOrderMethod,
      }),
    [scatterData, dotColor, sizeOrShape, sizeOrShapeOrderMethod]
  );

  return data;
}

export function calcMetadataObject(
  scatterData: ScatterVizDataState,
  index: number | undefined
): Record<string, string> {
  const metadataObject = Object.entries(scatterData.metadata).reduce(
    (acc, [key, value]) => {
      if (index === undefined) return acc;
      acc[key] = value.body[index]?.toString();
      return acc;
    },
    {} as Record<string, string>
  );
  return metadataObject;
}
