#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    int direction;
} fxVars;

constant float2 directionTable[9] = {
    float2( 1,  1),
    float2( 0,  1),
    float2(-1,  1),
    float2( 1,  0),
    float2( 0,  0),
    float2(-1,  0),
    float2( 1, -1),
    float2( 0, -1),
    float2(-1, -1)
};

constant float2 center = float2(0.5, 0.5);

vertex vertexOut vertexFunc(uint vertexID [[vertex_id]],
                            const device txShaderVers *in [[buffer(0)]],
                            constant txShaderUniforms &uniforms [[buffer(3)]],
                            constant fxVars &vars [[buffer(2)]])
{
    vertexOut out;
    out.position = in[vertexID].pos * uniforms.worldViewProj;
    out.tex1 = in[vertexID].tex1;
    out.tex2 = in[vertexID].tex2;
    out.texMask = in[vertexID].texMask;
    out.normPos = in[vertexID].pos.xy * 0.5 + center;
    return out;
}

fragment float4 fragmentFunc(vertexOut input [[stage_in]],
                             texture2d<half> inputTex0 [[texture(0)]],
                             texture2d<half> inputTex1 [[texture(1)]],
                             texture2d<half> inputTexMask [[texture(2)]],
                             constant txShaderUniforms &u [[buffer(2)]],
                             constant subFXVars &subVars [[buffer(3)]],
                             constant fxVars &vars [[buffer(1)]])
{
    float4 colorA = getColor(inputTex0, input.tex1);
    float4 colorB = getColor(inputTex1, input.tex2);

    float2 dir = normalize(directionTable[vars.direction]);
    float diagonalScale = abs(dir.x) + abs(dir.y);
    if (diagonalScale > 0)
        dir /= diagonalScale;
    float prog = dot(dir, input.normPos) - dot(dir, center)
                 + (1 - u.transitionProgress) * 1.2 - 0.1;
    float blend = smoothstep(0.4, 0.6, prog);
    blend = mix(1 - blend, blend, u.mix);

    float4 outColor;
    mixColor(outColor, colorA, colorB, blend);

    float alpha = u.baseAlpha;
#ifdef USE_MASK
    alpha *= float(inputTexMask.sample(linearSampler, input.texMask).a);
#endif
    multiplyOpacity(outColor, alpha);

    return outColor;
}
