#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    float none;
} fxVars;


vertex fxVertexOut vertexFunc(uint vertexID [[ vertex_id ]]
                              ,const device fxShaderVerts* in[[ buffer(0) ]]
                              ,texture2d<half> inputTex0 [[ texture(0) ]]
                              ,texture2d<half> inputTexMask [[ texture(1) ]]
                              ,constant fxGeneralUniforms& u[[ buffer(1) ]]
                              ,constant fxVars& vars[[ buffer(2)]])
{
    fxVertexOut out;

    out.normPos=in[vertexID].pos.xy*vec2(.5,.5f)+vec2(.5,.5);
    out.position=in[vertexID].pos;

    out.tex1=in[vertexID].tex1;
    out.texMask=in[vertexID].texMask;
    out.texSize=float2(inputTex0.get_width(),inputTex0.get_height());
    return out;
}

fragment float4 fragmentFunc(fxVertexOut input [[stage_in]]
                             ,texture2d<half> inputTex [[ texture(0) ]]
                             ,texture2d<half> inputTexMask [[ texture(1) ]]
                             ,constant fxGeneralUniforms& u[[ buffer(0) ]]
                             ,constant fxVars& vars[[ buffer(1)]])
{

    vec4 color = getColor(inputTex, fsTexture);
#ifdef PRE_MULT
    divideAlpha(color);
#endif


#ifdef LINEAR_16
    float gray = dot(color.rgb  / u.nitsScale, vec3(0.355, 0.658, 0.004));
#else
    float gray = dot(color.rgb, vec3(0.299, 0.587, 0.114));
#endif

    // compute the heat signature value

    vec4 heatSig = vec4( 0.0, 0.0, 0.0, 0.0 );

    if ( gray < 0.25 )
    {
        // black to blue
        gray *= 4.0;

        heatSig.r = 0.0;
        heatSig.g = 0.0;
        heatSig.b = gray;
    }
    else if ( gray < 0.5 )
    {
        // blue to green
        gray -= 0.25;
        gray *= 4.0;

        heatSig.r = 0.0;
        heatSig.g = gray;
        heatSig.b = 1.0 - gray;
    }
    else if ( gray < 0.75 )
    {
        // green to yellow
        gray -= 0.5;
        gray *= 4.0;

        heatSig.r = gray;
        heatSig.g = 1.0;
        heatSig.b = 0.0;
    }
    else
    {
        // yellow to red
        gray -= 0.75;
        gray *= 4.0;

        heatSig.r = 1.0;
        heatSig.g = 1.0 - gray;
        heatSig.b = 0.0;
    }
    heatSig.rgb *= u.nitsScale;

    heatSig.a = color.a;

    float4 outColor = heatSig;


    float blendValue=u.blendValue;
#ifdef USE_MASK
    blendValue*=getColor(inputTexMask,input.texMask).a;
#endif
    outColor.rgb=mix(color.rgb,outColor.rgb,blendValue);
#ifdef PRE_MULT
    multiplyAlpha(outColor);
#endif
    return outColor;
}
