
#include "kernel_noisecancel.cuh"

#define PI 3.14159

__device__ bool messageprint = false;

/*
This function is used to obtain covariance matrice by multiplying SRF and SRF^T functions over the sphere
*/
__global__ void CovarianceMatricesKernel(float* c, float* a, float* a_im, int noffrequency, int legendresamples)
{
    // For legendre samples 
    int COL = blockIdx.y * blockDim.y + threadIdx.y;
    // For Frequency indices
    int ROW = blockIdx.x * blockDim.x + threadIdx.x;
    // For Block Size
    int PLANE_OFFSET = blockIdx.z * blockDim.z + threadIdx.z;

    if (ROW < noffrequency && COL < legendresamples) {

        for (int row = 0; row < legendresamples; row++)
        {
            int ind = row * legendresamples + COL;
            c[PLANE_OFFSET * noffrequency * legendresamples * legendresamples + ROW * legendresamples * legendresamples + ind] = a[PLANE_OFFSET * noffrequency * legendresamples + ROW * legendresamples + row] * a[PLANE_OFFSET * noffrequency * legendresamples + ROW * legendresamples + COL] - a_im[PLANE_OFFSET * noffrequency * legendresamples + ROW * legendresamples + row] * a_im[PLANE_OFFSET * noffrequency * legendresamples + ROW * legendresamples + COL];
        }


    }
}
/*
This function is used for left multiplication A * COV
*/
__global__ void LeftMultiplierMatricesKernel(float* dev_c, float* dev_a, float* dev_wpfcoeff, int tfheight, int legendresamples, int nofsources)
{
    // For legendre samples * Nofsources
    int COL = blockIdx.y * blockDim.y + threadIdx.y;
    // For Frequency indices
    int ROW = blockIdx.x * blockDim.x + threadIdx.x;
    // For Block Size
    int PLANE_OFFSET = blockIdx.z * blockDim.z + threadIdx.z;

    if (ROW < tfheight && COL < legendresamples * nofsources) {
        int col = COL % legendresamples;
        int row = COL / legendresamples;
        for (int ik = 0; ik < legendresamples; ik++) {
            dev_c[PLANE_OFFSET * tfheight * legendresamples * nofsources + ROW * legendresamples * nofsources + COL] += dev_wpfcoeff[row * legendresamples + ik] * dev_a[PLANE_OFFSET * tfheight * legendresamples * legendresamples + ROW * legendresamples * legendresamples + ik * legendresamples + col];
           // dev_c_im[PLANE_OFFSET * tfheight * legendresamples * nofsources + ROW * legendresamples * nofsources + COL] += dev_wpfcoeff[row * legendresamples + ik] * dev_a_im[PLANE_OFFSET * tfheight * legendresamples * legendresamples + ROW * legendresamples * legendresamples + ik * legendresamples + col];
        }
    }

}

/*
This function is used for right multiplication COV * A
*/

__global__ void RightMultiplierMatricesKernel(float* dev_c, float* dev_a, float* dev_wpfcoeff, int tfheight, int legendresamples, int nofsources)
{
    // For legendre samples * Nofsources
    int COL = blockIdx.y * blockDim.y + threadIdx.y;
    // For Frequency indices
    int ROW = blockIdx.x * blockDim.x + threadIdx.x;
    // For Block Size
    int PLANE_OFFSET = blockIdx.z * blockDim.z + threadIdx.z;

    if (ROW < tfheight && COL < nofsources * nofsources) {

        int col_m = COL / nofsources;
        int row_m = (nofsources-1) - (COL % nofsources);
        for (int ik = 0; ik < legendresamples; ik++) {
            dev_c[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + COL] += dev_wpfcoeff[row_m * legendresamples + ik] * dev_a[PLANE_OFFSET * tfheight * legendresamples * nofsources + ROW * legendresamples * nofsources + col_m * legendresamples + ik];
        }

        dev_c[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + COL] = abs(dev_c[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + COL]);
    }

}


__global__ void FinalMultiplierMatricesKernel(float* dev_a, float* dev_a_im, float* dev_tfmultiplier, float* dev_sumtfmultiplier, int tfheight, int legendresamples, int nofsources)
{

    // For number of sources 
    int COL = blockIdx.y * blockDim.y + threadIdx.y;
    // For Frequency indices
    int ROW = blockIdx.x * blockDim.x + threadIdx.x;
    // For Block Size
    int PLANE_OFFSET = blockIdx.z * blockDim.z + threadIdx.z;

    if (ROW < tfheight && COL < nofsources) {
        int sid = COL;
        int id = PLANE_OFFSET * tfheight * nofsources + ROW * nofsources + sid;
        int index = PLANE_OFFSET * tfheight + ROW;
        float real_part = dev_a[id] * dev_tfmultiplier[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + sid * nofsources + sid];
        float imag_part = dev_a_im[id] * dev_tfmultiplier[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + sid * nofsources + sid];

        dev_a[id] =  (real_part) / dev_sumtfmultiplier[index];
        dev_a_im[id] = (imag_part) / dev_sumtfmultiplier[index];
    }
}


__global__ void SumTFMatricesKernel(float* dev_a, float* dev_tfmultiplier, int tfheight, int nofsources)
{

    // For number of sources 
    int COL = blockIdx.y * blockDim.y + threadIdx.y;
    // For Frequency indices
    int ROW = blockIdx.x * blockDim.x + threadIdx.x;
    // For Block Size
    int PLANE_OFFSET = blockIdx.z * blockDim.z + threadIdx.z;

    int id = PLANE_OFFSET * tfheight + ROW;

    for (int ijk = 0; ijk < nofsources * nofsources; ijk++)
    {
       // float realm = dev_tfmultiplier[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + ijk];
      //  float imagm = dev_tfmultiplier_im[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + ijk];
        dev_a[id] += dev_tfmultiplier[PLANE_OFFSET * tfheight * nofsources * nofsources + ROW * nofsources * nofsources + ijk];

    }

}



// dev_a : SRF (Legendre x 1)
// dev_c : AudioBlock x NofFrequencybins x Legendre x Legendre
cudaError_t Device_CovarianceSRF(float* dev_c, float* dev_a, float* dev_a_im, unsigned int freqanalysiswindow, unsigned int BlockSize, unsigned int tfheight, unsigned int freqstart, unsigned int legendresamples)
{
    cudaError_t cudaStatus;

    dim3 threadsPerBlock(tfheight, legendresamples, BlockSize);
    dim3 blocksPerGrid(1, 1, 1);

    if (tfheight > 16)
    {
        threadsPerBlock.x = 16;
        blocksPerGrid.x = ceil(double(tfheight) / double(threadsPerBlock.x));
    }

    if (legendresamples > 16)
    {
        threadsPerBlock.y = 16;
        blocksPerGrid.y = ceil(double(legendresamples) / double(threadsPerBlock.y));
    }

    if (BlockSize > 1)
    {
        threadsPerBlock.z = 2;
        blocksPerGrid.z = ceil(double(BlockSize) / double(threadsPerBlock.z));
    }

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    // Launch a kernel on the GPU with one thread for each element.
    CovarianceMatricesKernel << <blocksPerGrid, threadsPerBlock >> > (dev_c, dev_a, dev_a_im, tfheight, legendresamples);
    // Check for any errors launching the kernel
    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
    }
    // cudaDeviceSynchronize waits for the kernel to finish, and returns
 // any errors encountered during the launch.
    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    if (messageprint)
        printf("MatrixSwap(us) = %d \n", std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());

    return cudaStatus;




}
// OUT : dev_a : Nofsource x Nofsource
// IN : dev_c : AudioBlock x NofFrequencybins x Legendre x Legendre
// IN : dev_wpfcoeff : Nofsource x Legendre
cudaError_t Device_LeftMultiplier(float* dev_c, float* dev_a, float* dev_wpfcoeff, unsigned int freqanalysiswindow, unsigned int BlockSize, unsigned int tfheight, unsigned int freqstart, unsigned int legendresamples, unsigned int nofsources)
{
    cudaError_t cudaStatus;

    dim3 threadsPerBlock(tfheight, legendresamples * nofsources, BlockSize);
    dim3 blocksPerGrid(1, 1, 1);

    if (tfheight > 16)
    {
        threadsPerBlock.x = 16;
        blocksPerGrid.x = ceil(double(tfheight) / double(threadsPerBlock.x));
    }

    if (legendresamples * nofsources > 16)
    {
        threadsPerBlock.y = 16;
        blocksPerGrid.y = ceil(double(legendresamples * nofsources) / double(threadsPerBlock.y));
    }

    if (BlockSize > 1)
    {
        threadsPerBlock.z = 2;
        blocksPerGrid.z = ceil(double(BlockSize) / double(threadsPerBlock.z));
    }

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    // Launch a kernel on the GPU with one thread for each element.
    LeftMultiplierMatricesKernel << <blocksPerGrid, threadsPerBlock >> > (dev_c, dev_a, dev_wpfcoeff, tfheight, legendresamples, nofsources);
    // Check for any errors launching the kernel
    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
    }
    // cudaDeviceSynchronize waits for the kernel to finish, and returns
 // any errors encountered during the launch.
    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    if (messageprint)
        printf("MatrixSwap(us) = %d \n", std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());

    return cudaStatus;


}

cudaError_t Device_RightMultiplier(float* dev_c, float* dev_a, float* dev_wpfcoeff, unsigned int freqanalysiswindow, unsigned int BlockSize, unsigned int tfheight, unsigned int freqstart, unsigned int legendresamples, unsigned int nofsources)
{
    cudaError_t cudaStatus;

    dim3 threadsPerBlock(tfheight, nofsources * nofsources, BlockSize);
    dim3 blocksPerGrid(1, 1, 1);

    if (tfheight > 16)
    {
        threadsPerBlock.x = 16;
        blocksPerGrid.x = ceil(double(tfheight) / double(threadsPerBlock.x));
    }

    if (nofsources * nofsources > 2)
    {
        threadsPerBlock.y = 2;
        blocksPerGrid.y = ceil(double(nofsources * nofsources) / double(threadsPerBlock.y));
    }

    if (BlockSize > 1)
    {
        threadsPerBlock.z = 2;
        blocksPerGrid.z = ceil(double(BlockSize) / double(threadsPerBlock.z));
    }

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    // Launch a kernel on the GPU with one thread for each element.
    RightMultiplierMatricesKernel << <blocksPerGrid, threadsPerBlock >> > (dev_c, dev_a, dev_wpfcoeff, tfheight, legendresamples, nofsources);
    // Check for any errors launching the kernel
    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
    }
    // cudaDeviceSynchronize waits for the kernel to finish, and returns
 // any errors encountered during the launch.
    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    if (messageprint)
        printf("MatrixSwap(us) = %d \n", std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());

    return cudaStatus;


}


cudaError_t Device_TFBinMasking(float* dev_a, float* dev_a_im, float* dev_tfmultiplier, float* dev_sumtfmultiplier, unsigned int freqanalysiswindow, unsigned int BlockSize, unsigned int tfheight, unsigned int freqstart, unsigned int legendresamples, unsigned int nofsources)
{
    cudaError_t cudaStatus;

    dim3 threadsPerBlock(1, nofsources * nofsources, 1);
    dim3 blocksPerGrid(1, 1, 1);

    if (tfheight > 16)
    {
        threadsPerBlock.x = 16;
        blocksPerGrid.x = ceil(double(tfheight) / double(threadsPerBlock.x));
    }

    if (nofsources > 1)
    {
        threadsPerBlock.y = 2;
        blocksPerGrid.y = ceil(double(nofsources) / double(threadsPerBlock.y));
    }

    if (BlockSize > 1)
    {
        threadsPerBlock.z = 2;
        blocksPerGrid.z = ceil(double(BlockSize) / double(threadsPerBlock.z));
    }

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    // Launch a kernel on the GPU with one thread for each element.
    FinalMultiplierMatricesKernel << <blocksPerGrid, threadsPerBlock >> > (dev_a, dev_a_im, dev_tfmultiplier, dev_sumtfmultiplier, tfheight, legendresamples, nofsources);
    // Check for any errors launching the kernel
    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
    }
    // cudaDeviceSynchronize waits for the kernel to finish, and returns
 // any errors encountered during the launch.
    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    if (messageprint)
        printf("MatrixSwap(us) = %d \n", std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());

    return cudaStatus;


}

cudaError_t Device_TFBinSummation(float* dev_sum, float* dev_tfmultiplier, unsigned int freqanalysiswindow, unsigned int BlockSize, unsigned int tfheight, unsigned int freqstart, unsigned int legendresamples, unsigned int nofsources)
{

    cudaError_t cudaStatus;

    dim3 threadsPerBlock(tfheight, 1, BlockSize);
    dim3 blocksPerGrid(1, 1, 1);

    if (tfheight > 16)
    {
        threadsPerBlock.x = 16;
        blocksPerGrid.x = ceil(double(tfheight) / double(threadsPerBlock.x));
    }

   // if (nofsources > 1)
   //{
   //     threadsPerBlock.y = 1;
   //     blocksPerGrid.y = ceil(double(nofsources) / double(threadsPerBlock.y));
   // }

    if (BlockSize > 1)
    {
        threadsPerBlock.z = 2;
        blocksPerGrid.z = ceil(double(BlockSize) / double(threadsPerBlock.z));
    }

    std::chrono::steady_clock::time_point begin = std::chrono::steady_clock::now();
    // Launch a kernel on the GPU with one thread for each element.
    SumTFMatricesKernel << <blocksPerGrid, threadsPerBlock >> > (dev_sum, dev_tfmultiplier, tfheight, nofsources);
    // Check for any errors launching the kernel
    cudaStatus = cudaGetLastError();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "addKernel launch failed: %s\n", cudaGetErrorString(cudaStatus));
    }
    // cudaDeviceSynchronize waits for the kernel to finish, and returns
 // any errors encountered during the launch.
    cudaStatus = cudaDeviceSynchronize();
    if (cudaStatus != cudaSuccess) {
        fprintf(stderr, "cudaDeviceSynchronize returned error code %d after launching addKernel!\n", cudaStatus);
    }

    std::chrono::steady_clock::time_point end = std::chrono::steady_clock::now();
    if (messageprint)
        printf("MatrixSwap(us) = %d \n", std::chrono::duration_cast<std::chrono::microseconds>(end - begin).count());

    return cudaStatus;


}