import { SlimVisualization } from '@tensorleap/api-client';
import { useMergedObject } from '../../../core/useMergedObject';
import { useCallback, useEffect, useMemo, useState } from 'react';
import { useEpochRange } from '../../add-analysis/panes/useSelectedModels';
import { first, groupBy, orderBy } from 'lodash';

export type EpochOption = {
  epoch: number;
  slimVisualization?: SlimVisualization;
};

export type UseEpochVisualizationState = {
  selectEpoch: (epoch: EpochOption) => void;
  epochsOptions: EpochOption[];
  selectedEpoch: EpochOption;
  isSelectedEpochUpgrading: boolean;
  upgradeToLastEpoch: () => void;
  needsUpgrade: boolean;
};

type UseEpochVisualizationStateProps = {
  sessionRunId: string;
  inProcessEpochs: number[];
  slimVisualizations?: SlimVisualization[];
  upgrade: (epoch: number) => void;
};

export function useEpochVisualizationState({
  sessionRunId,
  inProcessEpochs,
  slimVisualizations,
  upgrade,
}: UseEpochVisualizationStateProps): UseEpochVisualizationState {
  const { getEpochRange } = useEpochRange();

  const options = useMemo(() => {
    const epoch = getEpochRange(sessionRunId);
    const byEpoch = groupBy(slimVisualizations || [], (v) => v.epoch);
    const options: EpochOption[] = epoch.flatMap((epoch) => {
      const epochVisualizations = byEpoch[epoch];
      if (!epochVisualizations) {
        return {
          epoch,
        };
      }
      return epochVisualizations.map((slimVisualization) => ({
        epoch,
        slimVisualization,
      }));
    });

    const sortedOptions = orderBy(
      options,
      [
        'epoch',
        (option) =>
          option.slimVisualization
            ? option.slimVisualization.createdAt
            : new Date(0),
      ],
      ['desc', 'desc']
    );

    return sortedOptions;
  }, [sessionRunId, slimVisualizations, getEpochRange]);

  const [userSelectedEpoch, setUserSelectedEpoch] = useState<
    | {
        epoch: number;
        visualizationId?: string;
      }
    | undefined
  >(undefined);

  const selectedEpoch = useMemo<EpochOption>(() => {
    const maxEpoch = options.find((option) => !!option.slimVisualization) ??
      first(options) ?? { epoch: 0 };

    if (!userSelectedEpoch) {
      return maxEpoch;
    }
    const { epoch, visualizationId } = userSelectedEpoch;
    const selectedEpoch = options.find(
      (option) =>
        option.epoch === epoch &&
        (!visualizationId || option.slimVisualization?.cid === visualizationId)
    );

    return selectedEpoch ?? maxEpoch;
  }, [options, userSelectedEpoch]);

  const needsUpgrade = useMemo(() => {
    const maxEpoch = first(options) ?? { epoch: 0 };
    const needsUpgrade = !maxEpoch.slimVisualization;
    const isUpgrading = inProcessEpochs.includes(maxEpoch.epoch);
    return needsUpgrade && !isUpgrading;
  }, [options, inProcessEpochs]);

  const isSelectedEpochUpgrading = inProcessEpochs.includes(
    selectedEpoch.epoch
  );

  const selectEpoch = useCallback(
    ({ epoch, slimVisualization }: EpochOption) => {
      setUserSelectedEpoch({ epoch, visualizationId: slimVisualization?.cid });
    },
    [setUserSelectedEpoch]
  );

  const upgradeToLastEpoch = useCallback(() => {
    const maxEpoch = first(options) ?? { epoch: 0 };
    setUserSelectedEpoch(maxEpoch);
    upgrade(maxEpoch.epoch);
  }, [options, setUserSelectedEpoch, upgrade]);

  useEffect(() => {
    const isUserSelectedEpochUpToDate =
      userSelectedEpoch &&
      userSelectedEpoch.visualizationId !==
        selectedEpoch.slimVisualization?.cid;

    isUserSelectedEpochUpToDate && selectEpoch(selectedEpoch);
  }, [selectedEpoch, selectEpoch, userSelectedEpoch]);

  return useMergedObject({
    selectEpoch,
    upgradeToLastEpoch,
    isSelectedEpochUpgrading,
    epochsOptions: options,
    selectedEpoch,
    needsUpgrade,
  });
}
