#version 310 es
precision highp float;

layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;

uniform uint cP;

layout(binding=0, rgba32f) uniform readonly highp image2D iFieldTimeXY;
layout(binding=1, rgba32f) uniform readonly highp image2D iFieldTimeZ;//rg32f

layout(binding=2, rgba32f) uniform writeonly highp image2D oDisplacementXY;
layout(binding=3, rgba32f) uniform writeonly highp image2D oDisplacementZ;//rg32f

const float M_PI = 3.14159265358979323846264338327950288419716939937510582097;          // pi

vec2 complexMul(vec2 a, vec2 b) { return vec2(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }

// source: http://www.bealto.com/gpu-fft_opencl-1.html

void DFT2(inout vec2 a, inout vec2 b)
{
    vec2 tmp = a - b;
    a += b;
    b = tmp;
}

void main()
{
    uint Y = gl_GlobalInvocationID.y;

    uint i = gl_GlobalInvocationID.x;
    uint t = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
    uint k = i & (cP-1u);

    vec4 u0xy = imageLoad(iFieldTimeXY, ivec2(i, Y));
    vec4 u1xy = imageLoad(iFieldTimeXY, ivec2(i+t, Y));

    vec2 u0z = imageLoad(iFieldTimeZ, ivec2(i, Y)).xy;
    vec2 u1z = imageLoad(iFieldTimeZ, ivec2(i+t, Y)).xy;

    float a = -M_PI*float(k)/float(cP);
    //float a = -M_PI*float(int(k) - int(t / 2u))/float(cP);
    vec2 twiddle = vec2(cos(a), sin(a));

    u1xy.xy = complexMul(u1xy.xy, twiddle);
    u1xy.zw = complexMul(u1xy.zw, twiddle);
    u1z = complexMul(u1z, twiddle);

    DFT2(u0xy.xy, u1xy.xy);
    DFT2(u0xy.zw, u1xy.zw);
    DFT2(u0z, u1z);

    uint j = (i << 1u) - k; // = ((i-k)<<1)+k

    imageStore(oDisplacementXY, ivec2(j, Y), u0xy);
    imageStore(oDisplacementXY, ivec2(j+cP, Y), u1xy);

    imageStore(oDisplacementZ, ivec2(j, Y), vec4(u0z, 0.0, 0.0));
    imageStore(oDisplacementZ, ivec2(j+cP, Y), vec4(u1z, 0.0, 0.0));
}