import { ShaderVersions } from "./constants";

export const render_vert = `in vec3 position;
in vec2 uv;

out vec2 vUv;

uniform mat4 modelViewMatrix;
uniform mat4 projectionMatrix;

void main() {
    vUv = uv;
    gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );
}`;

export const gbuffer_vert = `in vec3 position;
in vec2 uv;

out vec2 vUv;
out vec3 vPosition;
out vec3 rayDirection;

uniform mat4 modelViewMatrix;
uniform mat4 projectionMatrix;
uniform mat4 modelMatrix;
uniform vec3 cameraPosition;

void main() {
    vUv = uv;
    vPosition = position;
    gl_Position = projectionMatrix * modelViewMatrix * vec4( position, 1.0 );
    rayDirection = (modelMatrix * vec4( position, 1.0 )).rgb - cameraPosition;
}`;

export const gbuffer_frag = `precision highp float;

layout(location = 0) out vec4 gColor0;
layout(location = 1) out vec4 gColor1;
layout(location = 2) out vec4 gColor2;
layout(location = 3) out vec4 gNormal;

uniform mediump sampler2D tDiffuse0;
uniform mediump sampler2D tDiffuse1;
uniform mediump sampler2D tNormals;

in vec2 vUv;
in vec3 vPosition;
in vec3 rayDirection;

void main() {

    // write color to G-Buffer
    gColor1 = texture( tDiffuse0, vUv );
    if (gColor1.r == 0.0) discard;
    gColor0 = vec4( normalize(rayDirection), 1.0 );
    gColor2 = texture( tDiffuse1, vUv );
    gNormal = texture( tNormals, vUv);

}`;

export const gbufferFragment = () => {
  return `
    precision highp float;

    layout(location = 0) out vec4 gColor0;
    layout(location = 1) out vec4 gColor1;
    layout(location = 2) out vec4 gColor2;
    layout(location = 3) out vec4 gNormal;
  
    uniform mediump sampler2D tDiffuse0;
    uniform mediump sampler2D tDiffuse1;
    uniform mediump sampler2D tNormals;

  
    in vec2 vUv;
    in vec3 vPosition;
    in vec3 rayDirection;
  
    void main() {
  
        // write color to G-Buffer
        gColor1 = texture( tDiffuse0, vUv );
        if (gColor1.r == 0.0) discard;
        gColor0 = vec4( normalize(rayDirection), 1.0 );
        gColor2 = texture( tDiffuse1, vUv );
        gNormal = texture( tNormals, vUv);
  
    }`;
};

export const diffuse_network_shader = `
        /// pbr_diffuse
        mediump float result_diffuse[NUM_CHANNELS_DIFFUSE] = float[](
          BIAS_DIFFUSE_DENSE
        );
        for (int j = 0; j < 8; ++j) {
            mediump float input_value = 0.0;
            if (j < 4) {
            input_value =
                (j == 0) ? f0.r : (
                (j == 1) ? f0.g : (
                (j == 2) ? f0.b : f0.a));
            } else if (j < 8) {
            input_value =
                (j == 4) ? f1.r : (
                (j == 5) ? f1.g : (
                (j == 6) ? f1.b : f1.a));
            }
            for (int i = 0; i < NUM_CHANNELS_DIFFUSE; ++i) {
              result_diffuse[i] += input_value *
                texelFetch(weightDiffuse, ivec2(j, i), 0).x;
            }
        }`;

export const tint_network_shader = `
        // pbr_tint
        mediump float result_tint[NUM_CHANNELS_TINT] = float[](
          BIAS_TINT_DENSE
        );
        for (int j = 0; j < 8; ++j) {
            mediump float input_value = 0.0;
            if (j < 4) {
            input_value =
                (j == 0) ? f0.r : (
                (j == 1) ? f0.g : (
                (j == 2) ? f0.b : f0.a));
            } else if (j < 8) {
            input_value =
                (j == 4) ? f1.r : (
                (j == 5) ? f1.g : (
                (j == 6) ? f1.b : f1.a));
            }
            for (int i = 0; i < NUM_CHANNELS_TINT; ++i) {
              result_tint[i] += input_value *
                texelFetch(weightsTint, ivec2(j, i), 0).x;
            }
        }
        for (int i = 0; i < NUM_CHANNELS_TINT; ++i) {
            result_tint[i] = 1.0 / (1.0 + exp(-result_tint[i]));
        }`;

export const version_color_3layer_MLP = `
        mediump float intermediate_one[NUM_CHANNELS_ONE] = float[](BIAS_LIST_ZERO);

        for (int j = 0; j < NUM_CHANNELS_ZERO; ++j) {
            mediump float input_value = 0.0;
            if (j < 4) {
              input_value =
                (j == 0) ? refdirs.r : (
                (j == 1) ? refdirs.g :(
                (j == 2) ? refdirs.b: dotprod)); //switch y-z axes
            } else if (j < 8) {
            input_value =
                (j == 4) ? f0.r : (
                (j == 5) ? f0.g : (
                (j == 6) ? f0.b : f0.a));
            } else {
              input_value =
                (j == 8) ? f1.r : (
                (j == 9) ? f1.g : (
                (j == 10) ? f1.b : f1.a));
            }
            for (int i = 0; i < NUM_CHANNELS_ONE; ++i) {
            intermediate_one[i] += input_value *
                texelFetch(weightsZero, ivec2(j, i), 0).x;
            }
        }

        mediump float intermediate_two[NUM_CHANNELS_TWO] = float[](BIAS_LIST_ONE);

        for (int j = 0; j < NUM_CHANNELS_ONE; ++j) {
            if (intermediate_one[j] <= 0.0) {
                continue;
            }
            for (int i = 0; i < NUM_CHANNELS_TWO; ++i) {
                intermediate_two[i] += intermediate_one[j] *
                    texelFetch(weightsOne, ivec2(j, i), 0).x;
            }
        }
        mediump float result[NUM_CHANNELS_THREE] = float[](BIAS_LIST_TWO);

        for (int j = 0; j < NUM_CHANNELS_TWO; ++j) {
            if (intermediate_two[j] <= 0.0) {
                continue;
            }
            for (int i = 0; i < NUM_CHANNELS_THREE; ++i) {
                result[i] += intermediate_two[j] *
                    texelFetch(weightsTwo, ivec2(j, i), 0).x;
            }
        }

        for (int i = 0; i < NUM_CHANNELS_THREE; ++i) {
            result[i] = 1.0 / (1.0 + exp(-result[i]));
        }
        `;

export const version_color_2layer_MLP = `
        mediump float intermediate_one[NUM_CHANNELS_ONE] = float[](BIAS_LIST_ZERO);
        for (int j = 0; j < NUM_CHANNELS_ZERO; ++j) {
            mediump float input_value = 0.0;
            if (j < 4) {
              input_value =
                (j == 0) ? refdirs.r : (
                (j == 1) ? refdirs.g :(
                (j == 2) ? refdirs.b: dotprod)); //switch y-z axes
            } else if (j < 8) {
            input_value =
                (j == 4) ? f0.r : (
                (j == 5) ? f0.g : (
                (j == 6) ? f0.b : f0.a));
            } else {
              input_value =
                (j == 8) ? f1.r : (
                (j == 9) ? f1.g : (
                (j == 10) ? f1.b : f1.a));
            }
            for (int i = 0; i < NUM_CHANNELS_ONE; ++i) {
            intermediate_one[i] += input_value *
                texelFetch(weightsZero, ivec2(j, i), 0).x;
            }
        }

        mediump float result[NUM_CHANNELS_THREE] = float[](BIAS_LIST_TWO);

        for (int j = 0; j < NUM_CHANNELS_ONE; ++j) {
            if (intermediate_one[j] <= 0.0) {
                continue;
            }
            for (int i = 0; i < NUM_CHANNELS_THREE; ++i) {
                result[i] += intermediate_one[j] *
                    texelFetch(weightsTwo, ivec2(j, i), 0).x;
            }
        }
        for (int i = 0; i < NUM_CHANNELS_THREE; ++i) {
            result[i] = 1.0 / (1.0 + exp(-result[i]));
        }`;

export const viewDependenceNetworkShaderFunctionsF = (
  versionShaders: number
) => {
  let version_color = version_color_3layer_MLP;

  switch (versionShaders) {
    case ShaderVersions.first:
      version_color = version_color_3layer_MLP;
      break;

    case ShaderVersions.second:
      version_color = version_color_2layer_MLP;
      break;
  }

  const viewDependenceNetworkShaderFunctions = `
    precision mediump float;

    layout(location = 0) out vec4 pc_FragColor;

    in vec2 vUv;

    uniform mediump sampler2D tDiffuse0x;
    uniform mediump sampler2D tDiffuse1x;
    uniform mediump sampler2D tDiffuse2x;
    uniform mediump sampler2D tNormals;

    uniform mediump sampler2D weightsZero;
    uniform mediump sampler2D weightsOne;
    uniform mediump sampler2D weightsTwo;

    uniform mediump sampler2D weightsTint;
    uniform mediump sampler2D weightDiffuse;

    mediump vec3 evaluateNetwork( mediump vec4 f0, mediump vec4 f1, mediump vec4 viewdir, mediump vec3 normal_tex) {
        vec3 normal = normal_tex*2.0 - 1.0;

        vec3 refdirs = -reflect(-vec3(viewdir.rgb), normal);
        refdirs = (refdirs + 1.0)/2.0;

        
        /// pbr_diffuse
        ${diffuse_network_shader}

        // pbr_tint
        ${tint_network_shader}

        // pbr_use_n_dot_v
        mediump float dotprod = 0.0;
        dotprod += normal[0]*viewdir.r;
        dotprod += normal[1]*viewdir.g;
        dotprod += normal[2]*viewdir.b;
        
        ${version_color}
        
        // pbr_use_diffuse_color 
        mediump float diffuse_linear[NUM_CHANNELS_THREE];

        for (int i = 0; i < NUM_CHANNELS_THREE; ++i) {
            diffuse_linear[i] = 1.0 / (1.0 + exp(-(result_diffuse[i] - log(3.0))));
            result[i] = clamp((result_tint[i]*result[i] + diffuse_linear[i]), 0.0, 1.0);
        }  

        return vec3(result[0]*viewdir.a+(1.0-viewdir.a),
                    result[1]*viewdir.a+(1.0-viewdir.a),
                    result[2]*viewdir.a+(1.0-viewdir.a));
      }


    void main() {

        vec4 diffuse0 = texture( tDiffuse0x, vUv );
        if (diffuse0.a < 0.6) discard;
        vec4 diffuse1 = texture( tDiffuse1x, vUv );
        vec4 diffuse2 = texture( tDiffuse2x, vUv );
        vec4 normal_tex = texture( tNormals, vUv);

        //deal with iphone
        //diffuse0.a = diffuse0.a*2.0-1.0;
        //diffuse1.a = diffuse1.a*2.0-1.0;
        //diffuse2.a = diffuse2.a*2.0-1.0;


        // pc_FragColor.rgb  = diffuse1.rgb;
        pc_FragColor.rgb = evaluateNetwork(diffuse1,diffuse2,diffuse0, normal_tex.bgr);
        pc_FragColor.a = 1.0;
    }
`;

  return viewDependenceNetworkShaderFunctions;
};
