import { Session, SlimVersion, StatusEnum } from '@tensorleap/api-client';
import {
  Dispatch,
  SetStateAction,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from 'react';
import { Select } from '../../ui/atoms/Select';
import { SelectEpochLineChart } from '../../ui/SelectEpochLineChart';
import { getSessionEpochs } from '../helper-functions';
import { ModelSelector, ModelSelectorProps } from './ModelSelector';
import { useCurrentProject } from '../../core/CurrentProjectContext';
import { Input } from '../../ui/atoms/Input';

export type SelectSessionWithEpochProps = Omit<
  ModelSelectorProps,
  'onChange'
> & {
  onVersionChange?: (version: SlimVersion) => void;
  onSessionChange: (session?: Session) => void;
  epoch?: number;
  onEpochChange: (epoch?: number) => void;
  epochReadonly?: boolean;
  nameLabel?: string;
  name?: string;
  setName?: (name?: string) => void;
  setCreateNewSession?: Dispatch<SetStateAction<boolean>>;
};

export function SelectSessionWithEpoch({
  epoch,
  onEpochChange,
  onSessionChange,
  onVersionChange,
  epochReadonly,
  nameLabel = 'NAME',
  name,
  setName,
  setCreateNewSession,
  ...sessionRunSelectorProps
}: SelectSessionWithEpochProps): JSX.Element {
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();
  const { session } = sessionRunSelectorProps;
  const epochOptions = useMemo<number[]>(() => getSessionEpochs(session), [
    session,
  ]);

  const getFinishedEpochsRange = useCallback((session: Session | undefined): [
    minEpoch: number,
    maxEpoch: number
  ] => {
    if (!session) {
      return [0, 0];
    }
    const epochs = Array.from(
      new Set(
        session?.sessionWeights
          ?.filter(({ status }) => status === StatusEnum.Finished)
          .flatMap(({ epoch }) => epoch)
      )
    );
    const minEpoch = Math.min(...epochs);
    const maxEpoch = Math.max(...epochs);

    return [minEpoch, maxEpoch];
  }, []);

  const [selectedEpoch, setSelectedEpoch] = useState<number | undefined>(
    getFinishedEpochsRange(session)[1]
  );
  useEffect(() => {
    const epochRange = getFinishedEpochsRange(session);
    const currentEpoch = epochRange.length == 2 ? epochRange[1] : epochRange[0];

    onEpochChange(currentEpoch);
    setSelectedEpoch(currentEpoch);
  }, [session, onEpochChange, getFinishedEpochsRange]);

  const sessionRunId = useMemo(() => {
    const selectedWeight = session?.sessionWeights?.find(
      ({ epoch }) => epoch === selectedEpoch
    );
    const selectedSessionRun = session?.sessionRuns?.find(({ weightAssets }) =>
      weightAssets.some(
        (asset) => asset.sessionWeightId === selectedWeight?.cid
      )
    );

    return selectedSessionRun?.cid;
  }, [selectedEpoch, session]);

  return (
    <div className="flex flex-col flex-1 gap-4">
      <div className="flex gap-2">
        <ModelSelector
          {...sessionRunSelectorProps}
          onChange={(model, version) => {
            onSessionChange(model);
            onVersionChange?.(version);
            setCreateNewSession?.(false);
          }}
          onCreateAndSelectSession={(version: SlimVersion) => {
            if (!onVersionChange || !setCreateNewSession || !version) {
              console.error(
                "Something went wrong, one of these isn't defined",
                { onVersionChange, setCreateNewSession, version }
              );
              return;
            }
            {
              onVersionChange(version);
              onSessionChange(undefined);
              setCreateNewSession(true);
            }
          }}
        />

        {session && !!epochOptions.length && (
          <Select
            className="w-20"
            value={epoch || epochOptions[0]}
            options={epochOptions}
            label="EPOCH"
            readonly={epochReadonly}
            required
            onChange={(epoch) => onEpochChange(Number(epoch))}
            optionToLabel={(option: number) => option.toString()}
          />
        )}
      </div>
      {!!setName && (
        <Input
          label={nameLabel}
          value={name}
          onChange={(e) => setName(e.currentTarget.value)}
        />
      )}

      {session && (
        <div className="flex-1">
          <SelectEpochLineChart
            sessionRunId={sessionRunId}
            projectId={projectId}
          />
        </div>
      )}
    </div>
  );
}
