import { useMemo } from 'react';
import useSWR, { KeyedMutator } from 'swr';
import { REFRESH_INTERVALS } from './consts';
import { useCurrentProject } from '../CurrentProjectContext';
import { useSelectedSessionRuns } from '../../dashboard/add-analysis/panes/useSelectedModels';
import api from '../api-client';
import { PredictionLabels } from '@tensorleap/api-client';

export interface UseGetConfusionMatrixLabels {
  labelsByPrediction: PredictionLabels[] | undefined;
  error?: Error;
  isLoading: boolean;
  refetch: KeyedMutator<PredictionLabels[]>;
}

export function useGetConfusionMatrixLabels(): UseGetConfusionMatrixLabels {
  const { currentProjectId } = useCurrentProject();
  const selectedSessionRuns = useSelectedSessionRuns();
  const sessionRunIds = useMemo(
    () => Array.from(new Set(selectedSessionRuns.map(({ cid }) => cid))),
    [selectedSessionRuns]
  );

  const { data: labelsByPrediction, error, mutate } = useSWR(
    `${JSON.stringify({ sessionRunIds, currentProjectId })}`,
    async () => {
      if (!currentProjectId || !sessionRunIds.length) {
        return [];
      }
      const { labelsByPrediction } = await api.getConfusionMatrixLabels({
        sessionRunIds,
        projectId: currentProjectId,
      });
      return labelsByPrediction;
    },
    { refreshInterval: REFRESH_INTERVALS.charts }
  );

  return useMemo(
    () => ({
      labelsByPrediction,
      error,
      isLoading: !error && !labelsByPrediction,
      refetch: mutate,
    }),
    [error, labelsByPrediction, mutate]
  );
}
