
#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    int direction;
} fxVars;

constant vec4 deltaTransTable[9] = {
    vec4(-2.,2.,0.,.0),
    vec4( 0.,2.,0.,.0),
    vec4( 2.,2.,0.,.0),
    vec4(-2.,0.,0.0,.0),
    vec4( 0.,0.,0.0,.0),
    vec4( 2.,0.,0.0,.0),
    vec4(-2.,-2.,0.,.0),
    vec4( 0.,-2.,0.,.0),
    vec4( 2.,-2.,0.,.0)
};

constant vec4 transOffsetTable[9] = {
    vec4(-2.,2.,0.0,.0),
    vec4( 0.,2.,0.0,.0),
    vec4( 2.,2.,0.0,.0),
    vec4(-2.,0.,0.0,.0),
    vec4( 0.,0.,0.0,.0),
    vec4( 2.,0.,0.0,.0),
    vec4(-2.,-2.,0.,.0),
    vec4( 0.,-2.,0.,.0),
    vec4( 2.,-2.,0.,.0)
};



vertex vertexOut vertexFunc(uint vertexID [[ vertex_id ]],const device txShaderVers* in[[ buffer(0) ]],constant txShaderUniforms& uniforms[[ buffer(3) ]],constant fxVars& vars[[ buffer(2)]])
{
    vertexOut out;

    out.normPos=in[vertexID].pos.xy;


    vec4 deltaTrans=deltaTransTable[vars.direction];
    vec4 transOffset=transOffsetTable[vars.direction];

    mat4 translation=mat4(
                          1.0, 0.0, 0.0, 0.00,
                          0.0, 1.0, 0.0, 0.00,
                          0.0, 0.0, 1.0, 0.00,
                          0.0, 0.0, 0.0, 1.00);

    vec4 tmpPos=vec4(in[vertexID].pos.xyz,1.0)*uniforms.worldViewProj;
    if(uniforms.transitionState==0)
    {
        transOffset-=deltaTrans*uniforms.transitionProgress;

        translation[0][3]+=transOffset.x;
        translation[1][3]+=transOffset.y;
        translation[2][3]+=transOffset.z;
    }
    out.position=tmpPos*translation;

    out.tex1=in[vertexID].tex1;
    out.tex2=in[vertexID].tex2;
    out.texMask=in[vertexID].texMask;
    return out;
}

fragment float4 fragmentFunc(vertexOut input [[stage_in]]
                             ,texture2d<half> inputTex0 [[ texture(0) ]]
                             ,texture2d<half> inputTex1 [[ texture(1) ]]
                             ,texture2d<half> inputTexMask [[ texture(2) ]]
                             ,constant txShaderUniforms& u[[ buffer(2) ]]
                             ,constant subFXVars& subVars[[buffer(3)]])
{
    float4 outColor = getColor(inputTex0, input.tex1);
    float alpha = u.transitionState == 0 ? 1.0 : 1.0 - u.transitionProgress;
    alpha *= u.baseAlpha;
#ifdef USE_MASK
    alpha *= float(inputTexMask.sample(linearSampler, input.texMask).a);
#endif
    multiplyOpacity(outColor, alpha);
    return outColor;
}
