!===============================================================================
! Copyright (C) 2022 Intel Corporation
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

! Content:
!       Example of using dfftw_plan_dft_r2c_2d function on a (GPU) device using
!       the OpenMP target (offload) interface and out-of-place data I/O.
!
!*****************************************************************************

include "fftw/offload/fftw3_omp_offload.f90"

program dp_plan_dft_2d_outofplace

  use FFTW3_OMP_OFFLOAD
  use omp_lib, ONLY : omp_get_num_devices
  use, intrinsic :: ISO_C_BINDING

  include 'fftw/fftw3.f'

  ! Sizes of 2D transform
  integer, parameter :: N1 = 128
  integer, parameter :: N2 = 64

  ! Arbitrary harmonic used to verify FFT
  integer, parameter :: H1 = 1
  integer, parameter :: H2 = -N2/2

  ! Need double precision
  integer, parameter :: WP = selected_real_kind(15,307)

  ! Execution status
  integer :: statusf = 0, status = 0

  ! Data array
  real(WP), allocatable :: x(:, :)
  complex(WP), allocatable :: y(:,:)

  ! FFTW plan
  integer*8 :: fwd_offload = 0

  print *,"Example dp_plan_dft_r2c_2d_outofplace"
  print *,"Forward and backward 2D real out-of-place FFT"
  print *,"Configuration parameters:"
  print '("  N = ["I0","I0"]")', N1, N2
  print '("  H = ["I0","I0"]")', H1, H2

  print *,"Allocate array for input data"
  allocate ( x(N1,N2), STAT = status)
  if (0 /= status) goto 999

  print *,"Allocate array for output data"
  allocate ( y(N1/2 + 1, N2), STAT = status)
  if (0 /= status) goto 999

  print *,"Initialize data for forward transform"
  call init(x, N1, N2, H1, H2)

  print *,"Create FFTW forward transform plan"
  !$omp target data map(to:x) map(from:y) device(0)
  !$omp dispatch
  call dfftw_plan_dft_r2c_2d(fwd_offload, N1, N2, x, y, FFTW_ESTIMATE)
  if (0 == fwd_offload) print *, "plan dft r2c failed"

  print *,"Compute forward transform"
  !$omp dispatch
  call dfftw_execute_dft_r2c(fwd_offload, x, y)
  !$omp end target data

  print *,"Verify the result of the forward transform"
  statusf = verificate(y, N1, N2, H1, H2)
  if ((0 /= statusf)) goto 999

100 continue

  print *,"Destroy FFTW plans"
  call dfftw_destroy_plan(fwd_offload)

  print *,"Deallocate arrays"
  deallocate(x)
  deallocate(y)

  if (status == 0) then
    print *, "TEST PASSED"
    call exit(0)
  else
    print *, "TEST FAILED"
    call exit(1)
  end if

999 print '("  Error, status forward = ",I0)', statusf
  status = 1
  goto 100

contains

  ! Compute mod(K*L,M) accurately
  pure integer*8 function moda(k,l,m)
    integer, intent(in) :: k,l,m
    integer*8 :: k8
    k8 = k
    moda = mod(k8*l,m)
  end function moda

  ! Initialize x(:,:) to harmonic H
  subroutine init(x, N1, N2, H1, H2)
    integer N1, N2, H1, H2
    real(WP) :: x(:,:), factor

    integer k1, k2
    real(WP), parameter:: TWOPI = 6.2831853071795864769_WP

    if (mod(2*(N1-H1),N1)==0 .and. mod(2*(N2-H2),N2)==0) then
      factor = 1.0
    else
      factor = 2.0;
    end if

    forall (k1=1:N1, k2=1:N2)
      x(k1,k2) = factor * cos( TWOPI * ( &
        real  (moda(H1,k1-1,N1),WP) / N1 &
        + real(moda(H2,k2-1,N2),WP) / N2)) / (N1*N2)
    end forall
  end subroutine init

  ! Verify that x(:,:) contains harmonic (H1,H2)
  integer function verificate(x, N1, N2, H1, H2)
    integer N1, N2, H1, H2
    complex(WP) :: x(:,:)

    integer k1, k2
    real(WP) err, errthr, maxerr
    complex(WP) res_exp, res_got

    ! Note, this simple error bound doesn't take into account error of
    ! input data
    errthr = 2.5 * log(real(N1*N2,WP)) / log(2.0_WP) * EPSILON(1.0_WP)
    print '("  Check if err is below errthr " G10.3)', errthr

    maxerr = 0
    do k2 = 1, N2
      do k1 = 1, N1/2+1
        if (mod(k1-1-H1,N1)==0 .and. mod(k2-1-H2,N2)==0) then
          res_exp = 1.0
        else if (mod(-k1+1-H1,N1)==0 .and. mod(-k2+1-H2,N2)==0) then
          res_exp = 1.0
        else
          res_exp = 0.0
        end if
        res_got = x(k1, k2)
        err = abs(res_got - res_exp)
        maxerr = max(err,maxerr)
        if (.not.(err < errthr)) then
          print '("  x("I0","I0"):"$)', k1,k2
          print '(" expected ("G24.17","G24.17"),"$)', res_exp
          print '(" got ("G24.17","G24.17"),"$)', x(k1,k2)
          print '(" err "G10.3)', err
          print *,"  Verification FAILED"
          verificate = 1
          return
        end if
      end do
    end do
    print '("  Verified,  maximum error was " G10.3)', maxerr
    verificate = 0
  end function verificate
end program dp_plan_dft_2d_outofplace
