import { NormalizedLandmark } from "@mediapipe/tasks-vision";
import { ControlOptions } from "../control-options";
import {
  POSE_LANDMARKS,
  POSE_LANDMARKS_LEFT,
  POSE_LANDMARKS_RIGHT,
} from "./landmarks";
import { JointAngle, getAnglesForJointsPair } from "./pose-utils";

export enum ScoreableJointPairNames {
  elbows = "elbows",
  shoulders = "shoulders",
  hips = "hips",
  knees = "knees",
}

export const scorableJointSets = new Map<ScoreableJointPairNames, number[][]>([
  [
    ScoreableJointPairNames.knees,
    [
      [
        POSE_LANDMARKS.RIGHT_HIP,
        POSE_LANDMARKS_RIGHT.RIGHT_KNEE,
        POSE_LANDMARKS_RIGHT.RIGHT_ANKLE,
      ],
      [
        POSE_LANDMARKS.LEFT_HIP,
        POSE_LANDMARKS_LEFT.LEFT_KNEE,
        POSE_LANDMARKS_LEFT.LEFT_ANKLE,
      ],
    ],
  ],
  [
    ScoreableJointPairNames.hips,
    [
      [
        POSE_LANDMARKS.RIGHT_SHOULDER,
        POSE_LANDMARKS.RIGHT_HIP,
        POSE_LANDMARKS_RIGHT.RIGHT_KNEE,
      ],
      [
        POSE_LANDMARKS.LEFT_SHOULDER,
        POSE_LANDMARKS.LEFT_HIP,
        POSE_LANDMARKS_LEFT.LEFT_KNEE,
      ],
    ],
  ],
  [
    ScoreableJointPairNames.shoulders,
    [
      [
        POSE_LANDMARKS.RIGHT_HIP,
        POSE_LANDMARKS.RIGHT_SHOULDER,
        POSE_LANDMARKS.RIGHT_ELBOW,
      ],
      [
        POSE_LANDMARKS.LEFT_HIP,
        POSE_LANDMARKS.LEFT_SHOULDER,
        POSE_LANDMARKS.LEFT_ELBOW,
      ],
    ],
  ],
  [
    ScoreableJointPairNames.elbows,
    [
      [
        POSE_LANDMARKS.RIGHT_SHOULDER,
        POSE_LANDMARKS.RIGHT_ELBOW,
        POSE_LANDMARKS.RIGHT_WRIST,
      ],
      [
        POSE_LANDMARKS.LEFT_SHOULDER,
        POSE_LANDMARKS.LEFT_ELBOW,
        POSE_LANDMARKS.LEFT_WRIST,
      ],
    ],
  ],
]);

export enum JointPairNames {
  wrists = "wrists",
  elbows = "elbows",
  shoulders = "shoulders",
  hips = "hips",
  knees = "knees",
}

export type JointPairFrameEvaluation = {
  score: number;
  confidence: number;
};

/**
 * Evaluation of a single frame for a pose. The score is the average of the angles between the scoreable joints of the pose.
 */
export type PoseFrame = {
  timestamp?: number;
  score: number;
  confidence: number;
  pairs: { [jointPair in ScoreableJointPairNames]?: JointPairFrameEvaluation };
};

type PoseEvaluationJointPairs = {
  [jointPair in ScoreableJointPairNames]?: {
    score: number;
    confidence: number;
  };
};

/**
 * Evaluation of a pose based on an aggregation of all frames for which the pose was observed.
 */
export type PoseEvaluation = {
  /**
   * The average score of all frames for which the pose was observed.
   */
  score: number;
  confidence: number;
  framesCount: number;
  qualifyingFramesCount: number;
  /**
   * Value from 0 to 1 representing the normalized angle between the direction the
   * user is facing and the direction the camera is facing. 0 means the user is facing
   * away from the camera, 1 means the user is facing the camera, and 0.5 means the user
   * is facing perpendicular to the camera.
   */
  facingCamera?: number;
  jointPairs: PoseEvaluationJointPairs;
};

export interface IWithConfidence {
  confidence: number;
}

export class PoseScoring {
  options: ControlOptions;

  constructor(options: ControlOptions) {
    this.options = options;
  }
  /**
   * Return true if the items has sufficiently high confidence to be considered for pose evaluation
   */
  qualifyConfidence<Type extends IWithConfidence>(items: Type[]): Type[] {
    return items.filter(
      (item: Type) =>
        item && item.confidence >= this.options.landmarkPoseVisibilityThreshold
    );
  }

  /**
   * Evaluate a single frame for a pose.
   */
  getFrameScoreForPoseLandmarks(
    poseLandmarks: NormalizedLandmark[],
    timestamp: number,
    jointAnglesMap: { [key: number]: JointAngle }
  ): PoseFrame {
    let i = 0;
    const qualifiedScores: Array<JointPairFrameEvaluation> = [];
    const pairs: { [key: string]: JointPairFrameEvaluation } = {};
    for (const [jointsPairName, jointsPair] of scorableJointSets) {
      const angles = getAnglesForJointsPair(
        jointsPair,
        poseLandmarks,
        this.options.landmarkPoseVisibilityThreshold,
        jointAnglesMap
      );
      const jointScore: JointPairFrameEvaluation = {
        score:
          angles.reduce((a, b) => a + b.angle, 0) / Math.PI / angles.length ||
          0,
        confidence:
          angles.reduce((a, b) => a + b.confidence, 0) / angles.length || 0,
      };
      if (angles.length > 0) {
        qualifiedScores.push(jointScore);
        pairs[jointsPairName] = jointScore;
      }
      i++;
    }
    const frame: PoseFrame = {
      timestamp,
      pairs,
      score:
        qualifiedScores.reduce((a, b) => a + b.score, 0) /
        qualifiedScores.length,
      confidence:
        qualifiedScores.reduce((a, b) => a + b.confidence, 0) /
        qualifiedScores.length,
    };
    return frame;
  }

  /**
   * Evaluate a set of frames to aggregate scores over time for a pose event
   */
  evaluateFrames(sourceFrames: PoseFrame[]): PoseEvaluation {
    const pairs: PoseEvaluationJointPairs = {};
    for (const [jointsPairName, jointsPair] of scorableJointSets) {
      const evaluationsForAllFrames = sourceFrames.map((f: PoseFrame) => {
        const jointPairFrameEvaluation: JointPairFrameEvaluation | undefined =
          f.pairs && f.pairs[jointsPairName];
        return jointPairFrameEvaluation;
      });
      const jointsPairs: JointPairFrameEvaluation[] =
        evaluationsForAllFrames.filter(
          (f): f is JointPairFrameEvaluation => !!f
        );
      // for each joint pair, aggregate the scores across all frames using qualifying frames
      const qualifyingJointScores: JointPairFrameEvaluation[] =
        this.qualifyConfidence(jointsPairs);
      if (qualifyingJointScores.length > 0) {
        const mean =
          qualifyingJointScores.reduce((a, b) => a + b.score, 0) /
          qualifyingJointScores.length;
        pairs[jointsPairName] = {
          score: mean,
          confidence:
            qualifyingJointScores.reduce((a, b) => a + b.confidence, 0) /
            qualifyingJointScores.length,
        };
      }
    }

    // aggregate the scores across all frames using qualifying frames
    const qualifyingFrames = this.qualifyConfidence(sourceFrames);

    const mean =
      qualifyingFrames.reduce((accumulator, b) => accumulator + b.score, 0) /
      qualifyingFrames.length;
    return {
      score: mean,
      confidence:
        qualifyingFrames.reduce(
          (accumulator, b) => accumulator + b.confidence,
          0
        ) / qualifyingFrames.length,
      framesCount: sourceFrames.length,
      qualifyingFramesCount: qualifyingFrames.length,
      jointPairs: pairs,
    };
  }
}
