#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    float blurAmount;
    float blurRadius;
} fxVars;

typedef struct
{
    float4 position [[position]];
    float2 tex1;
    float2 texMask;
    float2 normPos;
    float2 texSize;
    float2 texCoords0;
    float2 texCoords1;
    float2 texCoords2;
    float2 texCoords3;
    float2 texCoords4;
    float2 texCoords5;
    float2 texCoords6;
    float2 texCoords7;
    float2 texCoords8;
    float2 texCoords9;
    float2 texCoords10;
    float2 texCoords11;
    float2 texCoords12;
    float2 texCoords13;
    float2 texCoords14;
} fsVertexOut;

vertex fsVertexOut vertexFunc(uint vertexID [[ vertex_id ]]
                              ,const device fxShaderVerts* in[[ buffer(0) ]]
                              ,texture2d<half> inputTex0 [[ texture(0) ]]
                              ,texture2d<half> inputTexMask [[ texture(1) ]]
                              ,constant fxGeneralUniforms& u[[ buffer(1) ]]
                              ,constant fxVars& vars[[ buffer(2)]])
{
    fsVertexOut out;

    out.normPos=in[vertexID].pos.xy*vec2(.5,.5f)+vec2(.5,.5);
    out.position=in[vertexID].pos;

    out.tex1=in[vertexID].tex1;
    out.texMask=in[vertexID].texMask;
    out.texSize=float2(inputTex0.get_width(),inputTex0.get_height());

    float fsScale=vars.blurAmount*3.0;
    if(fsScale==0.0)
        fsScale=0.1;

    float2 direction;
    float pass=u.pass;
    if(pass>1)
        pass-=2;
    if(pass==0)
    {
        direction=vec2((fsScale+fsScale*.5)/out.texSize.x,0.0);
    }
    else
    {
        direction=vec2(0.0,(fsScale+fsScale*.5)/out.texSize.y);
    }

    out.texCoords0 = in[vertexID].tex1 + vec2(-7.0,-7.0)*direction;
    out.texCoords1 = in[vertexID].tex1 + vec2(-6.0,-6.0)*direction;
    out.texCoords2 = in[vertexID].tex1 + vec2(-5.0,-5.0)*direction;
    out.texCoords3 = in[vertexID].tex1 + vec2(-4.0,-4.0)*direction;
    out.texCoords4 = in[vertexID].tex1 + vec2(-3.0,-3.0)*direction;
    out.texCoords5 = in[vertexID].tex1 + vec2(-2.0,-2.0)*direction;
    out.texCoords6 = in[vertexID].tex1 + vec2(-1.0,-1.0)*direction;
    out.texCoords7 = in[vertexID].tex1 + vec2( 1.0, 1.0)*direction;
    out.texCoords8 = in[vertexID].tex1 + vec2( 2.0, 2.0)*direction;
    out.texCoords9 = in[vertexID].tex1 + vec2( 3.0, 3.0)*direction;
    out.texCoords10 = in[vertexID].tex1 + vec2( 4.0, 4.0)*direction;
    out.texCoords11 = in[vertexID].tex1 + vec2( 5.0, 5.0)*direction;
    out.texCoords12 = in[vertexID].tex1 + vec2( 6.0, 6.0)*direction;
    out.texCoords13 = in[vertexID].tex1 + vec2( 7.0, 7.0)*direction;
    out.texCoords14 = in[vertexID].tex1;
    return out;
}

fragment float4 fragmentFunc(fsVertexOut input [[stage_in]]
                             ,texture2d<half> inputTex [[ texture(0) ]]
                             ,texture2d<half> inputTexMask [[ texture(1) ]]
                             ,constant fxGeneralUniforms& u[[ buffer(0) ]]
                             ,constant fxVars& vars[[ buffer(1)]])
{
    vec2 center=vec2(.5,.5);
    float dist = length(input.tex1 - center);

    vec4 texel = getColor(inputTex, fsTexture);
    float4 outColor=vec4(0.0);

    outColor += getColor(inputTex, input.texCoords0);
    outColor += getColor(inputTex, input.texCoords1);
    outColor += getColor(inputTex, input.texCoords2);
    outColor += getColor(inputTex, input.texCoords3);
    outColor += getColor(inputTex, input.texCoords4);
    outColor += getColor(inputTex, input.texCoords5);
    outColor += getColor(inputTex, input.texCoords6);
    vec4 srcColor = getColor(inputTex, input.texCoords14);
    outColor+=srcColor;
    outColor += getColor(inputTex, input.texCoords7);
    outColor += getColor(inputTex, input.texCoords8);
    outColor += getColor(inputTex, input.texCoords9);
    outColor += getColor(inputTex, input.texCoords10);
    outColor += getColor(inputTex, input.texCoords11);
    outColor += getColor(inputTex, input.texCoords12);
    outColor += getColor(inputTex, input.texCoords13);

    outColor*=vec4(1.0/15.0,1.0/15.0,1.0/15.0,1.0/15.0);
    float blend=(dist*vars.blurRadius)*4.;
    blend=min(1.0,blend);
    mixColor(outColor, srcColor, outColor, blend);

    float blendValue=u.blendValue;
#ifdef USE_MASK
    blendValue*=float(inputTexMask.sample(linearSampler, input.texMask).a);
#endif
    mixColor(outColor, texel, outColor, blendValue);
    return outColor;
}
