/*******************************************************************************
* Copyright (C) 2025 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       API oneapi::mkl::experimental::dft::distributed_descriptor to perform
*       2-D Single Precision Complex to Complex Fast Fourier Transform
*       distributed across SYCL GPU devices.
*
*       The supported floating point data types for data are:
*           float
*           std::complex<float>
*
*******************************************************************************/

#include <mpi.h>
#include <sycl/sycl.hpp>
#include <vector>
#include <iostream>
#include <stdexcept>
#include <cfloat>

#include "oneapi/mkl/experimental/distributed_dft.hpp"
#include "oneapi/mkl/exceptions.hpp"
#include "common_for_examples.hpp"
#include "mkl.h" // mkl_malloc

using distributed_desc_t = oneapi::mkl::experimental::dft::distributed_descriptor<oneapi::mkl::dft::precision::SINGLE, oneapi::mkl::dft::domain::COMPLEX>;

constexpr bool SUCCESS = true;
constexpr bool FAILURE = false;
constexpr float TWOPI = 6.2831853071795864769f;

static void init(float *data, int N0, int N1, int H0, int H1,
                 int mpi_rank, int mpi_nproc)
{
    // Strides for row-major addressing of data
    // within the local slab
    int S0 = 1, S1 = N0;
    int N1_local = distribute(N1, mpi_rank, mpi_nproc);

    for (int n1 = 0; n1 < N1_local; ++n1) {
        for (int n0 = 0; n0 < N0; ++n0) {
            float phase = TWOPI * (moda<float>(n0, H0, N0) / N0
                                    + moda<float>(global_index(n1, N1, mpi_rank, mpi_nproc), H1, N1) / N1);
            int index = 2*(n1*S1 + n0*S0);
            data[index+0] = cosf(phase) / (N1*N0);
            data[index+1] = sinf(phase) / (N1*N0);
        }
    }
}

static int verify_fwd(const float *data, int N0, int N1, int H0, int H1,
                      int mpi_rank, int mpi_nproc)
{
    // Note: this simple error bound doesn't take into account error of
    //       input data
    float errthr = 5.0f * logf((float) N1*N0) / logf(2.0f) * FLT_EPSILON;
    if(mpi_rank == 0) std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Strides for row-major addressing of data
    // within the local slab
    int N0_local = distribute(N0, mpi_rank, mpi_nproc);
    int S0 = 1, S1 = N0_local;

    int mpi_err;
    bool status = SUCCESS;
    float maxerr = 0.0f;
    for (int n1 = 0; n1 < N1; n1++) {
        for (int n0 = 0; n0 < N0_local; n0++) {
            float re_exp = (
                    ((global_index(n0, N0, mpi_rank, mpi_nproc)-H0) % N0 == 0) &&
                    ((n1-H1) % N1 == 0)
                ) ? 1.0f : 0.0f;
            float im_exp = 0.0f;

            int index = 2*(n1*S1 + n0*S0);
            float re_got = data[index+0];  // real component
            float im_got = data[index+1];  // imaginary component
            float err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tOn process:" << mpi_rank << ", data["
                          <<  n1 << ", " << n0 << "]: "
                          << "expected (" << re_exp << "," << im_exp << "), "
                          << "got (" << re_got << "," << im_got << "), "
                          << "err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                status = FAILURE;
                goto done;
            }
        }
    }

    done:
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &status, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if(mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_AllReduce error" << std::endl;
            return FAILURE;
        }
        if(status == FAILURE) return FAILURE;

    mpi_err =  MPI_Reduce(MPI_IN_PLACE, &maxerr, 1, MPI_FLOAT,
                          MPI_MAX, 0, MPI_COMM_WORLD);
    if(mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Reduce error" << std::endl;
        return FAILURE;
    }

    if(mpi_rank == 0)
        std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;

    return SUCCESS;
}

static int verify_bwd(const float* data, int N0, int N1, int H0, int H1,
                      int mpi_rank, int mpi_nproc) {
    // Note: this simple error bound doesn't take into account error of
    //       input data
    float errthr = 5.0f * logf((float) N1*N0) / logf(2.0f) * FLT_EPSILON;
    if(mpi_rank == 0) std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Strides for row-major addressing of data
    // within the local slab
    int S0 = 1, S1 = N0;
    int N1_local = distribute(N1, mpi_rank, mpi_nproc);

    int mpi_err;
    bool status = SUCCESS;
    float maxerr = 0.0f;
    for (int n1 = 0; n1 < N1_local; n1++) {
        for (int n0 = 0; n0 < N0; n0++) {
            float phase = TWOPI * (moda<float>(n0, H0, N0) / N0
                                    + moda<float>(global_index(n1, N1, mpi_rank, mpi_nproc), H1, N1) / N1);
            float re_exp = cosf(phase) / (N1*N0);
            float im_exp = sinf(phase) / (N1*N0);

            int index = 2*(n1*S1 + n0*S0);
            float re_got = data[index+0];  // real component
            float im_got = data[index+1];  // imaginary component
            float err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
            if (err > maxerr) maxerr = err;
            if (!(err < errthr)) {
                std::cout << "\t\tOn process:" << mpi_rank << ", data["
                          << n1 << ", " << n0 << "]: "
                          << "expected (" << re_exp << "," << im_exp << "), "
                          << "got (" << re_got << "," << im_got << "), "
                          << "err " << err << std::endl;
                std::cout << "\t\tVerification FAILED" << std::endl;
                status = FAILURE;
                goto done;
            }
        }
    }

    done:
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &status, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if(mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_AllReduce error" << std::endl;
            return FAILURE;
        }
        if(status == FAILURE) return FAILURE;

    mpi_err =  MPI_Reduce(MPI_IN_PLACE, &maxerr, 1, MPI_FLOAT,
                          MPI_MAX, 0, MPI_COMM_WORLD);
    if(mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Reduce error" << std::endl;
        return FAILURE;
    }

    if(mpi_rank == 0)
        std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;

    return SUCCESS;
}

int run_dft_example(sycl::device &dev, int mpi_rank, int mpi_nproc) {
    //
    // Initialize data for DFT
    //
    int N0 = 60, N1 = 120;
    int H0 = -1, H1 = -2;
    int result = FAILURE;
    int alloc_result = SUCCESS;

    // Allocate local memory and initialize
    // based on default slab decomposition
    auto N1_local = distribute(N1, mpi_rank, mpi_nproc);
    auto N0_local = distribute(N0, mpi_rank, mpi_nproc);
    auto fwd_size = 2 * N0 * N1_local;
    auto bwd_size = 2 * N0_local * N1;
    auto alloc_size = std::max(fwd_size, bwd_size);

    float* in = (float*) mkl_malloc(alloc_size*sizeof(float), 64);
    if(!in) alloc_result = FAILURE;
    int mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(in);
        throw std::runtime_error("Failed to allocate memory using mkl_malloc");
    }

    init(in, N0, N1, H0, H1, mpi_rank, mpi_nproc);

    //
    // Execute DFT
    //
    // Catch asynchronous exceptions
    auto exception_handler = [] (sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:\n"
                            << e.what() << std::endl;
            }
        }
    };

    // create execution queue with asynchronous error handling
    sycl::queue queue(dev, exception_handler);

    distributed_desc_t desc(MPI_COMM_WORLD, {N1, N0});
    desc.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0/(N0*N1)));
    // Distributed DFT assumes the same default forward and backward strides
    // as the DFT SYCL APIs corresponding to packed data layouts.
    // The default behavior is as if the following strides are set,
    // desc.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES,
    //                std::vector<std::int64_t>{0, N0, 1});
    // desc.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES,
    //                std::vector<std::int64_t>{0, N0, 1});

    // Default slab decomposition behavior is as if the following were set,
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_divided_dimension,
    //                0);
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_divided_dimension,
    //                1);
    desc.commit(queue);

    // Get the size of local USM memory to be allocated after commit
    std::int64_t fwd_usm_bytes, bwd_usm_bytes;
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_local_data_size_bytes, &fwd_usm_bytes);
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_local_data_size_bytes, &bwd_usm_bytes);
    auto usm_alloc_size = std::max(fwd_usm_bytes, bwd_usm_bytes);
    float *in_usm = (float *)malloc_device(usm_alloc_size, queue);
    if(!in_usm) alloc_result = FAILURE;
    mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                            MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(in);
        if(in_usm) sycl::free(in_usm, queue);
        throw std::runtime_error("Failed to allocate USM memory");
    }

    sycl::event copy_ev = queue.memcpy(in_usm, in, fwd_size*sizeof(float));

    oneapi::mkl::experimental::dft::compute_forward(desc, in_usm, {copy_ev}).wait();
    queue.memcpy(in, in_usm, bwd_size*sizeof(float)).wait();
    result = verify_fwd(in, N0, N1, H0, H1, mpi_rank, mpi_nproc);

    if (result == SUCCESS) {
        oneapi::mkl::experimental::dft::compute_backward(desc, in_usm).wait();
        queue.memcpy(in, in_usm, fwd_size*sizeof(float)).wait();
        result = verify_bwd(in, N0, N1, H0, H1, mpi_rank, mpi_nproc);
    }

    sycl::free(in_usm, queue);
    mkl_free(in);

    return result;
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Distributed 2D GPU FFT Complex-Complex Single-Precision Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   experimental::dft" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   float" << std::endl;
    std::cout << "#   std::complex<float>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_gpu -- only supports SYCL GPU implementation
//
int main(int argc, char **argv) {
    int mpi_err = MPI_Init(&argc, &argv);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI initialization error" << std::endl;
        std::cout << "Test Failed" << std::endl;
        return mpi_err;
    }

    int mpi_rank, mpi_nproc;
    MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &mpi_nproc);

    if(mpi_rank == 0) print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int returnCode = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);
        bool dev_found_and_is_gpu = my_dev_is_found && my_dev.is_gpu();
        // Check if all processes found the GPU
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &dev_found_and_is_gpu, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if (mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_Allreduce error" << std::endl;
            std::cout << "Test Failed" << std::endl;
            return mpi_err;
        }

        if (dev_found_and_is_gpu) {
            if(mpi_rank == 0) {
                std::cout << "Running tests on " << sycl_device_names[*it] << " with " << mpi_nproc << " processes" <<".\n";
                std::cout << "\tRunning with float precision complex-to-complex distributed 2-D FFT:" << std::endl;
            }

            try {
                bool status = run_dft_example(my_dev, mpi_rank, mpi_nproc);
                mpi_err = MPI_Reduce(MPI_IN_PLACE, &status, 1, MPI_CXX_BOOL,
                                     MPI_LAND, 0, MPI_COMM_WORLD);
                if (mpi_err != MPI_SUCCESS) {
                    std::cout << "MPI_Reduce error" << std::endl;
                    std::cout << "Test Failed" << std::endl;
                    return mpi_err;
                }

                if(mpi_rank == 0) {
                    if (status != SUCCESS) {
                        std::cout << "\tTest Failed" << std::endl << std::endl;
                        returnCode = 1;
                    } else {
                        std::cout << "\tTest Passed" << std::endl << std::endl;
                    }
                }
            }
            catch(sycl::exception const& e) {
                std::cout << "\t\tSYCL exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                std::cout << "\t\tError code: " << e.code().value() << std::endl;
                returnCode = 1;
            }
            catch(oneapi::mkl::exception const& e) {
                std::cout << "\t\toneMKL exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                returnCode = 1;
            }
            catch(std::runtime_error const& e) {
                std::cout << "\t\tRuntime exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                returnCode = 1;
            }
        } else if(my_dev_is_found) {
            std::cout << "Distributed DFT does not support " << sycl_device_names[*it] << " device for now; skipping tests" << std::endl;
        } else {
#ifdef FAIL_ON_MISSING_DEVICES
            std::cout << "No " << sycl_device_names[*it] << " devices found; Fail on missing devices is enabled." << std::endl;
            return 1;
#else
            std::cout << "No " << sycl_device_names[*it] << " devices found; skipping " << sycl_device_names[*it] << " tests." << std::endl;
#endif
        }
    }

    mkl_free_buffers();
    MPI_Finalize();
    return returnCode;
}
