/* -LICENSE-START-
 ** Copyright (c) 2023 Blackmagic Design
 **
 ** Permission is hereby granted, free of charge, to any person or organization
 ** obtaining a copy of the software and accompanying documentation (the
 ** "Software") to use, reproduce, display, distribute, sub-license, execute,
 ** and transmit the Software, and to prepare derivative works of the Software,
 ** and to permit third-parties to whom the Software is furnished to do so, in
 ** accordance with:
 **
 ** (1) if the Software is obtained from Blackmagic Design, the End User License
 ** Agreement for the Software Development Kit ("EULA") available at
 ** https://www.blackmagicdesign.com/EULA/DeckLinkSDK; or
 **
 ** (2) if the Software is obtained from any third party, such licensing terms
 ** as notified by that third party,
 **
 ** and all subject to the following:
 **
 ** (3) the copyright notices in the Software and this entire statement,
 ** including the above license grant, this restriction and the following
 ** disclaimer, must be included in all copies of the Software, in whole or in
 ** part, and all derivative works of the Software, unless such copies or
 ** derivative works are solely in the form of machine-executable object code
 ** generated by a source language processor.
 **
 ** (4) THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 ** OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 ** FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
 ** SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
 ** FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
 ** ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 ** DEALINGS IN THE SOFTWARE.
 **
 ** A copy of the Software is available free of charge at
 ** https://www.blackmagicdesign.com/desktopvideo_sdk under the EULA.
 **
 ** -LICENSE-END-
 */

#include "metalShaderTypes.h"
#include "commonMLDX.h"
#include "alpha.h"
#include "inputColorHelper.h"

constant float4 std709R = float4( 1.793,1.164383, 0.000000,0.0);
constant float4 std709G = float4(-0.534,1.164383,-0.213,0.0);
constant float4 std709B = float4( 0.000000,1.164383, 2.115,0.0);


template <typename T>
struct YCbCr
{
	T Y;
	T Cb;
	T Cr;
};

kernel void YUV10In(
    texture2d<uint, access::read>  inTexture1           [[texture(0)]],
    texture2d<half, access::read>  inTexture2           [[texture(1)]],
    texture2d<half, access::write> outTexture          [[texture(2)]],
    uint2                          gid                 [[thread_position_in_grid]])
{
    texture2d<uint, access::read> inTexture = inTexture1;
    uint flippedY = inTexture.get_height() - 1 - gid.y;
    uint2 flippedGid = uint2(gid.x * 4, flippedY);

    uint pixel0 = inTexture.read(flippedGid + uint2(0, 0)).r;
    uint pixel1 = inTexture.read(flippedGid + uint2(1, 0)).r;
    uint pixel2 = inTexture.read(flippedGid + uint2(2, 0)).r;
    uint pixel3 = inTexture.read(flippedGid + uint2(3, 0)).r;


    float4 yuvIn, outColor;
    uint baseX = gid.x * 6;

    yuvIn.g = ((pixel0 >> 10) & 0x3ff) / 1023.0;
    yuvIn.b = (pixel0 & 0x3ff) / 1023.0;
    yuvIn.r = ((pixel0 >> 20) & 0x3ff) / 1023.0;
    yuvIn.a = 1.;

    yuvIn += stdbias;

    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 0, gid.y));

    yuvIn.g = (pixel1 & 0x3ff) / 1023.0 - 0.0625;

    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 1, gid.y));

    yuvIn.g = ((pixel1 >> 20) & 0x3ff) / 1023.0;
    yuvIn.b = ((pixel1 >> 10) & 0x3ff) / 1023.0;
    yuvIn.r = (pixel2 & 0x3ff) / 1023.0;

    yuvIn += stdbias;
    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 2, gid.y));

    yuvIn.g = ((pixel2 >> 10) & 0x3ff) / 1023.0 - 0.0625;

    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 3, gid.y));

    yuvIn.g = (pixel3  & 0x3ff) / 1023.0;
    yuvIn.b = ((pixel2 >> 20) & 0x3ff) / 1023.0;
    yuvIn.r = ((pixel3 >> 10) & 0x3ff) / 1023.0;

    yuvIn += stdbias;
    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 4, gid.y));


    yuvIn.g = ((pixel3 >> 20) & 0x3ff) / 1023.0 - 0.0625;

    outColor=vec4(dot(yuv2R, yuvIn),dot(yuv2G, yuvIn),dot(yuv2B, yuvIn),1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 5, gid.y));
}

kernel void YUV8In(
    texture2d<uint, access::read>  inTexture1           [[texture(0)]],
    texture2d<half, access::read>  inTexture2           [[texture(1)]],
    texture2d<half, access::write> outTexture          [[texture(2)]],
    uint2                          gid                 [[thread_position_in_grid]])
{
    texture2d<uint, access::read> inTexture = inTexture1;
    uint flippedY = inTexture.get_height() - 1 - gid.y;
    uint2 flippedGid = uint2(gid.x, flippedY);

    uint pixel0 = inTexture.read(flippedGid).r;

    float4 yuvIn;
    uint baseX = gid.x * 2;

    yuvIn.g = (float)(((pixel0 >> 8) & 0xff)) / 255.0;
    yuvIn.b = (float)((pixel0 & 0xff)) / 255.0;
    yuvIn.r = (float)(((pixel0 >> 16) & 0xff)) / 255.0;
    yuvIn.a = 1.;

    yuvIn += stdbias;

    float4 outColor = float4(dot(std709R, yuvIn),
                             dot(std709G, yuvIn),
                             dot(std709B, yuvIn),
                             1.0);

    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 0, gid.y));

    yuvIn.g = (float)(((pixel0 >> 24) & 0xff)) / 255.0  - 0.0625;

    outColor=float4(dot(std709R, yuvIn),
                    dot(std709G, yuvIn),
                    dot(std709B, yuvIn),
                           1.0);
    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 1, gid.y));
}

kernel void P216In(
    texture2d<half, access::read>  inTexture1           [[texture(0)]],
    texture2d<half, access::read>  inTexture2           [[texture(1)]],
    texture2d<half, access::write> outTexture          [[texture(2)]],
    uint2                          gid                 [[thread_position_in_grid]])
{
    texture2d<half, access::read> inTexture = inTexture1;
    uint flippedY = inTexture.get_height() - 1 - gid.y;
    uint2 lumaGid = uint2(gid.x, flippedY - inTexture.get_height()/2);
    uint x = gid.x / uint(2) * uint(2);
    uint2 uvGid = uint2(x, flippedY);

    half luma = inTexture.read(lumaGid).r;
    half cr = inTexture.read(uvGid).r;
    half cb = inTexture.read(uvGid + uint2(1,0)).r;


    float4 yuvIn;
    uint baseX = gid.x;

    yuvIn.g = luma;
    yuvIn.b = cr;
    yuvIn.r = cb;
    yuvIn.a = 1.;

    yuvIn += stdbias;

    float4 outColor = float4(dot(std709R, yuvIn),
                             dot(std709G, yuvIn),
                             dot(std709B, yuvIn),
                             1.0);

    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 0, gid.y));
}

kernel void PA16In(
    texture2d<half, access::read>  inTexture1           [[texture(0)]],
    texture2d<half, access::read>  inTexture2           [[texture(1)]],
    texture2d<half, access::write> outTexture          [[texture(2)]],
    uint2                          gid                 [[thread_position_in_grid]])
{
    texture2d<half, access::read> inTexture = inTexture1;
    uint flippedY = inTexture.get_height() - 1 - gid.y;
    uint sectionHeight =  inTexture.get_height()/3;
    uint2 lumaGid = uint2(gid.x, flippedY - sectionHeight*2);
    uint x = gid.x / uint(2) * uint(2);
    uint2 uvGid = uint2(x, flippedY  - sectionHeight);
    uint2 alphaGid = uint2(x, flippedY);

    half luma = inTexture.read(lumaGid).r;
    half cr = inTexture.read(uvGid).r;
    half cb = inTexture.read(uvGid + uint2(1,0)).r;


    float4 yuvIn;
    uint baseX = gid.x;

    yuvIn.g = luma;
    yuvIn.b = cr;
    yuvIn.r = cb;
    yuvIn.a = inTexture.read(alphaGid).r;;

    yuvIn += stdbias;

    float4 outColor = float4(dot(std709R, yuvIn),
                             dot(std709G, yuvIn),
                             dot(std709B, yuvIn),
                             1.0);

    outColor=processColorProfile(outColor);
    outTexture.write(half4(outColor), uint2(baseX + 0, gid.y));
}
