#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    int direction;
} fxVars;


constant vec4 dirMult[9]={
    vec4(1.,-1.,1.,0.),
    vec4(0.,-1.,1.,0.),
    vec4(-1.,-1.,1.,0.),
    vec4(1.,0.,1.,0.),
    vec4(0.,0.,1.,0.),
    vec4(-1.,0.,1.,0.),
    vec4(1.,1.,1.,0.),
    vec4(0.,1.,1.,0.),
    vec4(-1.,1.,1.,0.)
};

vertex vertexOut vertexFunc(uint vertexID [[ vertex_id ]],const device txShaderVers* in[[ buffer(0) ]],constant txShaderUniforms& uniforms[[ buffer(3) ]],constant fxVars& vars[[ buffer(2)]])
{
    vertexOut out;

    out.normPos=in[vertexID].pos.xy;


    vec4 deltaTransIn=vec4(-7.,-5.,-16.,.0)*dirMult[vars.direction];
    vec4 transOffsetIn=vec4(-3.5,-2.5,-8.0,.0)*dirMult[vars.direction];

    vec4 deltaTransOut=vec4(5.5,4.,10.0,0.0)*dirMult[vars.direction];
    vec4 transOffsetOut=vec4(0.0,0.0,0.0,0.0);

    mat4 projection = mat4(
                           2.4140,0.0000,0.0000,0.0000,
                           0.0000,2.4140,0.0000,0.0000,
                           0.0000,0.0000,-1.020,-.4870,
                           0.0000,0.0000,-1.000,0.0000);
    mat4 translation=mat4(
                          1.0, 0.0, 0.0, 0.00,
                          0.0, 1.0, 0.0, 0.00,
                          0.0, 0.0, 1.0,-2.414,
                          0.0, 0.0, 0.0, 1.00);

    vec4 tmpPos=vec4(in[vertexID].pos.xyz,1.0)*uniforms.worldViewProj;
    vec4 worldPos=tmpPos;

    if(uniforms.transitionState==0)
    {
        float progress=uniforms.transitionProgress*.5;
        transOffsetIn-=deltaTransIn*progress;
        float offset=(.5-progress)*6.5*uniforms.transitionProgress;
        vec2 toff=dirMult[vars.direction].xy*offset;
        transOffsetIn.xy+=toff;

        translation[0][3]+=transOffsetIn.x;
        translation[1][3]+=transOffsetIn.y;
        translation[2][3]+=transOffsetIn.z;
    }
    else if(uniforms.transitionState==2)
    {
        float progress=-uniforms.transitionProgress*.5;
        transOffsetOut-=deltaTransOut*progress;
        float offset=(.5-progress)*3.*uniforms.transitionProgress;
        vec2 toff=dirMult[vars.direction].xy*offset;
        transOffsetIn.xy+=toff;

        translation[0][3]+=transOffsetOut.x;
        translation[1][3]+=transOffsetOut.y;
        translation[2][3]+=transOffsetOut.z;
    }


    worldPos*=translation;
    out.position=worldPos*projection;

    out.tex1=in[vertexID].tex1;
    out.tex2=in[vertexID].tex2;
    out.texMask=in[vertexID].texMask;
    return out;
}

fragment float4 fragmentFunc(vertexOut input [[stage_in]]
                             ,texture2d<half> inputTex0 [[ texture(0) ]]
                             ,texture2d<half> inputTex1 [[ texture(1) ]]
                             ,texture2d<half> inputTexMask [[ texture(2) ]]
                             ,constant txShaderUniforms& u[[ buffer(2) ]]
                             ,constant subFXVars& subVars[[buffer(3)]])
{
    float4 outColor = getColor(inputTex0, input.tex1);
    float alpha = u.transitionAlpha * u.baseAlpha;
#ifdef USE_MASK
    alpha *= float(inputTexMask.sample(linearSampler, input.texMask).a);
#endif
    multiplyOpacity(outColor, alpha);
    return outColor;
}
