#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    float redLevel;
    float greenLevel;
    float blueLevel;
    float alphaLevel;
    float blackLevel;
    float gamma;
} colorAdjVars;

typedef struct
{
    vector_float4 pos;
    packed_float4 tc;
} basicPT3;

typedef struct
{
    float4 position [[position]];
    float3 texCoord;
} fsPT3;

vertex fsPT3 vertexFunc(uint vertexID [[ vertex_id ]],
                        const device basicPT3* in [[ buffer(0) ]])
{
    fsPT3 out;
    out.position = in[vertexID].pos * vec4(1, -1, 1, 1);
    out.texCoord = in[vertexID].tc.xyz;
    return out;
}

fragment float4 fragmentFunc(fsPT3 input [[stage_in]],
                             texture2d<half> inputTex [[ texture(0) ]],
                             constant colorAdjVars& vars [[ buffer(1) ]])
{
    vec2 texSize = vec2(inputTex.get_width(), inputTex.get_height());
    vec4 outColor = vec4(inputTex.sample(linearSampler, (input.texCoord.xy / input.texCoord.z) / texSize));

#ifdef PRE_MULT
    divideAlpha(outColor);
#endif

    outColor.rgb = pow(outColor.rgb, vec3(1.0 + vars.gamma));
    outColor += vec4(vars.redLevel, vars.greenLevel, vars.blueLevel, 0.0);
    float range = 1.0 - vars.blackLevel;
    outColor.rgb *= range;
    outColor.rgb += vars.blackLevel;
    outColor.a += vars.alphaLevel * (1.0 - outColor.a);

#ifdef PRE_MULT
    multiplyAlpha(outColor);
#endif

    return outColor;
}
