OpenMP on device and cufft interoperability

Hello,

I am trying to implement a code in which the cufft functions are called with device array allocated by openmp (address obtained from use_device_ptr). Something similar was done in the article belwo with openacc:

My problem is solving iteratively an time-dependent partial differential equation in inverse space. At each time step a non-linear term is computed in real space, then there are two cufft calls, followed by an update in k-space and finally a cufft call to get the next step field.
My code works to compute the energy which means no change to the main field, but the update routine does not seem to work. The field appears to be unchanged after the iteration.
The idea is to have many stepswith the fields on the device and very rare transfer to the host. Is there something subtle happening with the device pointers? In the case of openacc+cufft they set the plan to the acc stream. COuld this be the issue? Is there equivalent way to get the stream for the openmp?

Cristian

MODULE F_T_C
    INTERFACE
        subroutine new_dpr2c(plan,idata,odata) bind(C,name="dpr2c_for_interface")
          use iso_c_binding
          implicit none
          type(c_ptr), value :: plan
          type(c_ptr),value :: idata,odata
        end subroutine
        subroutine new_dpc2r(plan,idata,odata) bind(C,name="dpc2r_for_interface")
          use iso_c_binding
          implicit none
          type(c_ptr), value :: plan
          type(c_ptr),value :: idata,odata
        end subroutine

        subroutine new_make_plan_double_r2c(plan,nx,ny) bind(C,name="make_plan_double_r2c_for_interface")
          use iso_c_binding
          implicit none
          type(c_ptr) :: plan
          integer(c_int),value :: nx,ny
        end subroutine

        subroutine new_make_plan_double_c2r(plan,nx,ny) bind(C,name="make_plan_double_c2r_for_interface")
          use iso_c_binding
          implicit none
          type(c_ptr) :: plan
          integer(c_int),value :: nx,ny
        end subroutine
    END INTERFACE
END MODULE F_T_C

PROGRAM test
  use iso_c_binding
  use omp_lib
  use F_T_C

  implicit none

integer*4 :: ii,jj,cmm,it,id
integer(c_int), parameter :: lx=784, ly=512, lyhp=ly/2+1, nend=100
integer(c_int), parameter :: size_r=lx*ly, size_k=lx*(ly/2+1) ! Follow C order
real(c_double), parameter :: invvol=1.d0/(size_r) ! Follow C order
real(c_double), parameter :: pi=3.1415926535897931d0,sq3=1.7320508075688772d0
real(c_double), parameter :: qqt=sq3/2.d0
real(c_double), parameter :: at=2.d0*pi/qqt
real(c_double), parameter :: dt=0.125d0
real(c_double), parameter :: dx=at/8.d0,dy=(at*sq3/2.d0)/8.d0
real(c_double), parameter :: r=-0.25d0,pm=-0.25d0
real(c_double) :: rnd, ssuumm
!real(c_double):: psi_h(size_r), nt_h(size_r)
!complex(c_double_complex):: psi_k_h(size_k),nt_k_h(size_k)
real(c_double), Dimension(:), allocatable:: psi_h, nt_h, new_psi_h
complex(c_double_complex), Dimension(:), allocatable:: psi_k_h,nt_k_h, new_psi_k_h

type(c_ptr) :: planR2C, planC2R
real(c_double) :: dkx,dky

call init()
!$omp target enter data map(alloc:psi_h, new_psi_h, nt_h,psi_k_h,new_psi_k_h,nt_k_h)

!$omp target teams distribute parallel do private(cmm,ii,jj)
do ii=1, lx
  do jj=1,ly
    cmm=(ii-1)*ly+jj
    psi_h(cmm)=-0.13*(dcos(qqt*(ii-1)*dx) *dcos(qqt*(jj-1)*dy/dsqrt(3.d0))+0.5*dcos(2.0*qqt*(jj-1)*dy/dsqrt(3.d0)))+pm
  enddo
enddo
!$omp end target teams distribute parallel do
!$omp target update from (psi_h)
rnd=sum(psi_h)*invvol-pm
!psi_h=psi_h-rnd
write(*,*) rnd, dsqrt(3.d0), dacos(-1.d0), sum(psi_h)/(size_r)

!$omp target update to(psi_h) ! just in case

call energy()
do id=1, 1
  do it=1,10 !nend
    call update()
    ssuumm=0.0
    do ii=1, lx
      do jj=1,ly
        cmm=(ii-1)*ly+jj
        ssuumm=ssuumm+abs(psi_h(cmm)-(-0.13*(dcos(qqt*(ii-1)*dx) *dcos(qqt*(jj-1)*dy/dsqrt(3.d0))+0.5*dcos(2.0*qqt*(jj-1)*dy/dsqrt(3.d0)))+pm))
      enddo
    enddo
    print*, (ssuumm)*invvol
    call energy()
  enddo
  stop
  call energy()
  write(*,*)  id, sum(nt_h(1:size_r))*invvol, 0.5*(r+1.d0)*pm*pm+0.25*pm**4
enddo

!$omp target exit data
contains

  subroutine update()
    integer*4 :: ie, je
    real(c_double) :: llfact, ntfact, fact,mksq,kkx,kky
    integer*4 :: ccc

    !$omp target update to(psi_h)

    !Compute non-linear term
    !$omp target teams distribute parallel do private(ccc,ie,je)
    do ie=1, lx
      do je=1,ly
        ccc=(ie-1)*ly+je
        nt_h(ccc)=psi_h(ccc)**3
      enddo
    enddo
    !$omp end target teams distribute parallel do

    !$omp target data use_device_ptr(psi_h,psi_k_h)
    call new_dpr2c(planR2C, c_loc(psi_h),c_loc(psi_k_h))
    !$omp end target data

    !$omp target data use_device_ptr(nt_h,nt_k_h)
    call new_dpr2c(planR2C, c_loc(nt_h),c_loc(nt_k_h))
    !$omp end target data

    !$omp target teams distribute parallel do private(ccc,ie,je,fact,llfact,ntfact,mksq,kkx,kky)
    do ie=1, lx
      if(ie<=lx/2+1) then
        kkx=dkx*(ie-1.d0)
      endif
      if(ie>lx/2+1) then
        kkx=dkx*(ie-1.d0-lx)
      endif
      do je=1,lyhp
        ccc=(ie-1)*lyhp+je
        kky=dky*(je-1.d0)
        mksq=-(kkx*kkx+kky*kky)
        llfact=(r+1.d0+2.d0*mksq+mksq*mksq)
        fact=1.d0/(1.d0-dt*mksq*llfact)
        ntfact=mksq*dt*1.d0/(1.d0-dt*mksq*llfact)
        new_psi_k_h(ccc)=(psi_k_h(ccc)*fact+nt_k_h(ccc)*ntfact)*invvol
      enddo
    enddo
    !$omp end target teams distribute parallel do

    !$omp target data use_device_ptr(new_psi_h,new_psi_k_h)
    call new_dpr2c(planC2R, c_loc(new_psi_k_h),c_loc(new_psi_h))
    !$omp end target data

    !Compute non-linear term
    !$omp target teams distribute parallel do private(ccc,ie,je)
    do ie=1, lx
      do je=1,ly
        ccc=(ie-1)*ly+je
        nt_h(ccc)=new_psi_h(ccc)**3
      enddo
    enddo
    !$omp end target teams distribute parallel do

    !$omp target data use_device_ptr(new_psi_h,new_psi_k_h)
    call new_dpr2c(planR2C, c_loc(new_psi_h),c_loc(new_psi_k_h))
    !$omp end target data

    !$omp target data use_device_ptr(nt_h,nt_k_h)
    call new_dpr2c(planR2C, c_loc(nt_h),c_loc(nt_k_h))
    !$omp end target data

    !$omp target teams distribute parallel do private(ccc,ie,je,fact,llfact,ntfact,mksq,kkx,kky)
    do ie=1, lx
      if(ie<=lx/2+1) then
        kkx=dkx*(ie-1.d0)
      endif
      if(ie>lx/2+1) then
        kkx=dkx*(ie-1.d0-lx)
      endif
      do je=1,lyhp
        ccc=(ie-1)*lyhp+je
        kky=dky*(je-1.d0)
        mksq=-(kkx*kkx+kky*kky)
        llfact=(r+1.d0+2.d0*mksq+mksq*mksq)
        fact=1.d0/(1.d0-dt*mksq*llfact)
        ntfact=mksq*dt*1.d0/(1.d0-dt*mksq*llfact)
        psi_k_h(ccc)=(new_psi_k_h(ccc)*fact+nt_k_h(ccc)*ntfact)*invvol
      enddo
    enddo
    !$omp end target teams distribute parallel do

    !$omp target data use_device_ptr(psi_h,psi_k_h)
    call new_dpr2c(planC2R, c_loc(psi_k_h),c_loc(psi_h))
    !$omp end target data
  !!$omp target update from(psi_h)
  !print *, sum(psi_h)*invvol
  end subroutine

  subroutine energy()
    integer*4 :: ie, je
      real(c_double) :: fact,mksq,kkx,kky
      integer*4 :: ccc

      !$omp target update to(psi_h)

      !$omp target data use_device_ptr(psi_h,psi_k_h)
      call new_dpr2c(planR2C, c_loc(psi_h),c_loc(psi_k_h))
      !$omp end target data

      !$omp target teams distribute parallel do private(ccc,ie,je,fact,mksq,kkx,kky)
      do ie=1, lx
        if(ie<=lx/2+1) then
          kkx=dkx*(ie-1.d0)
        endif
        if(ie>lx/2+1) then
          kkx=dkx*(ie-1.d0-lx)
        endif
        do je=1,lyhp
          ccc=(ie-1)*lyhp+je

          kky=dky*(je-1.d0)
          mksq=-(kkx*kkx+kky*kky)
          fact=(r+1.d0+2.d0*mksq+mksq*mksq)
          nt_k_h(ccc)=psi_k_h(ccc)*fact*invvol
        enddo
      enddo
      !$omp end target teams distribute parallel do

      !$omp target data use_device_ptr(nt_h,nt_k_h)
      call new_dpc2r(planC2R,c_loc(nt_k_h),c_loc(nt_h))
      !$omp end target data

      ! Compute the local energy in real space
      !$omp target teams distribute parallel do private(ccc,ie,je)
      do ie=1, lx
        do je=1,ly
          ccc=(ie-1)*ly+je
          nt_h(ccc)=0.5d0*psi_h(ccc) *nt_h(ccc)+0.25*(psi_h(ccc)**4)
        enddo
      enddo
      !$omp end target teams distribute parallel do

      !$omp target update from(nt_h)
      write(*,*)  id, sum(nt_h(1:size_r))*invvol, 0.5*(r+1.d0)*pm*pm+0.25*pm**4
      !stop

  end subroutine

  subroutine init()

    dkx=2.d0*pi/(lx*dx)
    dky=2.d0*pi/(ly*dy)
    call new_make_plan_double_r2c(planR2C, lx, ly)
    call new_make_plan_double_c2r(planC2R, lx, ly)
    allocate(psi_h(size_r),nt_h(size_r))
    allocate(psi_k_h(size_k),nt_k_h(size_k))
    allocate(new_psi_k_h(size_k),new_psi_h(size_r))

  end subroutine

END PROGRAM test