import * as THREE from "three";
import { OrbitControls } from "three/examples/jsm/controls/OrbitControls";
import { DefaultCameraSpecs } from "../../utils/types";
import { LoadingSpinner } from "./LoadingSpinner";
import { PlyLoader } from "./PlyLoader";
import { SplatBuffer } from "./SplatBuffer";
import { SplatLoader } from "./SplatLoader";
import { createSortWorker } from "./utils";

const DEFAULT_CAMERA_SPECS: DefaultCameraSpecs = {
  fx: 1159.5880733038064,
  fy: 1164.6601287484507,
  near: 0.1,
  far: 500,
};

type PropsViewer = {
  rootElement?: HTMLElement | null;
  cameraUp?: number[];
  initialCameraPos?: number[];
  initialCameraLookAt?: number[];
  cameraSpecs?: DefaultCameraSpecs;
  controls?: OrbitControls | null;
  selfDrivenMode?: boolean;
  callbackMesh?: any;
};

export class Viewer {
  cameraUp: THREE.Vector3;
  rootElement: HTMLElement | null;
  initialCameraPos: THREE.Vector3;
  initialCameraLookAt: THREE.Vector3;
  cameraSpecs: DefaultCameraSpecs;
  controls: OrbitControls | null;
  selfDrivenMode: boolean;
  scene: THREE.Scene | null;
  camera: THREE.PerspectiveCamera | null;
  realProjectionMatrix: THREE.Matrix4;
  renderer: THREE.WebGLRenderer | null;
  splatBuffer: SplatBuffer | null;
  splatMesh: THREE.Mesh<
    THREE.InstancedBufferGeometry,
    THREE.ShaderMaterial
  > | null;
  sortWorker: Worker | null;
  callbackMesh: any;

  resizeFunc: () => void;
  selfDrivenUpdateFunc: () => void;

  constructor({
    rootElement = null,
    cameraUp = [0, 1, 0],
    initialCameraPos = [0, 10, 15],
    initialCameraLookAt = [0, 0, 0],
    cameraSpecs = DEFAULT_CAMERA_SPECS,
    controls = null,
    selfDrivenMode = true,
    callbackMesh,
  }: PropsViewer) {
    // const {
    //   rootElement = null,
    //   cameraUp = [0, 1, 0],
    //   initialCameraPos = [0, 10, 15],
    //   initialCameraLookAt = [0, 0, 0],
    //   cameraSpecs = DEFAULT_CAMERA_SPECS,
    //   controls = null,
    //   selfDrivenMode = true,
    //   callbackMesh,
    // } = params;

    this.rootElement = rootElement;
    this.cameraUp = new THREE.Vector3().fromArray(cameraUp);
    this.initialCameraPos = new THREE.Vector3().fromArray(initialCameraPos);
    this.initialCameraLookAt = new THREE.Vector3().fromArray(
      initialCameraLookAt
    );
    this.cameraSpecs = cameraSpecs;
    this.controls = controls;
    this.selfDrivenMode = selfDrivenMode;
    this.scene = null;
    this.camera = null;
    this.realProjectionMatrix = new THREE.Matrix4();
    this.renderer = null;
    this.splatBuffer = null;
    this.splatMesh = null;
    this.selfDrivenUpdateFunc = this.update.bind(this);
    this.resizeFunc = this.onResize.bind(this);
    this.sortWorker = null;
    this.callbackMesh = callbackMesh;
  }

  getRenderDimensions(outDimensions: THREE.Vector2) {
    outDimensions.x = this.rootElement?.offsetWidth || 0;
    outDimensions.y = this.rootElement?.offsetHeight || 0;
  }

  updateRealProjectionMatrix(renderDimensions: THREE.Vector2) {
    this.realProjectionMatrix.elements = [
      [(2 * this.cameraSpecs.fx) / renderDimensions.x, 0, 0, 0],
      [0, (2 * this.cameraSpecs.fy) / renderDimensions.y, 0, 0],
      [
        0,
        0,
        -(this.cameraSpecs.far + this.cameraSpecs.near) /
          (this.cameraSpecs.far - this.cameraSpecs.near),
        -1,
      ],
      [
        0,
        0,
        -(2.0 * this.cameraSpecs.far * this.cameraSpecs.near) /
          (this.cameraSpecs.far - this.cameraSpecs.near),
        0,
      ],
    ].flat();
  }

  onResize = (function () {
    const renderDimensions = new THREE.Vector2();

    return function (this: Viewer) {
      if (!this.renderer || !this.camera) {
        return;
      }

      this.renderer.setSize(1, 1);
      this.getRenderDimensions(renderDimensions);
      this.camera.aspect = renderDimensions.x / renderDimensions.y;
      this.camera.updateProjectionMatrix();
      this.renderer.setSize(renderDimensions.x, renderDimensions.y);
      this.updateRealProjectionMatrix(renderDimensions);
      this.updateSplatMeshUniforms();
    };
  })();

  init() {
    if (!this.rootElement) {
      this.rootElement = document.createElement("div");
      this.rootElement.style.width = "100%";
      this.rootElement.style.height = "100%";
      document.body.appendChild(this.rootElement);
    }

    const renderDimensions = new THREE.Vector2();
    this.getRenderDimensions(renderDimensions);

    this.camera = new THREE.PerspectiveCamera(
      70,
      renderDimensions.x / renderDimensions.y,
      0.1,
      500
    );
    this.camera.position.copy(this.initialCameraPos);
    this.camera.lookAt(this.initialCameraLookAt);
    this.camera.up.copy(this.cameraUp).normalize();
    this.updateRealProjectionMatrix(renderDimensions);

    this.scene = new THREE.Scene();

    this.renderer = new THREE.WebGLRenderer({
      antialias: false,
    });
    this.renderer.shadowMap.enabled = true;
    this.renderer.shadowMap.type = THREE.PCFSoftShadowMap;
    this.renderer.setSize(renderDimensions.x, renderDimensions.y);

    if (!this.controls) {
      this.controls = new OrbitControls(this.camera, this.renderer.domElement);
      this.controls.maxPolarAngle = (0.9 * Math.PI) / 2;
      this.controls.enableDamping = true;
      this.controls.dampingFactor = 0.15;
      this.controls.target.copy(this.initialCameraLookAt);
    }

    window.addEventListener("resize", this.resizeFunc, false);

    this.rootElement.appendChild(this.renderer.domElement);

    this.sortWorker = new Worker(
      URL.createObjectURL(
        new Blob(["(", createSortWorker.toString(), ")(self)"], {
          type: "application/javascript",
        })
      )
    );

    this.sortWorker.onmessage = (e) => {
      let { color, centerCov } = e.data;
      this.updateSplatMeshAttributes(color, centerCov);
      this.updateSplatMeshUniforms();
      this.callbackMesh(this.splatMesh);
    };
  }

  updateSplatMeshAttributes(
    colors: Float32Array,
    centerCovariances: Float32Array
  ) {
    if (!this.splatMesh) {
      return;
    }

    const vertexCount = centerCovariances.length / 9;
    const geometry = this.splatMesh.geometry;

    const attributes = geometry.attributes as {
      [name: string]: THREE.BufferAttribute;
    };

    attributes.splatCenterCovariance.set(centerCovariances);
    attributes.splatCenterCovariance.needsUpdate = true;

    attributes.splatColor.set(colors);
    attributes.splatColor.needsUpdate = true;

    geometry.instanceCount = vertexCount;
  }

  updateSplatMeshUniforms = (function () {
    const renderDimensions = new THREE.Vector2();

    return function (this: Viewer) {
      this.getRenderDimensions(renderDimensions);
      if (this.splatMesh) {
        this.splatMesh.material.uniforms.realProjectionMatrix.value.copy(
          this.realProjectionMatrix
        );
        this.splatMesh.material.uniforms.focal.value.set(
          this.cameraSpecs.fx,
          this.cameraSpecs.fy
        );
        this.splatMesh.material.uniforms.viewport.value.set(
          renderDimensions.x,
          renderDimensions.y
        );
        this.splatMesh.material.uniformsNeedUpdate = true;
      }
    };
  })();

  async loadFile(fileName: string) {
    const loadingSpinner = new LoadingSpinner("");
    loadingSpinner.show();
    const loadPromise = new Promise<SplatBuffer>((resolve, reject) => {
      let fileLoadPromise: Promise<SplatBuffer> = new Promise<SplatBuffer>(
        () => {}
      );
      if (fileName.endsWith(".splat")) {
        fileLoadPromise = new SplatLoader().loadFromFile(
          fileName
        ) as Promise<SplatBuffer>;
      } else if (fileName.endsWith(".ply")) {
        fileLoadPromise = new PlyLoader().loadFromFile(
          fileName
        ) as Promise<SplatBuffer>;
      } else {
        reject(
          new Error(
            `Viewer::loadFile -> File format not supported: ${fileName}`
          )
        );
      }

      fileLoadPromise
        .then((splatBuffer) => {
          resolve(splatBuffer);
        })
        .catch((e) => {
          reject(
            new Error(`Viewer::loadFile -> Could not load file ${fileName}`)
          );
        });
    });

    const splatBuffer_1 = await loadPromise;
    this.splatBuffer = splatBuffer_1;
    this.splatMesh = this.buildMesh(this.splatBuffer);
    this.splatMesh.frustumCulled = false;
    loadingSpinner.hide();
    this.scene?.add(this.splatMesh);
    this.updateWorkerBuffer();

    return this.splatMesh;
  }

  addDebugMeshesToScene() {
    const sphereGeometry = new THREE.SphereGeometry(1, 32, 32);

    let sphereMesh = new THREE.Mesh(
      sphereGeometry,
      new THREE.MeshBasicMaterial({ color: 0xff0000 })
    );
    this.scene?.add(sphereMesh);
    sphereMesh.position.set(-50, 0, 0);

    sphereMesh = new THREE.Mesh(
      sphereGeometry,
      new THREE.MeshBasicMaterial({ color: 0xff0000 })
    );
    this.scene?.add(sphereMesh);
    sphereMesh.position.set(50, 0, 0);

    sphereMesh = new THREE.Mesh(
      sphereGeometry,
      new THREE.MeshBasicMaterial({ color: 0x00ff00 })
    );
    this.scene?.add(sphereMesh);
    sphereMesh.position.set(0, 0, -50);

    sphereMesh = new THREE.Mesh(
      sphereGeometry,
      new THREE.MeshBasicMaterial({ color: 0x00ff00 })
    );
    this.scene?.add(sphereMesh);
    sphereMesh.position.set(0, 0, 50);
  }

  start() {
    if (this.selfDrivenMode) {
      requestAnimationFrame(this.selfDrivenUpdateFunc);
    } else {
      throw new Error("Cannot start viewer unless it is in self driven mode.");
    }
  }

  update() {
    if (this.selfDrivenMode) {
      requestAnimationFrame(this.selfDrivenUpdateFunc);
    }
    this.controls?.update();
    this.updateView();

    if (!this.renderer || !this.scene || !this.camera) {
      return;
    }

    this.renderer.autoClear = false;
    this.renderer.render(this.scene, this.camera);
    // console.log('---update')
  }

  updateView = (function () {
    const tempMatrix = new THREE.Matrix4();
    const tempVector2 = new THREE.Vector2();

    return function (this: Viewer) {
      if (!this.camera || !this.sortWorker) {
        return;
      }

      this.getRenderDimensions(tempVector2);
      tempMatrix.copy(this.camera.matrixWorld).invert();
      tempMatrix.premultiply(this.realProjectionMatrix);
      //   console.log('---updateView')
      this.sortWorker.postMessage({
        sort: {
          view: tempMatrix.elements,
        },
      });
    };
  })();

  updateWorkerBuffer = (function () {
    return function (this: Viewer) {
      if (!this.sortWorker || !this.splatBuffer) {
        return;
      }

      this.sortWorker.postMessage({
        bufferUpdate: {
          rowSizeFloats: SplatBuffer.RowSizeFloats,
          rowSizeBytes: SplatBuffer.RowSizeBytes,
          splatBuffer: this.splatBuffer.getBufferData(),
          precomputedCovariance: this.splatBuffer.getCovarianceBufferData(),
          precomputedColor: this.splatBuffer.getColorBufferData(),
          vertexCount: this.splatBuffer.getVertexCount(),
        },
      });
    };
  })();

  buildMaterial(): THREE.ShaderMaterial {
    const vertexShaderSource = `
              #include <common>
              precision mediump float;
  
              attribute vec4 splatColor;
              attribute mat3 splatCenterCovariance;
  
              uniform mat4 realProjectionMatrix;
              uniform vec2 focal;
              uniform vec2 viewport;
  
              varying vec4 vColor;
              varying vec2 vPosition;
  
              void main () {
  
              vec3 splatCenter = splatCenterCovariance[0];
              vec3 cov3D_M11_M12_M13 = splatCenterCovariance[1];
              vec3 cov3D_M22_M23_M33 = splatCenterCovariance[2];
  
              vec4 camspace = viewMatrix * vec4(splatCenter, 1);
              vec4 pos2d = realProjectionMatrix * camspace;
  
              float bounds = 1.2 * pos2d.w;
              if (pos2d.z < -pos2d.w || pos2d.x < -bounds || pos2d.x > bounds
                  || pos2d.y < -bounds || pos2d.y > bounds) {
                  gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
                  return;
              }
  
              mat3 Vrk = mat3(
                  cov3D_M11_M12_M13.x, cov3D_M11_M12_M13.y, cov3D_M11_M12_M13.z,
                  cov3D_M11_M12_M13.y, cov3D_M22_M23_M33.x, cov3D_M22_M23_M33.y,
                  cov3D_M11_M12_M13.z, cov3D_M22_M23_M33.y, cov3D_M22_M23_M33.z
              );
  
              mat3 J = mat3(
                  focal.x / camspace.z, 0., -(focal.x * camspace.x) / (camspace.z * camspace.z),
                  0., focal.y / camspace.z, -(focal.y * camspace.y) / (camspace.z * camspace.z),
                  0., 0., 0.
              );
  
              mat3 W = transpose(mat3(viewMatrix));
              mat3 T = W * J;
              mat3 cov2Dm = transpose(T) * Vrk * T;
              cov2Dm[0][0] += 0.3;
              cov2Dm[1][1] += 0.3;
              vec3 cov2Dv = vec3(cov2Dm[0][0], cov2Dm[0][1], cov2Dm[1][1]);
  
              vec2 vCenter = vec2(pos2d) / pos2d.w;
  
              float diagonal1 = cov2Dv.x;
              float offDiagonal = cov2Dv.y;
              float diagonal2 = cov2Dv.z;
  
              float mid = 0.5 * (diagonal1 + diagonal2);
              float radius = length(vec2((diagonal1 - diagonal2) / 2.0, offDiagonal));
              float lambda1 = mid + radius;
              float lambda2 = max(mid - radius, 0.1);
              vec2 diagonalVector = normalize(vec2(offDiagonal, lambda1 - diagonal1));
              vec2 v1 = min(sqrt(2.0 * lambda1), 1024.0) * diagonalVector;
              vec2 v2 = min(sqrt(2.0 * lambda2), 1024.0) * vec2(diagonalVector.y, -diagonalVector.x);
  
              vColor = splatColor;
              vPosition = position.xy;
  
              vec2 projectedCovariance = vCenter +
                                         position.x * v1 / viewport * 2.0 +
                                         position.y * v2 / viewport * 2.0;
  
              gl_Position = vec4(projectedCovariance, 0.0, 1.0);
          }`;

    const fragmentShaderSource = `
              #include <common>
              precision mediump float;
  
              varying vec4 vColor;
              varying vec2 vPosition;
  
              void main () {
                  float A = -dot(vPosition, vPosition);
                  if (A < -4.0) discard;
                  float B = exp(A) * vColor.a;
                  gl_FragColor = vec4(B * vColor.rgb, B);
              }`;

    const uniforms = {
      realProjectionMatrix: {
        type: "v4v",
        value: new THREE.Matrix4(),
      },
      focal: {
        type: "v2",
        value: new THREE.Vector2(),
      },
      viewport: {
        type: "v2",
        value: new THREE.Vector2(),
      },
    };

    return new THREE.ShaderMaterial({
      uniforms: uniforms,
      vertexShader: vertexShaderSource,
      fragmentShader: fragmentShaderSource,
      transparent: true,
      alphaTest: 1.0,
      blending: THREE.CustomBlending,
      blendEquation: THREE.AddEquation,
      blendSrc: THREE.OneMinusDstAlphaFactor,
      blendDst: THREE.OneFactor,
      blendSrcAlpha: THREE.OneMinusDstAlphaFactor,
      blendDstAlpha: THREE.OneFactor,
      depthTest: false,
      depthWrite: false,
      side: THREE.DoubleSide,
    });
  }

  buildGeometry(splatBuffer: SplatBuffer) {
    const baseGeometry = new THREE.BufferGeometry();

    const positionsArray = new Float32Array(18);
    const positions = new THREE.BufferAttribute(positionsArray, 3);
    baseGeometry.setAttribute("position", positions);
    positions.setXYZ(2, -2.0, 2.0, 0.0);
    positions.setXYZ(1, -2.0, -2.0, 0.0);
    positions.setXYZ(0, 2.0, 2.0, 0.0);
    positions.setXYZ(5, -2.0, -2.0, 0.0);
    positions.setXYZ(4, 2.0, -2.0, 0.0);
    positions.setXYZ(3, 2.0, 2.0, 0.0);
    positions.needsUpdate = true;

    const geometry = new THREE.InstancedBufferGeometry().copy(baseGeometry);

    const splatColorsArray = new Float32Array(splatBuffer.getVertexCount() * 4);
    const splatColors = new THREE.InstancedBufferAttribute(
      splatColorsArray,
      4,
      false
    );
    splatColors.setUsage(THREE.DynamicDrawUsage);
    geometry.setAttribute("splatColor", splatColors);

    const splatCentersArray = new Float32Array(
      splatBuffer.getVertexCount() * 9
    );
    const splatCenters = new THREE.InstancedBufferAttribute(
      splatCentersArray,
      9,
      false
    );
    splatCenters.setUsage(THREE.DynamicDrawUsage);
    geometry.setAttribute("splatCenterCovariance", splatCenters);

    return geometry;
  }

  buildMesh(splatBuffer: SplatBuffer) {
    const geometry = this.buildGeometry(splatBuffer);
    const material = this.buildMaterial();
    const mesh = new THREE.Mesh(geometry, material);
    return mesh;
  }
}
