
#include "metalShaderTypes.h"
#include "alpha.h"

vertex vertexOut vertexFunc(uint vertexID [[ vertex_id ]],
                            const device txShaderVers* in[[ buffer(0) ]],
                            constant txShaderUniforms& uniforms[[ buffer(3) ]])
{
    vertexOut out;

    out.position=in[vertexID].pos*uniforms.worldViewProj;

    out.tex1=in[vertexID].tex1;
    out.tex2=in[vertexID].tex2;
    out.texMask=in[vertexID].texMask;
    return out;
}

vec4 grayscale (vec4 color)
{
    float gray=0.2126*color.r + 0.7152*color.g + 0.0722*color.b;
    return vec4(gray,gray,gray,color.a);
}

fragment float4 fragmentFunc(vertexOut input [[stage_in]]
                             ,texture2d<half> inputTex0 [[ texture(0) ]]
                             ,texture2d<half> inputTex1 [[ texture(1) ]]
                             ,texture2d<half> inputTexMask [[ texture(2) ]]
                             ,constant txShaderUniforms& u[[ buffer(2) ]]
                             ,constant subFXVars& subVars[[buffer(3)]])
{
    vec4 colorA = getColor(inputTex0, input.tex1);
    vec4 colorB = getColor(inputTex1, input.tex2);

#ifndef PRE_MULT
    multiplyAlpha(colorA);
    multiplyAlpha(colorB);
#endif

    vec4 color1=mix(grayscale(colorA),colorA,smoothstep(.6,1.0,u.transitionProgress));
    vec4 color2=mix(grayscale(colorB),colorB,smoothstep(.6,1.0,1.-u.transitionProgress));
    float4 outColor=mix(color2,color1,u.transitionProgress);

#ifdef USE_MASK
    multiplyOpacityPremultiplied(outColor, float(inputTexMask.sample(linearSampler, input.texMask).a));
#endif

#ifndef PRE_MULT
    divideAlpha(outColor);
#endif

    return outColor;
}
