import { MedicalReportCodeAddedBy, MedicalReportCodeState, MedicalReportProcess } from "@models/medical-report-process";


interface MedicalReportMetrics {
  precision: number;
  recall: number;
  f1: number;
}

export function calculateMetrics(trueCodes: string[], predictedCodes: string[]): MedicalReportMetrics {
  const yTrue = trueCodes;
  const yPred = predictedCodes;

  // The number of codes that occur in both 'trueCodes' and 'predictedCodes'
  const TP = new Set(yTrue).size && new Set(yPred).size && [...new Set(yTrue)].filter(x => new Set(yPred).has(x)).length;
  // The number of codes in 'predictedCodes' that do not occur in 'trueCodes'.
  const FP = new Set(yPred).size && new Set(yTrue).size && [...new Set(yPred)].filter(x => !new Set(yTrue).has(x)).length;
  // The number of codes in 'trueCodes' that do not appear in 'predictedCodes'.
  const FN = new Set(yTrue).size && new Set(yPred).size && [...new Set(yTrue)].filter(x => !new Set(yPred).has(x)).length;

  // Calculation of precision, recall and F1 score
  const precision = TP / (TP + FP) || 0;
  const recall = TP / (TP + FN) || 0;
  const f1 = 2 * (precision * recall) / (precision + recall) || 0;

  console.log(`(Recall, Precision, F1-Score): (${recall.toFixed(3)}, ${precision.toFixed(3)}, ${f1.toFixed(3)})`);

  return { precision, recall, f1 };
}

function averageMetrics(...metricsArray: (MedicalReportMetrics | undefined)[]): MedicalReportMetrics {
  const filteredMetricsArray = metricsArray.filter((value) => value) as MedicalReportMetrics[];

  const total = filteredMetricsArray.length;

  const sum = filteredMetricsArray.filter((value) => value).reduce((acc, metrics) => {
    acc.precision += metrics.precision;
    acc.recall += metrics.recall;
    acc.f1 += metrics.f1;
    return acc;
  }, { precision: 0, recall: 0, f1: 0 });

  return {
    precision: total ? sum.precision / total : 0,
    recall: total ? sum.recall / total : 0,
    f1: total ? sum.f1 / total : 0,
  };
}

export function calculateMedicalReportMetrics(
  reportProcess: MedicalReportProcess,
): { icd?: MedicalReportMetrics; ops?: MedicalReportMetrics; average: MedicalReportMetrics } {
  const icdCodesTrue = reportProcess.icdCodes?.filter((code) => code.status === MedicalReportCodeState.CHECKED).map((code) => code.name) || [];
  const icdCodesTest = reportProcess.icdCodes?.filter((code) => code.addedBy === MedicalReportCodeAddedBy.PLATFORM).map((code) => code.name) || [];
  const opsCodesTrue = reportProcess.opsCodes?.filter((code) => code.status === MedicalReportCodeState.CHECKED).map((code) => code.name) || [];
  const opsCodesTest = reportProcess.opsCodes?.filter((code) => code.addedBy === MedicalReportCodeAddedBy.PLATFORM).map((code) => code.name) || [];

  const icd = icdCodesTrue.length ? calculateMetrics(icdCodesTrue, icdCodesTest) : undefined;
  const ops = opsCodesTrue.length ? calculateMetrics(opsCodesTrue, opsCodesTest) : undefined;
  const average = averageMetrics(ops, icd);

  return { icd, ops, average };
}
