import { Controller, UseFormReturn } from 'react-hook-form';
import { Input } from '../../ui/atoms/Input';
import { parsePositiveInt } from '../helper-functions';
import { DEFAULT_BATCH_SIZE_KEY } from '../RunModel';
import { useLocalStorage } from '../../core/useLocalStorage';

type TrainPlanFields = {
  batchSize: number;
  epochs: number;
};

type TrainPlanProps<Fields extends TrainPlanFields> = {
  form: UseFormReturn<Fields>;
  isAdditionalEpoch?: boolean;
};

export function TrainPlan<Fields extends TrainPlanFields>({
  form,
  isAdditionalEpoch,
}: TrainPlanProps<Fields>) {
  const { control } = (form as unknown) as UseFormReturn<TrainPlanFields>;
  const [_, setDefaultBatchSize] = useLocalStorage(DEFAULT_BATCH_SIZE_KEY);

  return (
    <div className="flex w-full pb-8 flex-col">
      <span className="font-normal text-base tracking-normal">
        Please define the training plan:
      </span>
      <div className="flex gap-2 flex-row mt-4">
        <Controller
          name="batchSize"
          control={control}
          rules={{ required: true }}
          render={({ field, fieldState }) => (
            <Input
              {...field}
              containerProps={{ className: 'flex-1' }}
              type="number"
              label="BATCH SIZE"
              error={fieldState.error && 'BatchSize is Required'}
              onChange={(e) => {
                const numberValue = parsePositiveInt(e.target.value);
                field.onChange(numberValue);
                if (numberValue) setDefaultBatchSize(numberValue);
              }}
            />
          )}
        />

        <Controller
          name="epochs"
          control={control}
          rules={{ required: true }}
          render={({ field, fieldState }) => (
            <Input
              {...field}
              containerProps={{ className: 'flex-1' }}
              type="number"
              label={
                isAdditionalEpoch
                  ? 'NUMBER OF ADDITIONAL EPOCHS'
                  : 'NUMBER OF EPOCHS'
              }
              error={fieldState.error && 'Epochs is Required'}
              onChange={(e) => {
                field.onChange(
                  parsePositiveInt((e.target as HTMLInputElement).value)
                );
              }}
            />
          )}
        />
      </div>
    </div>
  );
}
