// Generic vertex function and first part of fragment function
// Included by ssInputShadersML::LoadShader()

#if __METAL__ || __METAL_MACOS__ || __METAL_IOS__
#include <metal_stdlib>
using namespace metal;
#endif

#include <simd/simd.h>

typedef struct
{
    float hScale;
    float vScale;
    vector_float4 endColor;
    float endOpacity;
} planarUniforms;

typedef struct
{
    vector_float4 pos;
    packed_float2 tc;
} planarVerts;

typedef struct
{
    vector_float4 pos;
    packed_float2 tc;
    packed_float2 maskTC;
} planarVertsMask;

#if __METAL__ || __METAL_MACOS__ || __METAL_IOS__

typedef struct
{
    float4 position [[position]];
    float2 texCoordY;   // normalized
    float2 texCoordUV;  // normalized
    float2 texCoordPS;  // pixel space
    float2 maskCoord;   // normalized
    float2 texCoordV;   // normalize and hold V if different from UV
} inputVertexType;

constexpr sampler nearestSampler(
    mip_filter::nearest,
    mag_filter::nearest,
    min_filter::nearest,
    address::clamp_to_border,
    border_color::transparent_black);

#endif

#define vec2 float2
#define vec3 float3
#define vec4 float4
#define textureFetchNearest(a, b) float4(a.sample(nearestSampler, b))

#ifdef AdobeRGB1998
constant float4 profileToRGB_R = float4(1.398283, -0.398283064, 0.0, 0.0);
constant float4 profileToRGB_G = float4(0.0, 1.0, 0.0, 0.0);
constant float4 profileToRGB_B = float4(0.0, -0.0429382771, 1.04293835, 0.0);
#endif

#ifdef DISPLAY_P3
constant float4 profileToRGB_R = float4(1.2249, -0.2247, 0.0, 0.0);
constant float4 profileToRGB_G = float4(-0.0420, 1.0419, 0.0, 0.0);
constant float4 profileToRGB_B = float4(-0.0197, -0.0786, 1.0979, 0.0);
#endif

#ifdef PRO_PHOTO_RGB
constant float4 profileToRGB_R = float4(2.03407574, -0.72733432, -0.306741565, 0.0);
constant float4 profileToRGB_G = float4(-0.228813201, 1.23173022, -0.00291692792, 0.0);
constant float4 profileToRGB_B = float4(-0.00856976956, -0.153286651, 1.1618564, 0.0);
#endif

#ifdef BT2020
// values obtained from:
// Report ITU-R BT.2407-0 (10/2017) - Colour gamut conversion from Recommendation ITU-R BT.2020 to Recommendation ITU-R
// BT.709
constant float4 profileToRGB_R = float4(1.6605, -0.5876, -0.0728, 0.0);
constant float4 profileToRGB_G = float4(-0.1246, 1.1329, -0.0083, 0.0);
constant float4 profileToRGB_B = float4(-0.0182, -0.1006, 1.1187, 0.0);
#endif

#ifdef FULL_RANGE
// full range values from wikipedia https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion
constant float4 yuv2R = float4(1.402, 1., 0.0, 0.0);
constant float4 yuv2G = float4(-0.714136, 1., -0.344136, 0.0);
constant float4 yuv2B = float4(0.000000, 1., 1.772, 0.0);
constant float4 stdbias = float4(-0.5, 0.0, -0.5, 0.0);
#else
constant float4 yuv2R = float4(1.793, 1.164383, 0.000000, 0.0);
constant float4 yuv2G = float4(-0.534, 1.164383, -0.213, 0.0);
constant float4 yuv2B = float4(0.000000, 1.164383, 2.115, 0.0);
constant float4 stdbias = float4(-0.5, -0.0625, -0.5, 0.0);
#endif

float3 lessThan(float3 a, float3 b)
{
    float3 retVal = float3(0.0);

    if (a.x < b.x)
        retVal.x = 1.0;
    if (a.y < b.y)
        retVal.y = 1.0;
    if (a.z < b.z)
        retVal.z = 1.0;
    return retVal;
}

float4 toGamma(float4 linearRGB)
{
    float3 cutoff = lessThan(linearRGB.rgb, float3(0.0031308));
    float3 higher = float3(1.055) * pow(linearRGB.rgb, float3(1.0 / 2.4)) - float3(0.055);
    float3 lower = linearRGB.rgb * float3(12.92);

    return float4(mix(higher, lower, cutoff), linearRGB.a);
}

float4 toLinear(float4 sRGB)
{
    float3 cutoff = lessThan(sRGB.rgb, float3(0.04045));
    float3 higher = pow((sRGB.rgb + float3(0.055)) / float3(1.055), float3(2.4));
    float3 lower = sRGB.rgb / float3(12.92);

    return float4(mix(higher, lower, cutoff), sRGB.a);
}

float4 linearizeBT709(float4 bt709)
{
    float3 cutoff = lessThan(bt709.rgb, float3(0.081));
    float3 higher = pow(bt709.rgb + float3(0.099), float3(2.2));
    float3 lower = bt709.rgb / float3(4.5);

    return float4(mix(higher, lower, cutoff), bt709.a);
}

#ifdef USE_COLOR_PROFILE
float4 processColorProfile(float4 colorIn)
{
#ifdef BT709
    // BT709 has same color primaries as sRGB so not color correction is required
    float4 srgbLinear = linearizeBT709(colorIn);
#else
    float4 linear = toLinear(colorIn);

    float4 srgbLinear =
        float4(dot(profileToRGB_R, linear), dot(profileToRGB_G, linear), dot(profileToRGB_B, linear), colorIn.a);
#endif
    return (toGamma(srgbLinear));
}
#else
float4 processColorProfile(float4 colorIn)
{
    return colorIn;
}
#endif

#ifdef USE_MASK
vertex inputVertexType vertexFunc(
    uint vertexID [[vertex_id]],
    const device planarVertsMask *in [[buffer(0)]],
    constant planarUniforms &uniforms [[buffer(1)]],
    texture2d<half> inputTex0 [[texture(0)]],
    texture2d<half> inputTex1 [[texture(1)]],
    texture2d<half> inputTex2 [[texture(2)]],
    texture2d<half> inputTex3 [[texture(3)]],
    texture2d<half> maskTex [[texture(4)]])
{
    inputVertexType out;

    out.position = in[vertexID].pos;

    out.texCoordY = in[vertexID].tc / float2(inputTex0.get_width(), inputTex0.get_height());
    out.texCoordUV = in[vertexID].tc / float2(inputTex1.get_width(), inputTex1.get_height());
    out.texCoordUV *= float2(uniforms.hScale, uniforms.vScale);
    out.texCoordPS = in[vertexID].tc;
    out.maskCoord = in[vertexID].maskTC / float2(maskTex.get_width(), maskTex.get_height());

    return out;
}
#else

vertex inputVertexType vertexFunc(
    uint vertexID [[vertex_id]],
    const device planarVerts *in [[buffer(0)]],
    constant planarUniforms &uniforms [[buffer(1)]],
    texture2d<half> inputTex0 [[texture(0)]],
    texture2d<half> inputTex1 [[texture(1)]])
{
    inputVertexType out;

    out.position = in[vertexID].pos;

    out.texCoordY = in[vertexID].tc / float2(inputTex0.get_width(), inputTex0.get_height());
    out.texCoordUV = in[vertexID].tc / float2(inputTex1.get_width(), inputTex1.get_height());
    out.texCoordUV *= float2(uniforms.hScale, uniforms.vScale);
    out.texCoordPS = in[vertexID].tc;

    return out;
}
#endif
