
#include "metalShaderTypes.h"
#include "alpha.h"

constexpr sampler backgroundSampler(
    mip_filter::nearest, 
    mag_filter::linear,
    min_filter::linear,
    s_address::clamp_to_edge,
    t_address::clamp_to_edge);

#define getBackgroundColor(a, b) float4(a.sample(backgroundSampler, b))

typedef struct
{
    float shadowEnable;
    float pad[3];
} fxVars;

vertex vertexOut vertexFunc(uint vertexID [[ vertex_id ]],const device txShaderVers* in[[ buffer(0) ]])
{
    vertexOut out;

    out.position=in[vertexID].pos;

    out.tex1=in[vertexID].tex1;
    out.tex2=in[vertexID].tex2;
    out.texMask=in[vertexID].texMask;
    return out;
}

fragment float4 fragmentFunc(vertexOut input [[stage_in]]
                             ,texture2d<half> inputTex0 [[ texture(0) ]]
                             ,texture2d<half> inputTex1 [[ texture(1) ]]
                             ,texture2d<half> inputTexMask [[ texture(2) ]]
                             ,constant fxVars& vars[[ buffer(1)]])
{
    // saving this here in case it is useful later
    //    bool isInBlurredArea = input.tex2.x < 0.0 || input.tex2.x > 1.0 || input.tex2.y < 0.0 || input.tex2.y > 1.0;

    float4 blur = getBackgroundColor(inputTex0, input.tex1);
    float4 org=getColor(inputTex1,input.tex2);

    float shadowWidth = 0.05;
    float edgeDistX = abs(input.tex2.x - clamp(input.tex2.x, 0.0, 1.0));
    float edgeDistY = abs(input.tex2.y - clamp(input.tex2.y, 0.0, 1.0));
    float edgeDist = max(edgeDistX, edgeDistY);

    float shadowStrength = smoothstep(0.0, shadowWidth, edgeDist);

    float minShadowStrength = 0.55; // we don't want it completely dark on edges, so limit to 55%
    float maxShadowStrength = 0.85; // we want to darken the entire blurred area a bit as well

    float remappedShadowStrength = minShadowStrength + (maxShadowStrength - minShadowStrength) * shadowStrength;

    float3 shadowBlur=blur.rgb*remappedShadowStrength;

    blur.rgb=mix(blur.rgb,shadowBlur.rgb,vars.shadowEnable);

    float4 outColor;
    blendColorSrcOverDst(outColor, org, blur);

#ifdef USE_MASK
    float alpha = float(inputTexMask.sample(linearSampler, input.texMask).a);
    multiplyOpacity(outColor, alpha);
#endif

    return outColor;
}
