import {
  ComponentType,
  ReactNode,
  useCallback,
  useEffect,
  useMemo,
  useState,
} from 'react';
import { SubmitHandler, useForm } from 'react-hook-form';
import {
  ContinueTrainParams,
  DataStateType,
  EvaluateParams,
  Job,
  Session,
  SlimVersion,
  TrainFromInitialWeightsParams,
  TrainFromScratchParams,
  TrainingParams,
} from '@tensorleap/api-client';
import { Dialog, Button, Divider, IconButton } from '../ui/mui';
import { FwArrowIcon, SaveAs, XClose } from '../ui/icons';
import { useVersionControl } from '../core/VersionControlContext';
import { isEvaluateOrTrainingProcessAllowed } from './helper-functions';
import api from '../core/api-client';
import { useSnackbar } from '../ui/SnackbarContext';
import { EarlyStep } from './RunModel/EarlyStep';
import {
  AddToDashboardFlag,
  SkipMetricsEstimationFlag,
} from './RunModel/Flags';
import { TrainPlan } from './RunModel/TrainPlan';
import { NewModelDetails } from './RunModel/NewModelDetails';
import { SelectSessionWithEpoch } from './RunModel/SelectSessionWithEpoch';
import { RunModelMethod, RunModelTabs } from './RunModel/RunModelTabs';
import { EvaluatePlan } from './RunModel/EvaluatePlan';
import clsx from 'clsx';
import { DIALOG_HEIGHT_CLASSES } from './common';
import { useCurrentProject } from '../core/CurrentProjectContext';
import SvgSaveCommit from '../ui/icons/SaveCommit';
import { ConfirmDialogManyButtons } from '../ui/atoms/ConfirmDialogManyButtons';
import { useVersionControlPanelContext } from '../core/VersionControlPanelContext';
import { useNetworkMapContext } from '../core/NetworkMapContext';
import { useLocalStorage } from '../core/useLocalStorage';

export const DEFAULT_BATCH_SIZE_KEY = 'DEFAULT_BATCH_SIZE';
export const SKIP_METRICS_ESTIMATION_KEY = 'SKIP_METRICS_ESTIMATION';

export type RunModelDialogProps = RunModelFormProps & {
  isOpen: boolean;
  initialRunModelMethod: RunModelMethod;
  onClose: () => void;
};

interface TrainDialogFormValues {
  modelName: string;
  batchSize: number;
  epochs: number;
  earlyStop: boolean;
  mainKPI?: 'LOSS' | 'MEAN' | 'ACCURACY';
  patiencePeriod?: number;
}

interface EvaluateDialogFormValues {
  batchSize: number;
  selectedSubsets: DataStateType[];
}

export function RunModelDialog({
  isOpen,
  onClose,
  initialRunModelMethod,
  initialVersion,
  initialSession,
}: RunModelDialogProps): JSX.Element {
  const [runModelMethod, setRunModelMethod] = useState<RunModelMethod>(
    initialRunModelMethod
  );

  const { handleSaveNewVersion } = useVersionControlPanelContext();
  const { isVersionChanged, handleSaveClicked } = useNetworkMapContext();

  const [continueWithoutSaving, setContinueWithoutSaving] = useState(false);

  const handleTrainMethodChange = useCallback((changedValue?: string) => {
    setRunModelMethod(changedValue as RunModelMethod);
  }, []);

  const CurrentForm: ComponentType<RunModelFormProps> = useMemo(() => {
    switch (runModelMethod) {
      case 'Train':
        return TrainFromScratch;
      case 'Continue':
        return ContinueTraining;
      case 'Retrain':
        return TrainFromInitialWeights;
      case 'Evaluate':
        return Evaluate;
    }
  }, [runModelMethod]);

  const handleContinueWithoutSaveClicked = useCallback(
    () => setContinueWithoutSaving(true),
    []
  );

  const saveAndContinueToTrainEvalDialog = useCallback(async () => {
    await handleSaveClicked(() => {
      setContinueWithoutSaving(true);
    });
  }, [handleSaveClicked]);

  const handleSaveAsNewClicked = useCallback(() => {
    onClose();
    setTimeout(handleSaveNewVersion, 100);
  }, [handleSaveNewVersion, onClose]);

  return (
    <Dialog
      open={isOpen}
      onClose={onClose}
      aria-labelledby="form-dialog-title"
      maxWidth="xl"
    >
      {continueWithoutSaving || !isVersionChanged ? (
        <TrainEvalDialog
          onClose={onClose}
          runModelMethod={runModelMethod}
          handleTrainMethodChange={handleTrainMethodChange}
          initialVersion={initialVersion}
          initialSession={initialSession}
          CurrentForm={CurrentForm}
        />
      ) : (
        <SuggestSaveDialog
          handleContinueWithoutSaveClicked={handleContinueWithoutSaveClicked}
          handleSaveClicked={saveAndContinueToTrainEvalDialog}
          handleSaveAsNewClicked={handleSaveAsNewClicked}
          onClose={onClose}
        />
      )}
    </Dialog>
  );
}

interface SuggestSaveDialogProps {
  handleContinueWithoutSaveClicked: () => void;
  handleSaveClicked: () => Promise<void>;
  handleSaveAsNewClicked: () => void;
  onClose: () => void;
}

function SuggestSaveDialog({
  handleContinueWithoutSaveClicked,
  handleSaveClicked,
  handleSaveAsNewClicked,
  onClose,
}: SuggestSaveDialogProps): JSX.Element {
  return (
    <ConfirmDialogManyButtons
      title={
        'Unsaved changes were recognized, a save must be performed in order to apply them'
      }
      isOpen={true}
      onClose={onClose}
      confirmButtons={[
        {
          onConfirm: handleContinueWithoutSaveClicked,
          confirmButtonText: 'Continue Without Saving',
          confirmButtonColor: 'red',
        },
        {
          onConfirm: handleSaveClicked,
          confirmButtonText: 'Save and continue',
          confirmButtonIcon: <SvgSaveCommit />,
          confirmButtonColor: 'blue',
        },
        {
          onConfirm: handleSaveAsNewClicked,
          confirmButtonText: 'Save As New Version',
          confirmButtonIcon: <SaveAs />,
          confirmButtonColor: 'blue',
        },
      ]}
    />
  );
}

interface TrainEvalDialogProps {
  onClose: () => void;
  runModelMethod: RunModelMethod;
  handleTrainMethodChange: (changedValue?: string) => void;
  initialVersion: SlimVersion;
  initialSession?: Session;
  CurrentForm: ComponentType<RunModelFormProps>;
}
function TrainEvalDialog({
  onClose,
  runModelMethod,
  handleTrainMethodChange,
  initialVersion,
  initialSession,
  CurrentForm,
}: TrainEvalDialogProps): JSX.Element {
  return (
    <div className={clsx('flex flex-col bg-gray-900', DIALOG_HEIGHT_CLASSES)}>
      <div className="flex flex-row h-16 px-8 justify-between items-center bg-gray-850">
        <h5 className="font-normal text-2xl leading-snug uppercase">
          train | eval
        </h5>
        <IconButton onClick={onClose}>
          <XClose />
        </IconButton>
      </div>

      <RunModelTabs value={runModelMethod} onChange={handleTrainMethodChange} />

      <CurrentForm
        initialVersion={initialVersion}
        initialSession={initialSession}
        onClose={onClose}
      />
    </div>
  );
}

type DialogBottomProps = {
  onClose: () => void;
  isValid: boolean;
  formId: string;
  submitText: ReactNode;
};

function DialogBottom({
  onClose,
  isValid,
  formId,
  submitText,
}: DialogBottomProps) {
  return (
    <div className="flex flex-row h-20 px-8 justify-end items-center bg-black">
      <Button
        variant="outlined"
        color="primary"
        className="mr-4"
        onClick={onClose}
      >
        <h6 className="font-medium text-xl uppercase leading-normal">cancel</h6>
      </Button>
      <Button
        data-track-action="train-dialog-submitted"
        type="submit"
        form={formId}
        variant="contained"
        color="primary"
        className="bg-primary-500"
        disabled={!isValid}
      >
        <h6 className="font-medium text-xl leading-normal uppercase">
          {submitText}
        </h6>
        <FwArrowIcon className="ml-2" />
      </Button>
    </div>
  );
}

function defaultTrainFormValues(
  defaultBatchSize: number
): TrainDialogFormValues {
  return {
    modelName: '',
    batchSize: defaultBatchSize,
    epochs: 1,
    earlyStop: true,
    mainKPI: 'LOSS',
    patiencePeriod: 8,
  };
}

const FORM_ID = 'TRAIN_FORM_ID';
const FORM_CLASSES = 'flex-1 w-[1300px] gap-6 p-8 flex flex-row';
const SUBMIT_TRAIN_TEXT = 'train';

type RunModelFormProps = {
  onClose: () => void;
  initialVersion: SlimVersion;
  initialSession?: Session;
};

type SubmitWrapperOptions = { addToDashboard: boolean };

type SubmitWrapperReturn = (
  submitFunc: () => Promise<Job>,
  options: SubmitWrapperOptions
) => Promise<void>;

function useSubmitWrapper(onClose: () => void): SubmitWrapperReturn {
  const { enqueueSnackbar } = useSnackbar();
  const {
    refetch: fetchVersions,
    toggleSelectSessionRun,
  } = useVersionControl();

  const errorSnack = useCallback(
    (error: Error) => {
      enqueueSnackbar(error.message || error.toString(), { variant: 'error' });
    },
    [enqueueSnackbar]
  );

  return useCallback(
    async (submitFunc: () => Promise<Job>, { addToDashboard }) => {
      try {
        const job = await submitFunc();
        onClose();
        await fetchVersions();
        const sessionRunId = job.sessionRunId;
        if (addToDashboard && sessionRunId) {
          toggleSelectSessionRun(sessionRunId);
        }
      } catch (e) {
        console.error(e);
        if (e instanceof Error) {
          errorSnack(e);
        }
      }
    },
    [onClose, errorSnack, toggleSelectSessionRun, fetchVersions]
  );
}

export function TrainFromScratch({
  onClose,
  initialVersion,
}: RunModelFormProps): JSX.Element {
  const { versions } = useVersionControl();
  const [selectedVersion, setSelectedVersion] = useState(initialVersion);
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const [defaultBatchSize, _] = useLocalStorage(DEFAULT_BATCH_SIZE_KEY, 8);

  const [addToDashboard, setAddToDashboard] = useState(true);
  const form = useForm<TrainDialogFormValues>({
    mode: 'onChange',
    defaultValues: defaultTrainFormValues(defaultBatchSize),
  });

  const {
    handleSubmit,
    formState: { isValid },
  } = form;

  const submitWrapper = useSubmitWrapper(onClose);

  const onSubmit: SubmitHandler<TrainDialogFormValues> = async ({
    modelName,
    batchSize,
    epochs,
    earlyStop,
    patiencePeriod,
  }) => {
    const trainingParams: TrainingParams = {
      epochs,
      batch_size: batchSize,
      ...(earlyStop && {
        early_stop_params: { patience: patiencePeriod as number },
      }),
    };

    const trainFromScratchParams: TrainFromScratchParams = {
      versionId: selectedVersion.cid,
      projectId,
      sessionName: modelName,
      trainingParams,
    };

    submitWrapper(
      async () => {
        return await api.trainFromScratch(trainFromScratchParams);
      },
      { addToDashboard }
    );
  };

  return (
    <>
      <form
        className={FORM_CLASSES}
        id={FORM_ID}
        onSubmit={handleSubmit(onSubmit)}
      >
        <div className="flex-1 flex flex-col">
          <TrainPlan form={form} />
          <Divider orientation="horizontal" />
          <EarlyStep form={form} />
          <Divider orientation="horizontal" />
          <div className="flex">
            <AddToDashboardFlag
              value={addToDashboard}
              onChange={setAddToDashboard}
            />
          </div>
        </div>
        <Divider orientation="vertical" />

        <NewModelDetails
          className="flex-1"
          versions={versions}
          selectedVersion={selectedVersion}
          onVersionChange={setSelectedVersion}
          onModelChange={null}
          form={form}
        />
      </form>
      <DialogBottom
        onClose={onClose}
        submitText={SUBMIT_TRAIN_TEXT}
        isValid={isValid}
        formId={FORM_ID}
      />
    </>
  );
}

export function TrainFromInitialWeights({
  onClose,
  initialVersion,
  initialSession,
}: RunModelFormProps): JSX.Element {
  const { versions } = useVersionControl();
  const [selectedVersion, setSelectedVersion] = useState(initialVersion);
  const [addToDashboard, setAddToDashboard] = useState(true);
  const [selectedSession, setSelectedSession] = useState<Session | undefined>(
    initialSession
  );
  const [selectedEpoch, setSelectedEpoch] = useState<number>();
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const [defaultBatchSize, _] = useLocalStorage(DEFAULT_BATCH_SIZE_KEY, 8);

  const form = useForm<TrainDialogFormValues>({
    mode: 'onChange',
    defaultValues: defaultTrainFormValues(defaultBatchSize),
  });

  const {
    handleSubmit,
    formState: { isValid },
  } = form;
  const isTrainProcessAllowed = isEvaluateOrTrainingProcessAllowed({
    isValid,
    selectedSession,
    selectedVersion,
  });

  const submitWrapper = useSubmitWrapper(onClose);

  const onSubmit: SubmitHandler<TrainDialogFormValues> = async ({
    modelName,
    batchSize,
    epochs,
    earlyStop,
    patiencePeriod,
  }) => {
    const trainingParams: TrainingParams = {
      epochs,
      batch_size: batchSize,
      ...(earlyStop && {
        early_stop_params: { patience: patiencePeriod as number },
      }),
    };

    if (!selectedSession || selectedEpoch === undefined) {
      console.warn(
        'How trainFromInitialWeights was submitted without a selectedModel/selectedEpoch?'
      );
      return;
    }

    const trainFromInitialWeightsParams: TrainFromInitialWeightsParams = {
      versionId: selectedVersion.cid,
      projectId,
      fromSessionId: selectedSession.cid,
      fromEpoch: selectedEpoch,
      modelName,
      trainingParams,
    };

    submitWrapper(
      () => api.trainFromInitialWeights(trainFromInitialWeightsParams),
      { addToDashboard }
    );
  };

  const getVersionDisabledMsg = useCallback(
    (version: SlimVersion) => {
      if (!selectedVersion) return 'Please select a version first';
      // for now we allow all version with undefined hash too
      if (version.hash !== selectedVersion.hash)
        return `The network of this version not match the selected network version (${selectedVersion.branchName})`;
    },
    [selectedVersion]
  );

  return (
    <>
      <form
        className={FORM_CLASSES}
        id={FORM_ID}
        onSubmit={handleSubmit(onSubmit)}
      >
        <div className="flex-1 flex flex-col">
          <TrainPlan form={form} />
          <Divider orientation="horizontal" />
          <EarlyStep form={form} />
          <Divider orientation="horizontal" />
          <div className="flex">
            <AddToDashboardFlag
              value={addToDashboard}
              onChange={setAddToDashboard}
            />
          </div>
        </div>
        <Divider orientation="vertical" />

        <div className="flex flex-col flex-1 gap-4">
          <NewModelDetails
            versions={versions}
            selectedVersion={selectedVersion}
            onModelChange={setSelectedSession}
            selectedSession={selectedSession}
            onVersionChange={setSelectedVersion}
            form={form}
          />
          <span className="text-base">Copy from session:</span>

          <SelectSessionWithEpoch
            session={selectedSession}
            onSessionChange={setSelectedSession}
            getVersionDisabledMsg={getVersionDisabledMsg}
            onEpochChange={setSelectedEpoch}
            epoch={selectedEpoch}
            initialSelectedVersion={selectedVersion}
          />
        </div>
      </form>
      <DialogBottom
        onClose={onClose}
        submitText={SUBMIT_TRAIN_TEXT}
        isValid={isTrainProcessAllowed}
        formId={FORM_ID}
      />
    </>
  );
}

export function ContinueTraining({
  onClose,
  initialVersion,
  initialSession,
}: RunModelFormProps): JSX.Element {
  const [selectedVersion, setSelectedVersion] = useState(initialVersion);
  const [selectedSession, setSelectedSession] = useState<Session | undefined>(
    initialSession
  );
  const [selectedEpoch, setSelectedEpoch] = useState<number>();
  const [addToDashboard, setAddToDashboard] = useState(true);
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const [defaultBatchSize, _] = useLocalStorage(DEFAULT_BATCH_SIZE_KEY, 8);

  const form = useForm<TrainDialogFormValues>({
    mode: 'onChange',
    defaultValues: defaultTrainFormValues(defaultBatchSize),
  });

  const {
    handleSubmit,
    formState: { isValid },
  } = form;

  const isTrainProcessAllowed = isEvaluateOrTrainingProcessAllowed({
    isValid,
    selectedSession,
    selectedVersion,
  });

  const submitWrapper = useSubmitWrapper(onClose);

  const onSubmit: SubmitHandler<TrainDialogFormValues> = async ({
    batchSize,
    epochs,
    earlyStop,
    patiencePeriod,
  }) => {
    const trainingParams: TrainingParams = {
      epochs,
      batch_size: batchSize,
      ...(earlyStop && {
        early_stop_params: { patience: patiencePeriod as number },
      }),
    };

    if (!selectedSession || selectedEpoch === undefined) {
      console.warn(
        'How continueTraining was submitted without a selectedModel/selectedEpoch?'
      );
      return;
    }

    const continueTrainParams: ContinueTrainParams = {
      projectId,
      versionId: selectedVersion.cid,
      sessionId: selectedSession.cid,
      trainingParams,
      fromEpoch: selectedEpoch,
    };

    submitWrapper(() => api.continueTrain(continueTrainParams), {
      addToDashboard,
    });
  };

  return (
    <>
      <form
        className={FORM_CLASSES}
        id={FORM_ID}
        onSubmit={handleSubmit(onSubmit)}
      >
        <div className="flex-1 flex flex-col">
          <TrainPlan form={form} isAdditionalEpoch />
          <Divider orientation="horizontal" />
          <EarlyStep form={form} />
          <Divider orientation="horizontal" />
          <div className="flex">
            <AddToDashboardFlag
              value={addToDashboard}
              onChange={setAddToDashboard}
            />
          </div>
        </div>
        <Divider orientation="vertical" />

        <div className="flex flex-col flex-1 gap-4">
          <span className="text-base">Continue from model:</span>
          <SelectSessionWithEpoch
            session={selectedSession}
            onSessionChange={setSelectedSession}
            onVersionChange={setSelectedVersion}
            onEpochChange={setSelectedEpoch}
            epoch={selectedEpoch}
            epochReadonly
            initialSelectedVersion={selectedVersion}
          />
        </div>
      </form>
      <DialogBottom
        onClose={onClose}
        submitText={SUBMIT_TRAIN_TEXT}
        isValid={isTrainProcessAllowed}
        formId={FORM_ID}
      />
    </>
  );
}

export function Evaluate({
  onClose,
  initialVersion,
  initialSession,
}: RunModelFormProps): JSX.Element {
  const [selectedVersion, setSelectedVersion] = useState(initialVersion);
  const [selectedSession, setSelectedSession] = useState<Session | undefined>(
    initialSession
  );
  const [selectedEpoch, setSelectedEpoch] = useState<number>();
  const [nameWasUpdated, setNameWasUpdated] = useState(false);
  const [evaluationName, setEvaluationName] = useState<string>('');

  const [createNewSession, setCreateNewSession] = useState(false);

  const calculateBaseEvaluationName = useCallback(() => {
    const sessionRunCount = selectedSession?.sessionRuns?.length || 0;
    return `${selectedSession?.modelName || 'session'}-Evaluate-${
      sessionRunCount + 1
    }`;
  }, [selectedSession?.modelName, selectedSession?.sessionRuns?.length]);

  useEffect(() => {
    if (!nameWasUpdated) {
      setEvaluationName(calculateBaseEvaluationName());
    }
  }, [nameWasUpdated, calculateBaseEvaluationName]);

  const updateEvaluationName = useCallback((evaluationName?: string) => {
    setNameWasUpdated(true);
    setEvaluationName(evaluationName || '');
  }, []);

  const [addToDashboard, setAddToDashboard] = useState(true);
  const [skipMetricsEstimation, setSkipMetricsEstimation] = useLocalStorage(
    SKIP_METRICS_ESTIMATION_KEY,
    false
  );
  const { fetchValidProjectCid } = useCurrentProject();
  const projectId = fetchValidProjectCid();

  const submitWrapper = useSubmitWrapper(onClose);

  const [defaultBatchSize, _] = useLocalStorage(DEFAULT_BATCH_SIZE_KEY, 8);

  const form = useForm<EvaluateDialogFormValues>({
    mode: 'onChange',
    defaultValues: {
      batchSize: defaultBatchSize,
      selectedSubsets: [
        DataStateType.Training,
        DataStateType.Validation,
        DataStateType.Test,
        DataStateType.Unlabeled,
      ],
    },
  });

  const {
    handleSubmit,
    formState: { isValid },
  } = form;

  const isEvaluateProcessAllowed = isEvaluateOrTrainingProcessAllowed({
    isValid,
    selectedSession,
    selectedVersion,
    createNewSession,
  });

  const onSubmit: SubmitHandler<EvaluateDialogFormValues> = async ({
    batchSize,
    selectedSubsets,
  }) => {
    if (selectedSubsets.length === 0 || selectedVersion === undefined) {
      console.error(
        'how evaluate was submitted without selecting a version or subsets?'
      );
      return;
    }

    if (
      !createNewSession &&
      (selectedSession === undefined || selectedEpoch === undefined)
    ) {
      console.error(
        'how evaluate an existing sesion was submitted without selecting a session or epoch?'
      );
    }

    const evaluateRequest: EvaluateParams = {
      projectId,
      versionId: selectedVersion.cid,
      sessionId: selectedSession?.cid,
      batchSize,
      dataStates: selectedSubsets,
      evaluatedEpoch: selectedEpoch,
      name: evaluationName,
      skipMetricsEstimation,
    };
    submitWrapper(() => api.evaluate(evaluateRequest), {
      addToDashboard,
    });
  };

  return (
    <>
      <form
        id={FORM_ID}
        onSubmit={handleSubmit(onSubmit)}
        className={FORM_CLASSES}
      >
        <div className="flex flex-col gap-8 flex-1 justify-start">
          <EvaluatePlan form={form} />
          <Divider orientation="horizontal" />
          <div className="flex flex-col">
            <SkipMetricsEstimationFlag
              value={skipMetricsEstimation}
              onChange={setSkipMetricsEstimation}
            />
            <AddToDashboardFlag
              value={addToDashboard}
              onChange={setAddToDashboard}
            />
          </div>
        </div>
        <Divider orientation="vertical" />
        <div className="flex flex-col flex-1 gap-4">
          <span className="text-base">
            Please select which model to evaluated:
          </span>
          <SelectSessionWithEpoch
            session={selectedSession}
            onSessionChange={setSelectedSession}
            onVersionChange={setSelectedVersion}
            onEpochChange={setSelectedEpoch}
            epoch={selectedEpoch}
            initialSelectedVersion={selectedVersion}
            name={evaluationName}
            setName={updateEvaluationName}
            nameLabel="EVALUATION NAME"
            allowCreateSession={true}
            createNewSession={createNewSession}
            setCreateNewSession={setCreateNewSession}
          />
        </div>
      </form>

      <DialogBottom
        onClose={onClose}
        isValid={isEvaluateProcessAllowed}
        submitText="evaluate"
        formId={FORM_ID}
      />
    </>
  );
}
