#include "metalShaderTypes.h"
#include "alpha.h"

typedef struct
{
    float sharpnessValue;
    float gcrValue;
    float ucrValue;
} fxVars;

/* Angle of halftone grid (degrees; positive = counterclockwise) */
constant float angle = -45.0;

/* Contrast control */
constant float contrastDelta = 0.3; // higher -> grey gets darker
constant float brightness = 0.0; // analog for white
constant float blackness = 1.1; // higher -> larger areas completely covered by dots

/* smoothness black to white (pseudo anti-aliasing). */
constant float smoothValue = 0.2;

constant vec4 dotColor = vec4(0.0, 0.0, 0.0, 1.0);
constant float kHalftoneWidth = 128.0;


vec4 RGBtoCMYK(vec3 rgb)
{
    vec3 cmycolor = vec3(1.0,1.0,1.0)-rgb;
    float k = min(cmycolor.r, min(cmycolor.g,cmycolor.b));
    vec4 cmyk = vec4((cmycolor-k)/(1.0-k),k);

    return cmyk;
}
vec3 CMYKtoRGB(vec4 cmyk)
{
    float k = cmyk.w;

    float r = (1.0-cmyk.x)*(1.0-k);
    float g = (1.0-cmyk.y)*(1.0-k);
    float b = (1.0-cmyk.z)*(1.0-k);

    return vec3(r,g,b);
}

vertex fxVertexOut vertexFunc(uint vertexID [[ vertex_id ]]
                              ,const device fxShaderVerts* in[[ buffer(0) ]]
                              ,texture2d<half> inputTex0 [[ texture(0) ]]
                              ,texture2d<half> inputTexMask [[ texture(1) ]]
                              ,constant fxGeneralUniforms& u[[ buffer(1) ]]
                              ,constant fxVars& vars[[ buffer(2)]])
{
    fxVertexOut out;

    out.normPos=in[vertexID].pos.xy*vec2(.5,.5f)+vec2(.5,.5);
    out.position=in[vertexID].pos;

    out.tex1=in[vertexID].tex1;
    out.texMask=in[vertexID].texMask;
    out.texSize=float2(inputTex0.get_width(),inputTex0.get_height());
    return out;
}

fragment float4 fragmentFunc(fxVertexOut input [[stage_in]]
                             ,texture2d<half> inputTex [[ texture(0) ]]
                             ,texture2d<half> inputTexMask [[ texture(1) ]]
                             ,constant fxGeneralUniforms& u[[ buffer(0) ]]
                             ,constant fxVars& vars[[ buffer(1)]])
{

    vec4 texel = getColor(inputTex, fsTexture);
#ifdef PRE_MULT
    divideAlpha(texel);
#endif
    mat2 rotate = mat2(cos(angle), -sin(angle),
                       sin(angle), cos(angle));
    mat2 inverse_rotate = mat2(cos(angle), sin(angle),
                               -sin(angle), cos(angle));

    /* Distance to next dot divided by two. */
    float halftoneWidth = kHalftoneWidth*min(2000.0, 1.0/vars.sharpnessValue);

    vec2 halfLineDist = halftoneWidth> 0.?vec2(1.0)/vec2(halftoneWidth)/vec2(2.0):vec2(0.);

    /* Find center of the halftone dot. */
    vec2 st = fsTexture;
    vec2 center =  rotate * st;
    center = halftoneWidth> 0.0?floor(center * vec2(halftoneWidth,halftoneWidth)) / vec2(halftoneWidth,halftoneWidth):center;
    center += halfLineDist;
    center = inverse_rotate * center;

    /* Only red (texture is gray scale) */
    vec4 centerColor = getColor(inputTex, center);
#ifdef PRE_MULT
    divideAlpha(centerColor);
#endif
    float luminance = centerColor.r;

    /* Radius of the halftone dot. */
    float blacknessValue = blackness;
    float radius = sqrt(2.0)*halfLineDist.x*(1.0 - luminance)*blacknessValue;

    float contrastDeltaValue = contrastDelta;
    float brightnessValue = brightness;

    float contrast = 1.0 + (contrastDeltaValue)/(2.0);
    float radiusSqrd = contrast * pow(radius,2.0)
    - (contrastDeltaValue * halfLineDist.x*halfLineDist.x)/2.0
    - brightnessValue * halfLineDist.x*halfLineDist.x;


    vec2 power = pow(abs(center-st),vec2(2.0));
    float pixelDist2 = power.x + power.y; // Distance pixel to center squared.

    float delta = smoothValue*pow(halfLineDist.x,2.0);
    float gradient = smoothstep(radiusSqrd-delta, radiusSqrd+delta, pixelDist2);

    float4 outColor = getColor(inputTex, fsTexture);
#ifdef PRE_MULT
    divideAlpha(outColor);
#endif
    vec3 irgb = outColor.rgb;

    // do GCR
    vec3 cmycolor = vec3(1,1,1) - irgb;
    float k = min(cmycolor.x, min(cmycolor.y, cmycolor.z))*vars.gcrValue*2.0;
    vec4 cmyk =vec4(cmycolor-vec3(k,k,k), k);

    // do UCR
    irgb = CMYKtoRGB(cmyk);
    cmyk = RGBtoCMYK(irgb);
    k = cmyk.w*vars.ucrValue;
    cmyk =cmyk-vec4(k,k,k,k);
    irgb = CMYKtoRGB(cmyk);

    outColor.rgb=mix(dotColor.rgb,irgb,gradient);
    float blendValue=u.blendValue;
#ifdef USE_MASK
    blendValue*=getColor(inputTexMask,input.texMask).a;
#endif
    outColor.rgb=mix(texel.rgb,outColor.rgb,blendValue);
#ifdef PRE_MULT
    multiplyAlpha(outColor);
#endif
    return outColor;
}
