#include "commonMLDX.h"
#include "alpha.h"
#include "outputColorHelper.h"

Texture2D<float4> inTexture1 : register(t0);
Texture2D<float4> inTexture2 : register(t1);
#ifdef TO_YUV
RWTexture2D<uint> outTexture : register(u0);
#else
RWTexture2D<float4> outTexture : register(u0);
#endif

inline float4 getInputPixel(const uint2 coord)
{
    const int3 coord3 = int3(coord, 0);
#ifdef INTERLACED
    return ((coord.y & 1u) == 0u) ? inTexture1.Load(coord3) : inTexture2.Load(coord3);
#else
    return inTexture1.Load(coord3);
#endif
}

inline uint to8Bit(const float val)
{
    return uint(round(val * 255.0)) & 0xFF;
}

inline uint to10Bit(const float val)
{
    return uint(round(val * 1023.0)) & 0x3FF;
}

#ifdef TO_YUV

[numthreads(16, 16, 1)]
void BGRATo2vuyKernel8(const uint3 DTid : SV_DispatchThreadID)
{
    const uint2 gid = DTid.xy;
    const uint2 baseCoord = gid * uint2(2, 1);

    // Read and convert first + second pixel
    const float3 yuv0 = toOutput(getInputPixel(baseCoord)).rgb;
    const float3 yuv1 = toOutput(getInputPixel(baseCoord + uint2(1, 0))).rgb;

    // Convert to 8-bit (average chroma)
    const uint y0 = to8Bit(yuv0.r);
    const uint y1 = to8Bit(yuv1.r);
    const uint cb = to8Bit(0.5f * (yuv0.g + yuv1.g));
    const uint cr = to8Bit(0.5f * (yuv0.b + yuv1.b));

    // Pack into 32-bit word + write output
    outTexture[gid] = cb | (y0 << 8) | (cr << 16) | (y1 << 24);
}

[numthreads(16, 16, 1)]
void BGRATo2vuyKernel10(const uint3 DTid : SV_DispatchThreadID)
{
    const uint2 gid = DTid.xy;
    const uint2 baseCoord = gid * uint2(6, 1);

    // Read 6 RGB pixels and convert
    const float3 yuv0 = toOutput(getInputPixel(baseCoord + uint2(0, 0))).rgb;
    const float3 yuv1 = toOutput(getInputPixel(baseCoord + uint2(1, 0))).rgb;
    const float3 yuv2 = toOutput(getInputPixel(baseCoord + uint2(2, 0))).rgb;
    const float3 yuv3 = toOutput(getInputPixel(baseCoord + uint2(3, 0))).rgb;
    const float3 yuv4 = toOutput(getInputPixel(baseCoord + uint2(4, 0))).rgb;
    const float3 yuv5 = toOutput(getInputPixel(baseCoord + uint2(5, 0))).rgb;

    // Convert luma to 10-bit
    const uint y0 = to10Bit(yuv0.r);
    const uint y1 = to10Bit(yuv1.r);
    const uint y2 = to10Bit(yuv2.r);
    const uint y3 = to10Bit(yuv3.r);
    const uint y4 = to10Bit(yuv4.r);
    const uint y5 = to10Bit(yuv5.r);

    // Average chroma and convert to 10-bit
    const uint cb0 = to10Bit(0.5 * (yuv0.g + yuv1.g));
    const uint cr0 = to10Bit(0.5 * (yuv0.b + yuv1.b));
    const uint cb1 = to10Bit(0.5 * (yuv2.g + yuv3.g));
    const uint cr1 = to10Bit(0.5 * (yuv2.b + yuv3.b));
    const uint cb2 = to10Bit(0.5 * (yuv4.g + yuv5.g));
    const uint cr2 = to10Bit(0.5 * (yuv4.b + yuv5.b));

    // Pack values into 32-bit words + write output
    const uint baseX = gid.x * 4;
    outTexture[uint2(baseX, gid.y)] = cb0 | (y0 << 10) | (cr0 << 20);
    outTexture[uint2(baseX + 1, gid.y)] = y1 | (cb1 << 10) | (y2 << 20);
    outTexture[uint2(baseX + 2, gid.y)] = cr1 | (y3 << 10) | (cb2 << 20);
    outTexture[uint2(baseX + 3, gid.y)] = y4 | (cr2 << 10) | (y5 << 20);
}

#else  // !TO_YUV

[numthreads(16, 16, 1)]
void BRGACopy(const uint3 DTid : SV_DispatchThreadID)
{
    const uint2 gid = DTid.xy;
    const float4 rgba = toOutput(getInputPixel(gid));
    outTexture[gid] = rgba.argb;
}

#endif  // !TO_YUV
