import { NormalizedLandmark } from "@mediapipe/tasks-vision";
import { JointAngle, calculateDistance } from "./pose-utils";
import { CriterionForEvaluator } from "./PoseDefinitionInterpreter";
import {
  PoseDetectionFeedback,
  ThresholdOperator,
} from "./pose-detection-feedback";
import {
  calculateCenterOfMass,
  checkJointsBent,
  checkJointsStraight,
  checkLandmarksAbove,
  checkLandmarksApart,
  checkLandmarksSameHeight,
  checkLandmarksStacked,
  combineFeedbackSides,
  combineFeedbackSidesOr,
  formatAngleValue,
  getSideFeedback,
  validateThresholds,
} from "./pose-detection-feedback-utils";
import { POSE_LANDMARKS } from "./landmarks";

/**
 * Evaluate criteria for a pose, such as whether joints are straight or bent, or landmarks are
 * stacked or apart. Each check method returns true if the criteria is met, and updates the
 * poseDetectionFeedback object with the criteria used to evaluate the pose.
 */
export class PoseCriteriaEvaluator {
  private isInPose: boolean;
  private poseLandmarks: NormalizedLandmark[];
  private jointAnglesMap: { [key: number]: JointAngle };
  private poseDetectionFeedback: PoseDetectionFeedback;
  private visibilityThreshold: number;
  private facingDirection: { x: number; z: number } | null = null;

  constructor(
    isInPose: boolean,
    poseLandmarks: NormalizedLandmark[],
    jointAnglesMap: { [key: number]: JointAngle },
    poseDetectionFeedback: PoseDetectionFeedback,
    visibilityThreshold: number
  ) {
    this.isInPose = isInPose;
    this.poseLandmarks = poseLandmarks;
    this.jointAnglesMap = jointAnglesMap;
    this.poseDetectionFeedback = poseDetectionFeedback;
    this.visibilityThreshold = visibilityThreshold;
  }

  /**
   * @returns the direction the person is facing (based on the position of their shoulders), as a unit vector.
   */
  private getFacingDirection() {
    if (!this.facingDirection) {
      const shoulderVector = {
        x:
          this.poseLandmarks[POSE_LANDMARKS.RIGHT_SHOULDER].x -
          this.poseLandmarks[POSE_LANDMARKS.LEFT_SHOULDER].x,
        z:
          this.poseLandmarks[POSE_LANDMARKS.RIGHT_SHOULDER].z -
          this.poseLandmarks[POSE_LANDMARKS.LEFT_SHOULDER].z,
      };
      const rawFacingDirection = { x: -shoulderVector.z, z: shoulderVector.x };
      const magnitude = Math.sqrt(
        rawFacingDirection.x ** 2 + rawFacingDirection.z ** 2
      );
      this.facingDirection = {
        x: rawFacingDirection.x / magnitude,
        z: rawFacingDirection.z / magnitude,
      };
    }
    return this.facingDirection;
  }

  checkJointsStraight(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLeft: [number, number, number],
    landmarksRight: [number, number, number],
    startThreshold = Math.PI * 0.85,
    endThreshold = Math.PI * 0.75
  ): boolean {
    return checkJointsStraight(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.jointAnglesMap,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarksLeft,
      landmarksRight,
      startThreshold,
      endThreshold
    );
  }

  checkJointsBent(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLeft: [number, number, number],
    landmarksRight: [number, number, number],
    startThreshold = Math.PI * 0.6,
    endThreshold = Math.PI * 0.5
  ): boolean {
    return checkJointsBent(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.jointAnglesMap,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarksLeft,
      landmarksRight,
      startThreshold,
      endThreshold
    );
  }

  checkOneJointStraight(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    side: "left" | "right" | undefined,
    landmarksLeft: [number, number, number],
    landmarksRight: [number, number, number],
    startThreshold = Math.PI * 0.7,
    endThreshold = Math.PI * 0.6
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const base = {
      criterion,
      min: 0,
      max: Math.PI,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: formatAngleValue,
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    const jointAngles = [landmarksLeft[1], landmarksRight[1]].map(
      (key) => this.jointAnglesMap[key]
    );
    const [leftJointAngle, rightJointAngle] = jointAngles;
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSidesOr(
      base,
      getSideFeedback(leftJointAngle?.angle, threshold, [
        this.poseLandmarks[landmarksLeft[0]],
        this.poseLandmarks[landmarksLeft[1]],
        this.poseLandmarks[landmarksLeft[2]],
      ]),
      getSideFeedback(rightJointAngle?.angle, threshold, [
        this.poseLandmarks[landmarksRight[0]],
        this.poseLandmarks[landmarksRight[1]],
        this.poseLandmarks[landmarksRight[2]],
      ]),
      side
    );
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkOneJointBent(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    side: "left" | "right" | undefined,
    landmarksLeft: [number, number, number],
    landmarksRight: [number, number, number],
    startThreshold = Math.PI * 0.25,
    endThreshold = Math.PI * 0.2
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const base = {
      criterion,
      min: 0,
      max: Math.PI,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.LessThan,
      visibilityThreshold: this.visibilityThreshold,
      format: formatAngleValue,
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    const jointAngles = [landmarksLeft[1], landmarksRight[1]].map(
      (key) => this.jointAnglesMap[key]
    );
    const [leftJointAngle, rightJointAngle] = jointAngles;
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSidesOr(
      base,
      getSideFeedback(Math.PI - leftJointAngle?.angle, threshold, [
        this.poseLandmarks[landmarksLeft[0]],
        this.poseLandmarks[landmarksLeft[1]],
        this.poseLandmarks[landmarksLeft[2]],
      ]),
      getSideFeedback(Math.PI - rightJointAngle?.angle, threshold, [
        this.poseLandmarks[landmarksRight[0]],
        this.poseLandmarks[landmarksRight[1]],
        this.poseLandmarks[landmarksRight[2]],
      ]),
      side
    );
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkLandmarksStacked(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLeft: [number, number],
    landmarksRight: [number, number],
    startThreshold = -0.25,
    endThreshold = -0.3
  ): boolean {
    return checkLandmarksStacked(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarksLeft,
      landmarksRight,
      startThreshold,
      endThreshold
    );
  }

  checkLandmarksAbove(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLower: [number, number],
    landmarksUpper: [number, number],
    startThreshold = 0,
    endThreshold = -0.1
  ): boolean {
    return checkLandmarksAbove(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarksLower,
      landmarksUpper,
      startThreshold,
      endThreshold
    );
  }

  checkLandmarksAboveLowest(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLower: [number, number],
    landmarksUpper: [number, number],
    startThreshold = 0.3,
    endThreshold = 0.2
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const leftLower = this.poseLandmarks[landmarksLower[0]];
    const rightLower = this.poseLandmarks[landmarksLower[1]];
    const leftUpper = this.poseLandmarks[landmarksUpper[0]];
    const rightUpper = this.poseLandmarks[landmarksUpper[1]];

    // positive y is down
    const lowestLowerY = Math.max(leftLower.y, rightLower.y);

    const base = {
      criterion,
      min: -0.5,
      max: 0.5,
      startThreshold: startThreshold,
      endThreshold: endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
      activeLandmarkIndexes: [
        landmarksLower[0],
        landmarksLower[1],
        landmarksUpper[0],
        landmarksUpper[1],
      ],
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSides(
      base,
      getSideFeedback(lowestLowerY - leftUpper.y, threshold, [
        leftUpper,
        leftLower,
        rightLower,
      ]),
      getSideFeedback(lowestLowerY - rightUpper.y, threshold, [
        rightUpper,
        leftLower,
        rightLower,
      ])
    );

    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkLandmarksBelowHighest(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarksLower: [number, number],
    landmarksUpper: [number, number],
    startThreshold = 0.3,
    endThreshold = 0.2
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const leftLower = this.poseLandmarks[landmarksLower[0]];
    const rightLower = this.poseLandmarks[landmarksLower[1]];
    const leftUpper = this.poseLandmarks[landmarksUpper[0]];
    const rightUpper = this.poseLandmarks[landmarksUpper[1]];

    const visibleUppers = [leftUpper, rightUpper].filter(
      (landmark) =>
        landmark.visibility && landmark.visibility > this.visibilityThreshold
    );

    // positive y is down
    const highestUpperY = Math.min(
      ...visibleUppers.map((landmark) => landmark.y)
    );

    const base = {
      criterion,
      min: -0.5,
      max: 0.5,
      startThreshold: startThreshold,
      endThreshold: endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
      activeLandmarkIndexes: [
        landmarksLower[0],
        landmarksLower[1],
        // get the indexes of the visible upper landmarks
        ...visibleUppers.map((landmark) =>
          this.poseLandmarks.indexOf(landmark)
        ),
      ],
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    const uppersForVisibility = visibleUppers.length
      ? visibleUppers
      : [leftUpper, rightUpper];
    const leftSideFeedback = getSideFeedback(
      leftLower.y - highestUpperY,
      threshold,
      [leftLower, ...uppersForVisibility]
    );
    const rightSideFeedback = getSideFeedback(
      rightLower.y - highestUpperY,
      threshold,
      [rightLower, ...uppersForVisibility]
    );
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSides(
      base,
      leftSideFeedback,
      rightSideFeedback
    );

    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkLandmarksApart(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarkIndexes: [number, number],
    startThreshold = 0.25,
    endThreshold = 0.15
  ): boolean {
    return checkLandmarksApart(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarkIndexes,
      startThreshold,
      endThreshold
    );
  }

  checkLandmarksTogether(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarkIndexes: [number, number],
    startThreshold = -0.4,
    endThreshold = -0.45
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const landmark0 = this.poseLandmarks[landmarkIndexes[0]];
    const landmark1 = this.poseLandmarks[landmarkIndexes[1]];
    const base = {
      criterion,
      min: -1,
      max: 0,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    const apartDistance = -calculateDistance(landmark0, landmark1);
    const visibility = Math.min(
      landmark0.visibility || 0,
      landmark1.visibility || 0
    );
    this.poseDetectionFeedback[feedbackKey] = {
      ...base,
      value: apartDistance,
      isMet: visibility > this.visibilityThreshold && apartDistance > threshold,
      visibility,
    };
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkLandmarksSameHeight(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarks0: [number, number],
    landmarks1: [number, number],
    startThreshold = -0.2,
    endThreshold = -0.25
  ): boolean {
    return checkLandmarksSameHeight(
      criterion,
      this.isInPose,
      this.poseLandmarks,
      this.poseDetectionFeedback,
      this.visibilityThreshold,
      feedbackKey,
      landmarks0,
      landmarks1,
      startThreshold,
      endThreshold
    );
  }

  checkLandmarkPairSameHeight(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarkIndexes: [number, number],
    startThreshold = -0.2,
    endThreshold = -0.25
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const leftLandmark = this.poseLandmarks[landmarkIndexes[0]];
    const rightLandmark = this.poseLandmarks[landmarkIndexes[1]];
    const distanceThreshold = this.isInPose ? endThreshold : startThreshold;

    const base = {
      criterion,
      min: -0.5,
      max: 0.5,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
    };
    this.poseDetectionFeedback[feedbackKey] = {
      ...base,
      value: Math.abs(leftLandmark.y - rightLandmark.y),
      visibility: Math.min(
        leftLandmark.visibility || 0,
        rightLandmark.visibility || 0
      ),
      isMet: -Math.abs(leftLandmark.y - rightLandmark.y) > distanceThreshold,
    };
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkMassCenteredOn(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarkIndexes: [number, number],
    startThreshold = -0.2,
    endThreshold = -0.3
  ) {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const center = calculateCenterOfMass(this.poseLandmarks);
    const leftLandmark = this.poseLandmarks[landmarkIndexes[0]];
    const rightLandmark = this.poseLandmarks[landmarkIndexes[1]];

    const landmarksCenter = {
      x: (leftLandmark.x + rightLandmark.x) / 2,
      y: (leftLandmark.y + rightLandmark.y) / 2,
      z: (leftLandmark.z + rightLandmark.z) / 2,
    };
    const centerThreshold = this.isInPose ? endThreshold : startThreshold;
    const landmarksFromCenterHorizontally = Math.hypot(
      landmarksCenter.x - center.x,
      landmarksCenter.z - center.z
    );
    const massCenteredOnLandmarks =
      landmarksFromCenterHorizontally < -centerThreshold;

    this.poseDetectionFeedback[feedbackKey] = {
      criterion,
      min: -0.5,
      max: 0,
      startThreshold: startThreshold,
      endThreshold: endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      value: -landmarksFromCenterHorizontally,
      isMet: massCenteredOnLandmarks,
      visibilityThreshold: this.visibilityThreshold,
      visibility: Math.min(
        leftLandmark.visibility || 0,
        rightLandmark.visibility || 0
      ),
      format: (value: number) => value.toFixed(2),
    };

    return massCenteredOnLandmarks;
  }

  checkOneLandmarkHigher(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    side: "left" | "right" | undefined,
    landmarkIndexes: [number, number],
    startThreshold = 0.2,
    endThreshold = 0.1
  ) {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const leftLandmark = this.poseLandmarks[landmarkIndexes[0]];
    const rightLandmark = this.poseLandmarks[landmarkIndexes[1]];
    const higherLandmarkThreshold = this.isInPose
      ? endThreshold
      : startThreshold;

    const oneLandmarkHigherBase = {
      criterion,
      min: -0.5,
      max: 0.5,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
    };
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSidesOr(
      oneLandmarkHigherBase,
      getSideFeedback(
        // positive y is down
        -(leftLandmark.y - rightLandmark.y),
        higherLandmarkThreshold,
        [leftLandmark, rightLandmark]
      ),
      getSideFeedback(
        -(rightLandmark.y - leftLandmark.y),
        higherLandmarkThreshold,
        [leftLandmark, rightLandmark]
      ),
      side
    );
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkOneLandmarkInFront(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    side: "left" | "right" | undefined,
    landmarkIndexes: [number, number],
    startThreshold = 0.2,
    endThreshold = 0.1
  ) {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const leftLandmark = this.poseLandmarks[landmarkIndexes[0]];
    const rightLandmark = this.poseLandmarks[landmarkIndexes[1]];
    const distanceThreshold = this.isInPose ? endThreshold : startThreshold;

    const normalizedFacingDirection = this.getFacingDirection();
    const differenceVector = {
      x: leftLandmark.x - rightLandmark.x,
      z: leftLandmark.z - rightLandmark.z,
    };
    const inFrontDistance =
      differenceVector.x * normalizedFacingDirection.x +
      differenceVector.z * normalizedFacingDirection.z;

    const base = {
      criterion,
      min: -0.5,
      max: 0.5,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      format: (value: number) => value.toFixed(2),
    };
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSidesOr(
      base,
      getSideFeedback(inFrontDistance, distanceThreshold, [
        leftLandmark,
        rightLandmark,
      ]),
      getSideFeedback(-inFrontDistance, distanceThreshold, [
        leftLandmark,
        rightLandmark,
      ]),
      side
    );
    return this.poseDetectionFeedback[feedbackKey].isMet;
  }

  checkLandmarksTouching(
    criterion: CriterionForEvaluator,
    feedbackKey: string,
    landmarks0: [number, number],
    landmarks1: [number, number],
    startThreshold = -0.2,
    endThreshold = -0.25
  ): boolean {
    validateThresholds(feedbackKey, startThreshold, endThreshold);
    if (!this.poseLandmarks) return false;

    const left0 = this.poseLandmarks[landmarks0[0]];
    const right0 = this.poseLandmarks[landmarks0[1]];
    const left1 = this.poseLandmarks[landmarks1[0]];
    const right1 = this.poseLandmarks[landmarks1[1]];

    const base = {
      criterion,
      min: -0.5,
      max: 0,
      startThreshold,
      endThreshold,
      thresholdOperator: ThresholdOperator.GreaterThan,
      visibilityThreshold: this.visibilityThreshold,
      activeLandmarkIndexes: [
        landmarks0[0],
        landmarks0[1],
        landmarks1[0],
        landmarks1[1],
      ],
      format: (value: number) => value.toFixed(2),
    };
    const threshold = this.isInPose ? base.endThreshold : base.startThreshold;
    this.poseDetectionFeedback[feedbackKey] = combineFeedbackSides(
      base,
      getSideFeedback(-calculateDistance(left0, left1), threshold, [
        left0,
        left1,
      ]),
      getSideFeedback(-calculateDistance(right0, right1), threshold, [
        right0,
        right1,
      ])
    );

    return this.poseDetectionFeedback[feedbackKey].isMet;
  }
}
