gs2_transforms.fpp Source File


Contents

Source Code


Source Code

! Modifications for using FFTW version 3:
! (c) The Numerical Algorithms Group (NAG) Ltd, 2009 
!                                 on behalf of the HECToR project
!> FIXME : Add documentation
module gs2_transforms
  use redistribute, only: redist_type
  use fft_work, only: fft_type

  implicit none

  private

  public :: init_transforms, finish_transforms, init_zf
  public :: transform2, inverse2, kz_spectrum
  public :: transform_x, transform_y, inverse_x, inverse_y ! Only for testing
  logical :: initialized=.false., initialized_y_fft=.false.
  logical :: initialized_x_redist=.false., initialized_y_redist=.false.
  logical :: initialized_zf = .false., initialized_3d=.false.

  interface transform_x
     module procedure transform_x5d
  end interface transform_x

  interface transform_y
     module procedure transform_y5d
  end interface transform_y

  interface transform2
     module procedure transform2_5d_accel
     module procedure transform2_5d
     module procedure transform2_3d
     module procedure transform2_2d !Not actually used, but gives symmetry to inverse2_2d
  end interface transform2

  interface inverse_x
     module procedure inverse_x5d
  end interface inverse_x

  interface inverse_y
     module procedure inverse_y5d
  end interface inverse_y

  interface inverse2
     module procedure inverse2_5d_accel
     module procedure inverse2_5d
     module procedure inverse2_3d
     module procedure inverse2_2d
  end interface inverse2

  ! redistribution
  type (redist_type), save :: g2x, x2y

  ! fft
  type (fft_type), save :: xf_fft, xb_fft, yf_fft, yb_fft, zf_fft
  type (fft_type), save :: xf3d_cr, xf3d_rc
#ifdef SHMEM
  type (fft_type), save :: faccel_shmx, faccel_shmy, baccel_shmx, baccel_shmy
  type (fft_type), save :: yf_fft_shm, yb_fft_shm
  logical, save :: use_shm_2d_plan=.false., use_shm_2d_plan_buff=.false.
  integer, dimension(:), allocatable :: gidx
#endif

  ! accel will be set to true if the v layout is used AND the number of
  ! PEs is such that each PE has a complete copy of the x,y space --
  ! in that case, no communication is needed to evaluate the nonlinear
  ! terms
  logical :: accel = .false.

  logical, dimension(:), allocatable :: aidx  ! aidx == aliased index
  integer, dimension(:), allocatable :: ia, iak
  complex, dimension(:, :), allocatable :: fft, xxf
#ifndef SHMEM
  complex, dimension(:, :, :), allocatable :: ag
#else
  complex, save, dimension(:, :, :), pointer, contiguous :: ag => null()
#endif

contains

  !> FIXME : Add documentation  
  subroutine init_transforms &
       (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny, accelerated)
    use mp, only: nproc, iproc, proc0, mp_abort
    use gs2_layouts, only: init_gs2_layouts
    use gs2_layouts, only: pe_layout, init_accel_transform_layouts
    use gs2_layouts, only: init_y_transform_layouts
    use gs2_layouts, only: init_x_transform_layouts
    use gs2_layouts, only: fft_wisdom_file, fft_use_wisdom, fft_measure_plan
    use fft_work, only: load_wisdom, save_wisdom, measure_plan
    implicit none
    integer, intent (in) :: ntgrid, naky, ntheta0, nlambda, negrid, nspec
    integer, intent (in) :: nx, ny
    logical, intent (out) :: accelerated
    logical, parameter :: debug = .false.
    character(1) :: char

    accelerated = accel

    !Early exit if possible
    if (initialized) return
    initialized = .true.

    ! If either nx or ny are zero then we cannot setup the transforms so
    ! detect this and abort with a helpful message
    if (nx == 0 .or. ny == 0) then
       call mp_abort("Trying to initialise gs2 transforms but nx and / or ny are zero. Did you remember to set `grid_option = 'box'` in `kt_grids_knobs`?", .true.)
    end if

    if (debug) write (*,*) 'init_transforms: init_gs2_layouts'
    call init_gs2_layouts

    measure_plan = fft_measure_plan
    if (fft_use_wisdom) call load_wisdom(trim(fft_wisdom_file))

    call pe_layout (char)

#ifndef SHMEM
    accel = (char == 'v') .and. (mod(negrid * nlambda * nspec, nproc) == 0)
#else
    accel = (char == 'v')
#endif

    if (accel) then
       if (debug) write (*,*) 'init_transforms: init_gs2_layouts (accel)'
       call init_accel_transform_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny, nproc, iproc)

       ! need these for movies - also called via init_y_redist_local in non-accel branch
       if (debug) write (*,*) 'init_transforms: init_y_transform_layouts'
       call init_y_transform_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny, nproc, iproc)
       if (debug) write (*,*) 'init_transforms: init_x_transform_layouts'
       call init_x_transform_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, nproc, iproc)

       call init_accel_fft(ntgrid)
    else
       !Recommended for p+log(p)>log(N) where p is number of
       !processors and N is total number of mesh points. Could automate
       !selection, though above condition is only fairly rough
       if (debug) write (*,*) 'init_transforms: init_y_redist_local'
       ! Note also calls init_x_redist_local
       call init_y_redist_local (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny)

       if (debug) write (*,*) 'init_transforms: init_xy_fft'
       call init_xy_fft
    end if

    if (debug) write (*,*) 'init_transforms: done'
    accelerated = accel
    
    if (proc0 .and. fft_use_wisdom) call save_wisdom(trim(fft_wisdom_file))
  end subroutine init_transforms

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> Setup accelerated 2D or SHMEM FFT(s)
  subroutine init_accel_fft(ntgrid)
    use gs2_layouts, only: accel_lo, accelx_lo, dealiasing
    use fft_work, only: init_crfftw, init_rcfftw, init_ccfftw, FFT_TO_SPECTRAL_SPACE, FFT_TO_REAL_SPACE
#ifdef SHMEM
    use shm_mpi3, only: shm_info, shm_alloc
    use mp, only: proc0
    use gs2_layouts, only: g_lo, accel_lo
    integer :: howmany, stride, n2, n3, idw, nwk, envlen, ts1, ts2
    character(len=1) :: envval
    complex, dimension(:, :), allocatable :: dummy
#endif
    integer, intent(in) :: ntgrid
    integer :: i, idx

    if (initialized_y_fft) return
    initialized_y_fft = .true.

    !JH FFTW plan creation for the accelerated 2d transforms -- do we need this
    !when we have SHMEM?
    call init_crfftw (yf_fft, FFT_TO_REAL_SPACE, accel_lo%ny, accel_lo%nx, &
         (2*accel_lo%ntgrid+1) * 2)
    call init_rcfftw (yb_fft, FFT_TO_SPECTRAL_SPACE, accel_lo%ny, accel_lo%nx, &
         (2*accel_lo%ntgrid+1) * 2)

    ! prepare for dealiasing
    allocate (ia (accel_lo%nia), iak(accel_lo%nia))

#ifndef SHMEM
    allocate (ag(-ntgrid:ntgrid, 2, accel_lo%llim_proc:accel_lo%ulim_alloc))
    allocate (aidx(accel_lo%llim_proc:accel_lo%ulim_proc))

    do i = accel_lo%llim_proc, accel_lo%ulim_proc
       aidx(i) = dealiasing(accel_lo, i)
    end do

    do idx = 1, accel_lo%nia
       ia (idx) = accelx_lo%llim_proc + (idx-1)*accelx_lo%nxny
       iak(idx) = accel_lo%llim_proc  + (idx-1)*accel_lo%nxnky
    end do
#else
    call shm_alloc(ag, [-ntgrid, ntgrid, 1, 2, accel_lo%llim_proc, accel_lo%ulim_alloc])
    allocate (gidx(accel_lo%llim_node:accel_lo%ulim_node))
    gidx = -1
    idx = g_lo%llim_node
    allocate (aidx(accel_lo%llim_node:accel_lo%ulim_node))
    do i = accel_lo%llim_node, accel_lo%ulim_node
       aidx(i) = dealiasing(accel_lo, i)
       if (aidx(i)) cycle
       gidx(i) = idx; idx = idx+1
    end do

    do idx = 1, accel_lo%nia
       ia (idx) = accelx_lo%llim_node + (idx-1)*accelx_lo%nxny
       iak(idx) = accel_lo%llim_node  + (idx-1)*accel_lo%nxnky
    end do

    call get_environment_variable("GS2_SHM_2D_PLAN",VALUE=envval,LENGTH=envlen)
    if (envlen >0) read(envval,*) use_shm_2d_plan
    call get_environment_variable("GS2_SHM_2D_PLAN_BUFF",VALUE=envval, LENGTH=envlen)
    if (envlen >0) read(envval,*) use_shm_2d_plan_buff
    if (proc0) then
       if (use_shm_2d_plan) then
          write(*,*)' SHM with 2D plans'
          if (use_shm_2d_plan_buff) then
             write(*,*)' SHM 2D plans with buffering'
          endif
       else
          write(*,*)' SHM with 1D plans'
       endif
    endif

    if (use_shm_2d_plan) then
       n2 = accel_lo%ny
       n3 = accel_lo%nx

       ! attention, trouble if ppn =1
       if ( shm_info%size > 1 ) then
          if (shm_info%id < shm_info%size / 2) then
             idw = shm_info%id
             ts2=1
             nwk = shm_info%size / 2
          else
             idw = shm_info%id - shm_info%size / 2
             ts2=2
             nwk = shm_info%size - shm_info%size / 2! for odd ppn
          endif
       else
          idw = 0
          ts2=1
          nwk=1
       endif
       ts1 = -ntgrid + (2*ntgrid+1)/nwk * idw &
            + min(mod(2*ntgrid+1,nwk), idw)
       howmany =  (2*ntgrid+1)/nwk
       if ( idw < mod(2*ntgrid+1,nwk)) howmany = howmany+1
       if(shm_info%size == 1) howmany = 2*(2*ntgrid+1)

       if ( use_shm_2d_plan_buff) then
          stride=howmany
       else
          stride=2*(2*accel_lo%ntgrid+1)
       endif

       call init_crfftw (yf_fft_shm, FFT_TO_REAL_SPACE, accel_lo%ny, accel_lo%nx, &
            howmany, stride) !(2*accel_lo%ntgrid+1) * 2)

       ! now the inverse
       call init_rcfftw (yb_fft_shm, FFT_TO_SPECTRAL_SPACE, accel_lo%ny, accel_lo%nx, &
            howmany, stride) !(2*accel_lo%ntgrid+1) * 2)

       ! init_... has intent out for plan var
       yf_fft_shm%ts1=ts1
       yf_fft_shm%ts2=ts2
       yb_fft_shm%ts1=ts1
       yb_fft_shm%ts2=ts2

    else
       call init_crfftw(faccel_shmx, FFT_TO_REAL_SPACE, accel_lo%ny, 2*(2*ntgrid+1), .true.)
       allocate(dummy( 2*(2*ntgrid+1), accel_lo%ndky * accel_lo%nx))
       call init_ccfftw(faccel_shmy, FFT_TO_REAL_SPACE, accel_lo%nx, 2*(2*ntgrid+1), dummy, .true., accel_lo%ndky)
       deallocate(dummy)
       call init_rcfftw(baccel_shmx, FFT_TO_SPECTRAL_SPACE, accelx_lo%ny, 2*(2*ntgrid+1), .true.)
       allocate(dummy( 2*(2*ntgrid+1), accelx_lo%ny * accelx_lo%nx))
       call init_ccfftw(baccel_shmy, FFT_TO_SPECTRAL_SPACE, accel_lo%nx, 2*(2*ntgrid+1), dummy, .true.,accel_lo%ndky)
       deallocate(dummy)
    end if
#endif
  end subroutine init_accel_fft

  !> Setup non-accelerated separate x and y FFTs
  subroutine init_xy_fft
    use gs2_layouts, only: xxf_lo, yxf_lo
    use fft_work, only: init_crfftw, init_rcfftw, init_ccfftw, FFT_TO_SPECTRAL_SPACE, FFT_TO_REAL_SPACE
    implicit none
    integer :: nb_ffts

    if (initialized_y_fft) return
    initialized_y_fft = .true.

    if (.not. allocated(fft)) allocate(fft(yxf_lo%ny/2+1, yxf_lo%llim_proc:yxf_lo%ulim_alloc))
    if (.not. allocated(xxf)) allocate(xxf(xxf_lo%nx,xxf_lo%llim_proc:xxf_lo%ulim_alloc))

    !JH FFTW plan creation for transform x5d and inverse
    ! number of ffts to be calculated
    !JH 7th December 2011
    !JH xxf_lo%ulim_alloc is used here rather than x/yxf_lo%ulim_proc
    !JH because there are situations where x/yxf_lo%llim_proc is greater
    !JH than x/yxf_lo%ulim_proc and that would create a negative number
    !JH of FFTs to be calculated.  However, x/yxf_lo%ulim_alloc is set
    !JH to be x/yxf_lo%llim_proc in this situation, and that will give
    !JH 1 FFT to be calculated which the code can correctly undertake.
    nb_ffts = xxf_lo%ulim_alloc - xxf_lo%llim_proc + 1
    call init_ccfftw (xf_fft, FFT_TO_REAL_SPACE, xxf_lo%nx, nb_ffts, xxf)
    call init_ccfftw (xb_fft, FFT_TO_SPECTRAL_SPACE, xxf_lo%nx, nb_ffts, xxf)

    nb_ffts = yxf_lo%ulim_alloc - yxf_lo%llim_proc + 1
    call init_crfftw (yf_fft, FFT_TO_REAL_SPACE, yxf_lo%ny, nb_ffts)
    call init_rcfftw (yb_fft, FFT_TO_SPECTRAL_SPACE,  yxf_lo%ny, nb_ffts)
  end subroutine init_xy_fft

  !> Constructs the redistribute mapping from the g_lo data
  !> decomposition to the xxf_lo decomposition.
  subroutine setup_x_redist_local(g_lo, xxf_lo, g2x)
    use layouts_type, only: g_layout_type, xxf_layout_type
    use gs2_layouts, only: proc_id, idx_local, opt_local_copy, layout
    use gs2_layouts, only: ik_idx,il_idx,ie_idx,is_idx,idx, ig_idx, isign_idx
    use mp, only: nproc
    use redistribute, only: index_list_type, init_redist, delete_list
    use redistribute, only: set_redist_character_type, set_xxf_optimised_variables
    use sorting, only: quicksort
    implicit none
    type(g_layout_type), intent(in) :: g_lo
    type(xxf_layout_type), intent(in) :: xxf_lo
    type(redist_type), intent(in out) :: g2x
    type (index_list_type), dimension(0:nproc-1) :: to_list, from_list, sort_list
    integer, dimension(0:nproc-1) :: nn_to, nn_from
    integer, dimension (3) :: from_low, from_high
    integer, dimension (2) :: to_high
    integer :: to_low
    integer :: iglo, isign, ig, it, ixxf, it0
    integer :: n, ip, il, ik, ie, is

    ! count number of elements to be redistributed to/from each processor
    nn_to = 0
    nn_from = 0

    !Here we loop over the whole domain (so this doesn't get cheaper with more cores)
    !This is required to ensure that we can associate send and receive message elements
    !i.e. so we know that we put the received data in the correct place.
    !However, as we know the order the iglo,isign,ig indices should increase we could
    !loop over our ixxf local range, calculate the corresponding iglo,isign and ig indices
    !and sort the ixxf messages on this to ensure they're in the correct order.
    !Either way when just counting how much data we are going to send and receive we don't
    !care about order so we can just loop over our local range!
    !First count the sends | g_lo-->xxf_lo
    !Protect against procs with no data
    if(g_lo%ulim_proc>=g_lo%llim_proc)then
       do iglo = g_lo%llim_proc, g_lo%ulim_alloc
          il = il_idx(g_lo, iglo)

          !Convert iglo,isign=1,ig=-ntgrid into ixxf
          ixxf=idx(xxf_lo,-g_lo%ntgrid,1,ik_idx(g_lo,iglo),&
               il, ie_idx(g_lo,iglo),is_idx(g_lo,iglo))

          !Now loop over other local dimensions
          do isign = 1, 2
             do ig = -g_lo%ntgrid, g_lo%ntgrid
                ip = proc_id(xxf_lo, ixxf)

                ! Don't send data from forbidden region, unless it is going to
                ! this processor (i.e. local copy) due to assumptions in opt_local_copy
                ! based methods.
                if (.not. point_is_forbidden_and_not_redistributed(ig, il, ip)) then
                   !Increase the data send count for the proc which has the ixxf
                   nn_from(ip)=nn_from(ip)+1
                end if

                !Increase ixxf using knowledge of the xxf_lo layout
                ixxf=ixxf+xxf_lo%naky
             enddo
          enddo
       enddo
    endif

    !Now count the receives | xxf_lo<--g_lo
    !Protect against procs with no data
    if(xxf_lo%ulim_proc>=xxf_lo%llim_proc)then
       do ixxf = xxf_lo%llim_proc, xxf_lo%ulim_alloc
          !Could split it (or x) domain into two parts to account for
          !difference in it (or x) meaning and order in g_lo and xxf_lo
          !but only interested in how much data we receive and not the
          !exact indices (yet) so just do 1-->ntheta0 (this is how many
          !non-zero x's?)
          ik = ik_idx(xxf_lo, ixxf)
          il = il_idx(xxf_lo, ixxf)
          ie = ie_idx(xxf_lo, ixxf)
          is = is_idx(xxf_lo, ixxf)
          ig = ig_idx(xxf_lo, ixxf)
          do it=1,xxf_lo%ntheta0
             !Convert ixxf,it indices into iglo, ig and isign indices
             iglo=idx(g_lo, ik, it, il, ie, is)

             ip = proc_id(g_lo, iglo)

             ! Don't send data from forbidden region, unless it is going to
             ! this processor (i.e. local copy) due to assumptions in opt_local_copy
             ! based methods.
             if (point_is_forbidden_and_not_redistributed(ig, il, ip)) cycle

             !Increase the data to receive count for proc with this data
             !Note, we only worry about iglo and not isign/ig because we know that
             !in g_lo each proc has all ig and isign domain.
             !The xxf_lo domain contains ig and isign so we only need to add one
             !to the count for each ixxf.
             nn_to(proc_id(g_lo,iglo))=nn_to(proc_id(g_lo,iglo))+1
          enddo
       enddo
    endif

    !Now allocate storage for data mapping indices
    do ip = 0, nproc-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip)%first(nn_from(ip)))
          allocate (from_list(ip)%second(nn_from(ip)))
          allocate (from_list(ip)%third(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip)%first(nn_to(ip)))
          allocate (to_list(ip)%second(nn_to(ip)))
          !For sorting to_list later
          allocate (sort_list(ip)%first(nn_to(ip)))
       end if
    end do

    !Reinitialise count arrays to zero
    nn_to = 0
    nn_from = 0

    !First fill in the sending indices, these define the messages data order
    !Protect against procs with no data
    if(g_lo%ulim_proc>=g_lo%llim_proc)then
       do iglo=g_lo%llim_proc, g_lo%ulim_alloc
          il = il_idx(g_lo, iglo)

          !Convert iglo,isign=1,ig=-ntgrid into ixxf
          ixxf=idx(xxf_lo,-g_lo%ntgrid,1,ik_idx(g_lo,iglo),&
               il, ie_idx(g_lo,iglo),is_idx(g_lo,iglo))

          !Now loop over other local dimensions
          do isign = 1, 2
             do ig = -g_lo%ntgrid, g_lo%ntgrid
                !Get proc id
                ip=proc_id(xxf_lo,ixxf)
                if (.not. point_is_forbidden_and_not_redistributed(ig, il, ip)) then

                   !Increment procs message counter
                   n=nn_from(ip)+1
                   nn_from(ip)=n

                   !Store indices
                   from_list(ip)%first(n) = ig
                   from_list(ip)%second(n) = isign
                   from_list(ip)%third(n) = iglo
                end if

                !We could send this information, transformed to the xxf layout to the proc.

                !Increment counter
                ixxf=ixxf+xxf_lo%naky
             enddo
          enddo
       enddo
    endif

    !Now lets fill in the receiving indices, these must match the messages data order
    !Protect against procs with no data
    if(xxf_lo%ulim_proc>=xxf_lo%llim_proc)then
       do ixxf = xxf_lo%llim_proc, xxf_lo%ulim_alloc
          !Get indices
          isign=isign_idx(xxf_lo,ixxf)
          ik = ik_idx(xxf_lo, ixxf)
          il = il_idx(xxf_lo, ixxf)
          ie = ie_idx(xxf_lo, ixxf)
          is = is_idx(xxf_lo, ixxf)
          ig = ig_idx(xxf_lo, ixxf)

          !Loop over receiving "it" indices
          do it=1,g_lo%ntheta0
             !Convert from g_lo%it to xxf_lo%it
             if(it>(xxf_lo%ntheta0+1)/2) then
                it0=it+xxf_lo%nx-xxf_lo%ntheta0
             else
                it0=it
             endif

             !Convert ixxf,it indices into iglo indices
             iglo=idx(g_lo, ik, it, il, ie, is)

             !Get proc id which has this data
             ip=proc_id(g_lo,iglo)

             if (point_is_forbidden_and_not_redistributed(ig, il, ip)) cycle

             !Determine message position index
             n=nn_to(ip)+1
             nn_to(ip)=n

             !Store receive indices
             to_list(ip)%first(n)=it0
             to_list(ip)%second(n)=ixxf

             !Store index for sorting
             sort_list(ip)%first(n) = ig+g_lo%ntgrid-1+g_lo%ntgridtotal*(isign-1+2*(iglo-g_lo%llim_world))
          enddo
       enddo
    endif

    !Now we need to sort the to_list message indices based on sort_list | This seems potentially slow + inefficient
    do ip=0,nproc-1
       !Only need to worry about procs which we are receiving from
       if(nn_to(ip)>0) then
          !Sort using quicksort
          call quicksort(nn_to(ip),sort_list(ip)%first,to_list(ip)%first,to_list(ip)%second)
       endif
    enddo

    !Setup array range values
    from_low (1) = -g_lo%ntgrid
    from_low (2) = 1
    from_low (3) = g_lo%llim_proc

    to_low = xxf_lo%llim_proc

    to_high(1) = xxf_lo%nx
    to_high(2) = xxf_lo%ulim_alloc

    from_high(1) = g_lo%ntgrid
    from_high(2) = 2
    from_high(3) = g_lo%ulim_alloc

    call set_redist_character_type(g2x, 'g2x')
    call set_xxf_optimised_variables(opt_local_copy, xxf_lo%naky,  &
         xxf_lo%ntgrid,  xxf_lo%ntheta0, xxf_lo%nlambda,  xxf_lo%negrid, &
         xxf_lo%nx, xxf_lo%ulim_proc, g_lo%ulim_proc, layout)

    !Create g2x redistribute object
    call init_redist (g2x, 'c', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    !Deallocate list objects
    call delete_list (to_list)
    call delete_list (from_list)
    call delete_list(sort_list)

  end subroutine setup_x_redist_local

  !> Setup global xxf_lo and redistribute between global g_lo -> xxf_lo
  subroutine init_x_redist_local(ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx)
    use gs2_layouts, only: init_x_transform_layouts
    use gs2_layouts, only: g_lo, xxf_lo
    use mp, only: nproc, iproc
    implicit none
    integer, intent (in) :: ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx

    !Early exit if possible
    if (initialized_x_redist) return
    initialized_x_redist = .true.

    !Setup the xxf_lo layout object
    call init_x_transform_layouts &
         (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, nproc, iproc)

    call setup_x_redist_local(g_lo, xxf_lo, g2x)
  end subroutine init_x_redist_local

  !> Constructs the redistribute mapping from the xxf_lo data
  !> decomposition to the yxf_lo decomposition.
  subroutine setup_y_redist_local(xxf_lo, yxf_lo, x2y)
    use layouts_type, only: xxf_layout_type, yxf_layout_type
    use gs2_layouts, only: proc_id, idx_local
    use gs2_layouts, only: ik_idx,it_idx,il_idx,ie_idx,is_idx,idx,ig_idx,isign_idx
    use mp, only: nproc
    use redistribute, only: index_list_type, init_redist, delete_list
    use redistribute, only: set_yxf_optimised_variables, set_redist_character_type
    use sorting, only: quicksort
    implicit none
    type(xxf_layout_type), intent(in) :: xxf_lo
    type(yxf_layout_type), intent(in) :: yxf_lo
    type(redist_type), intent(in out) :: x2y
    type (index_list_type), dimension(0:nproc-1) :: to_list, from_list,sort_list
    integer, dimension(0:nproc-1) :: nn_to, nn_from
    integer, dimension (2) :: from_low, from_high, to_high
    integer :: to_low
    integer :: it, ixxf, ik, iyxf
    integer :: ixxf_start, iyxf_start
    integer :: n, ip, il, ig

    !Initialise counts to zero
    nn_to = 0
    nn_from = 0

    !First count data to send | xxf_lo-->yxf_lo
    !Protect against procs with no data
    if(xxf_lo%ulim_proc>=xxf_lo%llim_proc)then
       do ixxf=xxf_lo%llim_proc,xxf_lo%ulim_alloc
          il = il_idx(xxf_lo, ixxf)
          ig = ig_idx(xxf_lo, ixxf)

          !Get iyxf index for "it"=1
          iyxf_start=idx(yxf_lo, ig,&
               isign_idx(xxf_lo,ixxf), 1, il,&
               ie_idx(xxf_lo,ixxf),is_idx(xxf_lo,ixxf))

          !Loop over "it" range, note that we actually only want to know
          !iyxf in this range and we know that it-->it+1 => iyxf-->iyxf+1
          !so replace loop with one over iyxf
          do iyxf=iyxf_start,iyxf_start+yxf_lo%nx-1
             ip = proc_id(yxf_lo, iyxf)

             ! Don't send data from forbidden region, unless it is going to
             ! this processor (i.e. local copy) due to assumptions in opt_local_copy
             ! based methods.
             if (point_is_forbidden_and_not_redistributed(ig, il, ip)) cycle
             !Increase the appropriate procs send count
             nn_from(ip)=nn_from(ip)+1
          enddo
       enddo
    endif

    !Now count data to receive | yxf_lo<--xxf_lo
    !Protect against procs with no data
    if(yxf_lo%ulim_proc>=yxf_lo%llim_proc)then
       do iyxf=yxf_lo%llim_proc,yxf_lo%ulim_alloc
          ig = ig_idx(yxf_lo, iyxf)
          il = il_idx(yxf_lo, iyxf)
          !Get ixxf index for "ik"=1
          ixxf_start=idx(xxf_lo, ig,&
               isign_idx(yxf_lo,iyxf), 1, il,&
               ie_idx(yxf_lo,iyxf),is_idx(yxf_lo,iyxf))

          !Loop over "ik" range, note that we actually only want to know
          !ixxf in this range and we know that ik-->ik+1 => ixxf-->ixxf+1
          !so replace loop with one over ixxf
          do ixxf=ixxf_start,ixxf_start+xxf_lo%naky-1
             ip = proc_id(xxf_lo, ixxf)

             ! Don't send data from forbidden region, unless it is going to
             ! this processor (i.e. local copy) due to assumptions in opt_local_copy
             ! based methods.
             if (point_is_forbidden_and_not_redistributed(ig, il, ip)) cycle

             !Increase the appropriate procs recv count
             nn_to(ip) = nn_to(ip)+1
          enddo
       enddo
    endif

    !Now allocate storage for data mapping structures
    do ip = 0, nproc-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip)%first(nn_from(ip)))
          allocate (from_list(ip)%second(nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip)%first(nn_to(ip)))
          allocate (to_list(ip)%second(nn_to(ip)))
          !For sorting to_list later
          allocate (sort_list(ip)%first(nn_to(ip)))
       end if
    end do

    !Reinitialise count arrays to zero
    nn_to = 0
    nn_from = 0

    !First fill in the sending indices, these define the messages data order
    !Protect against procs with no data
    if(xxf_lo%ulim_proc>=xxf_lo%llim_proc)then
       do ixxf=xxf_lo%llim_proc,xxf_lo%ulim_alloc
          ig = ig_idx(xxf_lo, ixxf)
          il = il_idx(xxf_lo, ixxf)

          !Get iyxf for "it"=1
          iyxf=idx(yxf_lo, ig,&
               isign_idx(xxf_lo,ixxf), 1, il,&
               ie_idx(xxf_lo,ixxf),is_idx(xxf_lo,ixxf))

          !Now loop over other local dimension. Note we need "it" here
          !so don't replace this with loop over iyxf
          do it=1,yxf_lo%nx
             !Get the processor id
             ip=proc_id(yxf_lo,iyxf)

             ! Don't send data from forbidden region, unless it is going to
             ! this processor (i.e. local copy) due to assumptions in opt_local_copy
             ! based methods.
             if (.not. point_is_forbidden_and_not_redistributed(ig, il, ip)) then
                !Increment the procs message counter
                n=nn_from(ip)+1
                nn_from(ip)=n

                !Store indices
                from_list(ip)%first(n)=it
                from_list(ip)%second(n)=ixxf
             end if

             !Increment iyxf
             iyxf=iyxf+1
          enddo
       enddo
    endif

    !Now fill in the receiving indices, these must match the message data order, achieved by later sorting
    !Protect against procs with no data
    if(yxf_lo%ulim_proc>=yxf_lo%llim_proc)then
       do iyxf=yxf_lo%llim_proc,yxf_lo%ulim_alloc
          ig = ig_idx(yxf_lo, iyxf)
          il = il_idx(yxf_lo, iyxf)

          !Get ixxf for "ik"=1
          ixxf=idx(xxf_lo, ig,&
               isign_idx(yxf_lo,iyxf), 1, il,&
               ie_idx(yxf_lo,iyxf),is_idx(yxf_lo,iyxf))

          !Now loop over other local dimension. Note we need "ik" here
          !so don't replace this with loop over ixxf
          do ik=1,xxf_lo%naky
             !Get the processor id
             ip=proc_id(xxf_lo,ixxf)

             ! Don't send data from forbidden region, unless it is going to
             ! this processor (i.e. local copy) due to assumptions in opt_local_copy
             ! based methods.
             if (.not. point_is_forbidden_and_not_redistributed(ig, il, ip)) then

                !Increment the procs message counter
                n=nn_to(ip)+1
                nn_to(ip)=n

                !Store indices
                to_list(ip)%first(n)=ik
                to_list(ip)%second(n)=iyxf

                !Store index for sorting
                sort_list(ip)%first(n)=it_idx(yxf_lo,iyxf)+ixxf*yxf_lo%nx
             end if

             !Increment ixxf
             ixxf=ixxf+1
          enddo
       enddo
    endif

    !Now we need to sort the to_list message indices based on sort_list
    !This could be slow and inefficient
    do ip=0,nproc-1
       !Only need to worry about procs which we are receiving from
       if(nn_to(ip)>0) then
          !Use quicksort based on compound index
          call quicksort(nn_to(ip),sort_list(ip)%first,to_list(ip)%first,to_list(ip)%second)
       endif
    enddo

    !Setup array bound values
    from_low(1) = 1
    from_low(2) = xxf_lo%llim_proc

    to_low = yxf_lo%llim_proc

    to_high(1) = yxf_lo%ny/2+1
    to_high(2) = yxf_lo%ulim_alloc

    from_high(1) = xxf_lo%nx
    from_high(2) = xxf_lo%ulim_alloc

    call set_redist_character_type(x2y, 'x2y')
    call set_yxf_optimised_variables(yxf_lo%ulim_proc)

    !Create x2y redist object
    call init_redist (x2y, 'c', to_low, to_high, to_list, &
         from_low, from_high, from_list)

    !Deallocate list objects
    call delete_list (to_list)
    call delete_list (from_list)
    call delete_list (sort_list)
  end subroutine setup_y_redist_local

  !> Small helper function to wrap up logic about if this point corresponds to a forbidden
  !> point and if we need to transfer it in the redistribute
  pure logical function point_is_forbidden_and_not_redistributed(ig, il, ip) result(flag)
    use mp, only: iproc
    use le_grids, only: forbid
    use gs2_layouts, only: opt_local_copy
    implicit none
    integer, intent(in) :: ig, il, ip
    !> The following is for debug/testing to force all grid
    !> points to be communicated if we wish.
    logical, parameter :: skip_forbidden_region = .true.
    flag = &
         !We want to skip the forbidden points
         skip_forbidden_region .and. &
         !Point is in forbidden region
         forbid(ig, il) &
         !Point will be sent to a different proc, or we are not using opt_local_copy.
         .and. ((ip/=iproc) .or. .not. opt_local_copy)
  end function point_is_forbidden_and_not_redistributed

  !> Setup the module level xxf -> yxf redistribute. Note this also
  !> calls [[init_x_redist_local]] in order to setup the mapping for
  !> g_lo to xxf_lo as well.
  subroutine init_y_redist_local (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny)
    use gs2_layouts, only: init_y_transform_layouts
    use gs2_layouts, only: xxf_lo, yxf_lo
    use mp, only: nproc, iproc
    implicit none
    integer, intent (in) :: ntgrid, naky, ntheta0, nlambda, negrid, nspec
    integer, intent (in) :: nx, ny

    !Setup g_lo-->xxf_lo redist object first
    call init_x_redist_local (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx)

    !Early exit if possible
    if (initialized_y_redist) return
    initialized_y_redist = .true.

    !Setup the yxf layout object
    call init_y_transform_layouts &
         (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nx, ny, nproc, iproc)

    call setup_y_redist_local(xxf_lo, yxf_lo, x2y)
  end subroutine init_y_redist_local

  !> FIXME : Add documentation
  subroutine transform_x5d (g, xxf)
    use gs2_layouts, only: xxf_lo, g_lo
    use redistribute, only: gather
    use job_manage, only: time_message
    use fft_work, only: time_fft
    use array_utils, only: zero_array
    implicit none
    complex, dimension (-xxf_lo%ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (:,xxf_lo%llim_proc:), intent (out) :: xxf

    ! Zero out the array as the subsequent gather doesn't populate
    ! every element as we skip communicating forbidden points and
    ! we are padding our data. We can't just zero once at the start
    ! as fftw can destroy the input, meaning it might set zero entries
    ! in fft to something non-zero. Plus here xxf also holds the output.
    call zero_array(xxf)

    !CMR, 7/3/2011: gather pulls appropriate pieces of g onto this processor for
    !    local Fourier transform in x, and may also pad with zeros for dealiasing
    !
    call gather (g2x, g, xxf)

    call time_message(.false., time_fft, ' FFT')
    call xf_fft%execute_c2c(xxf, xxf)
    call time_message(.false., time_fft, ' FFT')
  end subroutine transform_x5d

  !> FIXME : Add documentation  
  subroutine inverse_x5d (xxf, g)
    use gs2_layouts, only: xxf_lo, g_lo
    use redistribute, only: scatter
    use job_manage, only: time_message
    use fft_work, only: time_fft
    implicit none
    complex, dimension (:,xxf_lo%llim_proc:), intent (in out) :: xxf
    complex, dimension (-xxf_lo%ntgrid:,:,g_lo%llim_proc:), intent (out) :: g

    call time_message(.false., time_fft, ' FFT')
    call xb_fft%execute_c2c(xxf, xxf)
    call time_message(.false., time_fft, ' FFT')

    call scatter (g2x, xxf, g)
  end subroutine inverse_x5d

  !> FIXME : Add documentation
  subroutine transform_y5d (xxf, yxf)
    use gs2_layouts, only: xxf_lo, yxf_lo
    use redistribute, only: gather
    use job_manage, only: time_message
    use fft_work, only: time_fft
    use array_utils, only: zero_array
    implicit none
    complex, dimension (:,xxf_lo%llim_proc:), intent (in) :: xxf
    real, dimension (:,yxf_lo%llim_proc:), intent(in out) :: yxf

    ! Zero out the array as the subsequent gather doesn't populate
    ! every element as we skip communicating forbidden points and
    ! we are padding our data. We can't just zero once at the start
    ! as fftw can destroy the input, meaning it might set zero entries
    ! in fft to something non-zero.
    call zero_array(fft)

    !Note here we're doing the communication even if we're not using
    !an FFT routine.
    call gather (x2y, xxf, fft)

    call time_message(.false., time_fft, ' FFT')
    call yf_fft%execute_c2r(fft, yxf)
    call time_message(.false., time_fft, ' FFT')
  end subroutine transform_y5d

  !> FIXME : Add documentation  
  subroutine inverse_y5d (yxf, xxf)
    use gs2_layouts, only: xxf_lo, yxf_lo
    use redistribute, only: scatter
    use job_manage, only: time_message
    use fft_work, only: time_fft
    implicit none
    real, dimension (:,yxf_lo%llim_proc:), intent (in out) :: yxf
    complex, dimension (:,xxf_lo%llim_proc:), intent (out) :: xxf

    call time_message(.false., time_fft, ' FFT')
    call yb_fft%execute_r2c(yxf, fft)
    call time_message(.false., time_fft, ' FFT')

    call scatter (x2y, fft, xxf)
  end subroutine inverse_y5d

  !> FIXME : Add documentation  
  subroutine transform2_5d (g, yxf)
    use gs2_layouts, only: g_lo, yxf_lo, ik_idx
    use unit_tests,only: debug_message
    implicit none
    complex, dimension (:,:,g_lo%llim_proc:), intent (in out) :: g
    real, dimension (:,yxf_lo%llim_proc:), intent (out) :: yxf
    integer :: iglo

    call debug_message(4, 'gs2_transforms::transform2_5d starting')

    !CMR+GC: 2/9/2013
    !  gs2's Fourier coefficients,  F_k^gs2, not standard form: i.e. f(x) = f_k e^(i k.x)
    !
    !  F_k^gs2 are 2 x larger for ky > 0,   i.e.
    !                     F_k^gs2 = |    f_k   for ky = 0
    !                               |  2 f_k   for ky > 0
    !
    ! Following large loop (due to this) can be eliminated with std Fourier coeffs.
    ! Similar optimisations possible in: 
    !          "inverse2_5d", "transform2_5d_accel", "inverse2_5d_accel" 
    !
    ! NB Moving to standard Fourier coeffs would impact considerably on diagnostics:
    !       e.g. fac in get_volume_average
    !

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP SHARED(g_lo, g) &
    !$OMP SCHEDULE(dynamic)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       if (ik_idx(g_lo, iglo) == 1) cycle
       g(:,:,iglo) = g(:,:,iglo) * 0.5
    end do
    !$OMP END PARALLEL DO

    call transform_x (g, xxf)
    call transform_y (xxf, yxf)
  end subroutine transform2_5d

  !> FIXME : Add documentation  
  subroutine inverse2_5d (yxf, g)
    use gs2_layouts, only: g_lo, yxf_lo, ik_idx
    implicit none
    real, dimension (:,yxf_lo%llim_proc:), intent (in out) :: yxf
    complex, dimension (:,:,g_lo%llim_proc:), intent (out) :: g
    integer :: iglo
    real :: scale
    call inverse_y (yxf, xxf)
    call inverse_x (xxf, g)

    scale = xb_fft%scale * yb_fft%scale
    !CMR+GC: 2/9/2013
    ! Following large loop can be eliminated if gs2 used standard Fourier coefficients.
    ! (See above comment in transform2_5d.)
    !

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP SHARED(g_lo, g, scale) &
    !$OMP SCHEDULE(dynamic)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       if (ik_idx(g_lo, iglo) /= 1) then
          g(:,:,iglo) = g(:,:,iglo) * (scale * 2.0)
       else
          g(:,:,iglo) = g(:,:,iglo) * scale
       end if
    end do
    !$OMP END PARALLEL DO
  end subroutine inverse2_5d

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> FIXME : Add documentation
  subroutine transform2_5d_accel (g, axf)
    use gs2_layouts, only: g_lo, accel_lo, accelx_lo, ik_idx
    use unit_tests, only: debug_message
    use job_manage, only: time_message
    use fft_work, only: time_fft
#ifdef SHMEM
    use shm_mpi3, only : shm_info, shm_get_node_pointer, shm_node_barrier, shm_fence
#endif
    implicit none
    complex, dimension (:,:,g_lo%llim_proc:), target, intent (in out) :: g
    real, dimension (:,:,accelx_lo%llim_proc:), target, intent (in out) :: axf
    integer :: iglo, k, idx
#ifdef SHMEM
    complex, pointer, contiguous ::ag_ptr(:,:,:) => null(), g_ptr(:,:,:) => null()
    real, pointer, contiguous :: axf_ptr(:,:,:) => null() 
    integer  :: kk, nk, kb, bidx
    integer, save :: iglo_s, iglo_e, kbs, kbe
    logical, save :: firsttime =.true.
#endif 

    integer :: itgrid, iduo
    integer :: ntgrid

    call debug_message(4, 'gs2_transforms::transform2_5d_accel starting')

    ntgrid = accel_lo%ntgrid
#ifdef SHMEM
    g_ptr (1:,1:,g_lo%llim_node:) => shm_get_node_pointer(g, -1)
    ag_ptr (-ntgrid:,1:,accel_lo%llim_node:) => shm_get_node_pointer(ag, -1)
    if(firsttime)then
       firsttime=.false.
       call compute_block_partition(g_lo%ntheta0 * g_lo%naky, iglo_s, iglo_e)
       call compute_block_partition(accel_lo%nxnky, kbs, kbe)
    end if

#endif
    !
    !CMR+GC, 2/9/2013:
    !  Scaling g would not be necessary if gs2 used standard Fourier coefficients.
    !
    ! scale the g and copy into the anti-aliased array ag
    ! zero out empty ag
    ! touch each g and ag only once
#ifndef SHMEM
    idx = g_lo%llim_proc
    do k = accel_lo%llim_proc, accel_lo%ulim_proc
       ! CMR: aidx is true for modes killed by the dealiasing mask
       ! so following line removes ks in dealiasing region
       if (aidx(k)) then
          ag(:,:,k) = 0.0
       else
          ! scaling only for k_y not the zero mode
          if (ik_idx(g_lo, idx) /= 1) then
             do iduo = 1, 2
                do itgrid = 1, 2*ntgrid +1
                   ! It seems strange that we scale the input and then
                   ! copy it to ag -- why not just scaling ag, which is
                   ! what we are going to fft and return?
                   g(itgrid, iduo, idx) &
                        = 0.5 * g(itgrid, iduo, idx)
                   ag(itgrid - (ntgrid+1), iduo, k) &
                        = g(itgrid, iduo, idx)
                enddo
             enddo
             ! in case of k_y being zero: just copy
          else
             do iduo = 1, 2
                do itgrid = 1, 2*ntgrid+1
                   ag(itgrid-(ntgrid+1), iduo, k) &
                        = g(itgrid, iduo, idx)
                enddo
             enddo
          endif
          idx = idx + 1
       endif
    enddo
#else
    call shm_node_barrier
    bidx=0
    do kb = accel_lo%llim_node, accel_lo%ulim_node, accel_lo%nxnky
       idx = g_lo%llim_node +bidx * g_lo%naky*g_lo%ntheta0

       ! we might not have scaled all g
       Do iglo = idx+iglo_s, idx+iglo_s+iglo_e
          if (ik_idx(g_lo, iglo) /= 1) then
             g_ptr(:,:, iglo) = 0.5 * g_ptr(:,:,iglo)
          endif
       enddo

       call shm_node_barrier
       call shm_fence(g_ptr(1,1,g_lo%llim_node))

       do k = kb+kbs, kb+kbs+kbe !accel_lo%llim_node, accel_lo%ulim_node       
          ! CMR: aidx is true for modes killed by the dealiasing mask
          ! so following line removes ks in dealiasing region
          if (aidx(k)) then
             ag_ptr(:,:,k) = 0.0
          else
             do iduo = 1, 2
                do itgrid =1, 2*ntgrid +1
                   ag_ptr(itgrid - (ntgrid+1), iduo, k) &
                        = g_ptr(itgrid, iduo, gidx(k))
                enddo
             enddo
             !         idx=idx+1
          endif
       enddo
       bidx=bidx+1
    enddo
#endif

    !CMR+GC: 2/9/2013
    ! Following large loop can be eliminated if gs2 used standard Fourier coefficients.
    ! (See above comment in transform2_5d.)

#ifndef SHMEM
    ! we might not have scaled all g
    Do iglo = idx, g_lo%ulim_proc
       if (ik_idx(g_lo, iglo) /= 1) then
          g(:,:, iglo) = 0.5 * g(:,:,iglo)
       endif
    enddo
#else

    axf_ptr(-ntgrid:,1:,accelx_lo%llim_node:) => & 
         shm_get_node_pointer(axf, -1)
#endif


    ! transform
#ifndef SHMEM
    idx = 1
    call time_message(.false., time_fft, ' FFT')
    do k = accel_lo%llim_proc, accel_lo%ulim_proc, accel_lo%nxnky
       ! remember FFTW3 for c2r destroys the contents of ag
       call yf_fft%execute_c2r(ag(:, :, k:), axf(:, :, ia(idx):))
       idx = idx + 1
    end do
    call time_message(.false., time_fft, ' FFT')
#else
    ! in the following the ranks operate on other data ranges
    call shm_node_barrier

    if (use_shm_2d_plan) then
       idx = 1 + shm_info%id
       idx = mod(idx -1, accel_lo%nia) +1
       nk = accel_lo%ulim_node - accel_lo%llim_node +1
    else
       idx = 1
    endif
    call time_message(.false., time_fft, ' FFT')
    do k = accel_lo%llim_node, accel_lo%ulim_node, accel_lo%nxnky
       if (use_shm_2d_plan) then
          kk = k + shm_info%id * accel_lo%nxnky
          kk = mod(kk - accel_lo%llim_node, nk) + accel_lo%llim_node
       else
          kk = k
       endif

       call fft_shm(kk,ia(idx))

       idx = idx + 1
       if ( use_shm_2d_plan) then
          idx = mod(idx -1, accel_lo%nia) +1
       endif
    end do
    call shm_node_barrier
    call time_message(.false., time_fft, ' FFT')
  contains

    !> FIXME : Add documentation
    subroutine fft_shm(k1, k2)
      use mp, only : proc0
      implicit none
      integer, intent(in) :: k1, k2

      if (use_shm_2d_plan) then 
         call fft2d_shm(k1, k2)
      else
         call fft1d_shm(k1,k2)
      endif

    end subroutine fft_shm

    !> FIXME : Add documentation
    subroutine fft1d_shm(k1,k2)
      use shm_mpi3, only : shm_info, shm_fence
      implicit none
      integer, intent(in) :: k1, k2

      integer :: i, j, is, ie, js, je, idw, nwk
      integer :: n1,n2,n3,p1,p2,p3

      n1 = 2*(2*ntgrid+1)
      n2 = accel_lo%ndky
      n3 = accel_lo%nx
      p1 = 2*(2*ntgrid+1)
      p2 = accelx_lo%ny
      p3 =  accelx_lo%nx

      if ( p3 /= n3 ) then 
         write(0,*) "fft_shm: WARNING third dimesions not equal"
      endif

      nwk = g_lo%ppn
      idw = shm_info%id
      is = idw*(n2/nwk) + min(idw,mod(n2, nwk))
      ie = is + (n2/nwk) -1 
      if ( idw < mod(n2, nwk)) ie = ie +1
      js = idw*(p3/nwk) + min(idw,mod(p3, nwk))
      je = js + (p3/nwk) -1 
      if ( idw < mod(p3, nwk)) je = je +1

      do i = is, ie
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft( &
#else
         call dfftw_execute_dft( &
#endif
         faccel_shmy%plan, ag_ptr(-ntgrid,1,k1+i), ag_ptr(-ntgrid,1,k1+i))
      end do

      call shm_node_barrier
      call shm_fence(ag_ptr(-ntgrid,1,accel_lo%llim_node))

      do j = js, je
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft_c2r( &
#else
         call dfftw_execute_dft_c2r( &
#endif
         faccel_shmx%plan, ag_ptr(-ntgrid,1,k1+j*n2), axf_ptr(-ntgrid,1,k2+j*p2))
      end do
    end subroutine fft1d_shm

    !> FIXME : Add documentation         
    subroutine fft2d_shm(k1,k2)
      use shm_mpi3, only : shm_info, shm_fence
      implicit none
      integer, intent(in) :: k1, k2

      integer :: i, j, n2,n3, howmany,ts1,ts2

      complex,allocatable :: cin(:,:)
      real, allocatable ::  rout(:,:) 

      n2 = accel_lo%ny
      n3 = accel_lo%nx
      howmany = yf_fft_shm%howmany
      ts1=yf_fft_shm%ts1
      ts2=yf_fft_shm%ts2

      if (use_shm_2d_plan_buff) then
         allocate(rout(howmany, n2*n3), &
              cin(howmany, (n2/2+1)*n3))

         do j=1,(n2/2+1)*n3
            do i=1,howmany
               cin(i,j) =  ag_ptr(ts1+i-1, ts2,k1+j-1)
            enddo
         enddo

         call yf_fft_shm%execute_c2r(cin, rout)

         do j=1,n2*n3
            do i=1,howmany
               axf_ptr(ts1+i-1,ts2,k2+j-1)= rout(i,j)
            end do
         end do
      else
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft_c2r( &
#else
         call dfftw_execute_dft_c2r( &
#endif
         yf_fft_shm%plan, ag_ptr(ts1, ts2,k1), axf_ptr(ts1,ts2,k2))
      end if

    end subroutine fft2d_shm
# endif

  end subroutine transform2_5d_accel

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> FIXME : Add documentation
  subroutine inverse2_5d_accel (axf, g)
    use gs2_layouts, only: g_lo, accel_lo, accelx_lo, ik_idx
    use job_manage, only: time_message
    use fft_work, only: time_fft
#ifdef SHMEM
    use shm_mpi3, only : shm_info, shm_get_node_pointer, shm_node_barrier, shm_fence
#endif
    implicit none
    real, dimension (:,:,accelx_lo%llim_proc:), target, intent (in out) :: axf
    complex, dimension (:,:,g_lo%llim_proc:), target, intent (out) :: g
    integer :: iglo, idx, k
    integer :: itgrid, iduo
    integer :: ntgrid
#ifdef SHMEM
    complex, pointer, contiguous ::ag_ptr(:,:,:) => null(), g_ptr(:,:,:) => null()
    real, pointer, contiguous :: axf_ptr(:,:,:) => null() 
    integer kk, nk, kb, bidx
    integer, save :: iglo_s, iglo_e, kbs, kbe
    logical, save :: firsttime =.true.
#endif

    ntgrid = accel_lo%ntgrid
#ifndef SHMEM
    ! transform
    idx = 1
    call time_message(.false., time_fft, ' FFT')
    do k = accelx_lo%llim_proc, accelx_lo%ulim_proc, accelx_lo%nxny
       call yb_fft%execute_r2c(axf(:, :, k:), ag(:, :, iak(idx):))
       idx = idx + 1
    end do
    call time_message(.false., time_fft, ' FFT')

    idx = g_lo%llim_proc
    do k = accel_lo%llim_proc, accel_lo%ulim_proc
       ! ignore the large k (anti-alias)
       if ( .not.aidx(k)) then
          !
          !CMR+GC, 2/9/2013:
          !  Scaling g here would be unnecessary if gs2 used standard Fourier coefficients.
          ! different scale factors depending on ky == 0
          if (ik_idx(g_lo, idx) /= 1) then
             do iduo = 1, 2 
                do itgrid = 1, 2*ntgrid+1
                   g(itgrid, iduo, idx) &
                        = 2.0 * yb_fft%scale * ag(itgrid-(ntgrid+1), iduo, k)
                enddo
             enddo
          else
             do iduo = 1, 2 
                do itgrid = 1, 2*ntgrid+1
                   g(itgrid, iduo, idx) &
                        = yb_fft%scale * ag(itgrid-(ntgrid+1), iduo, k)
                enddo
             enddo
          endif
          idx = idx + 1
       endif
    enddo

    !CMR+GC: 2/9/2013
    ! Following large loop can be eliminated if gs2 used standard Fourier coefficients.
    ! (See above comment in transform2_5d.)

    ! we might not have scaled all g
    Do iglo = idx, g_lo%ulim_proc
       if (ik_idx(g_lo, iglo) /= 1) then
          g(:,:, iglo) = 2.0 * g(:,:,iglo)
       endif
    enddo

#else
    ! shm
    ag_ptr (-ntgrid:,1:,accel_lo%llim_node:) => shm_get_node_pointer(ag, -1)
    axf_ptr(1:,1:,accelx_lo%llim_node:) => shm_get_node_pointer(axf, -1)

    call shm_node_barrier
    ! transform
    if ( use_shm_2d_plan) then
       idx = 1 + shm_info%id
       idx = mod(idx-1,accel_lo%nia) +1
       nk = accelx_lo%ulim_node - accelx_lo%llim_node +1
    else
       idx = 1
    endif
    call time_message(.false., time_fft, ' FFT')
    do k = accelx_lo%llim_node, accelx_lo%ulim_node, accelx_lo%nxny
       if (use_shm_2d_plan) then
          kk = k + shm_info%id * accelx_lo%nxny
          kk = mod(kk - accelx_lo%llim_node, nk) + accelx_lo%llim_node
       else
          kk = k
       endif

       call fft_shm_inv(kk, iak(idx))

       idx = idx + 1
       if ( use_shm_2d_plan) then
          idx = mod(idx-1, accel_lo%nia) +1
       endif
    end do

    ! see comment from _accel
    call shm_node_barrier
    call time_message(.false., time_fft, ' FFT')

    if(firsttime)then
       firsttime=.false.
       call compute_block_partition(g_lo%ntheta0 * g_lo%naky, iglo_s, iglo_e)
       call compute_block_partition(accel_lo%nxnky, kbs, kbe)
    end if

    g_ptr (1:,1:,g_lo%llim_node:) => shm_get_node_pointer(g, -1)
    bidx=0
    do kb = accel_lo%llim_node, accel_lo%ulim_node,accel_lo%nxnky
       !      idx = g_lo%llim_node +bidx * g_lo%naky*g_lo%ntheta0 + iglo_s !g_lo%llim_node !gidx_s !g_lo%llim_proc
       do k = kb+kbs, kb+kbs+kbe 
          ! ignore the large k (anti-alias)
          if ( .not.aidx(k)) then
             !
             !CMR+GC, 2/9/2013:
             !  Scaling g here would be unnecessary if gs2 used standard Fourier coefficients.
             ! different scale factors depending on ky == 0
             do iduo = 1, 2 
                do itgrid = 1, 2*ntgrid+1
                   g_ptr(itgrid, iduo, gidx(k)) &
                        = yb_fft%scale * ag_ptr(itgrid-(ntgrid+1), iduo, k)
                enddo
             enddo
             !          idx = idx + 1
          endif
       enddo
       call shm_node_barrier
       call shm_fence(g_ptr(1,1,g_lo%llim_node))
       ! we might not have scaled all g
       idx = g_lo%llim_node +bidx * g_lo%naky*g_lo%ntheta0 
       Do iglo = idx+iglo_s, idx+iglo_s+iglo_e
          if (ik_idx(g_lo, iglo) /= 1) then
             g_ptr(:,:, iglo) = 2.0 * g_ptr(:,:,iglo)
          endif
       enddo
       bidx=bidx+1
    enddo

    !CMR+GC: 2/9/2013
    ! Following large loop can be eliminated if gs2 used standard Fourier coefficients.
    ! (See above comment in transform2_5d.)

    call shm_node_barrier

  contains

    !> FIXME : Add documentation
    subroutine fft_shm_inv(k1, k2)
      implicit none
      integer, intent(in) :: k1, k2
      if (use_shm_2d_plan) then
         call fft2d_shm_inv(k1, k2)
      else
         call fft1d_shm_inv(k1,k2)
      endif
    end subroutine fft_shm_inv

    !> FIXME : Add documentation
    subroutine fft1d_shm_inv(k1,k2)
      use shm_mpi3, only : shm_info, shm_fence
      implicit none
      integer, intent(in) :: k1, k2

      integer i, j, is, ie, js, je, idw, nwk
      integer nx, ny, px, py

      ! test ny == py !!!
      nx = accelx_lo%ny
      ny = accelx_lo%nx
      px = accel_lo%ndky 
      py = accel_lo%nx

      nwk = g_lo%ppn
      idw = shm_info%id
      is = idw*(px/nwk) + min(idw,mod(px, nwk))
      ie = is + (px/nwk) -1 
      if ( idw < mod(px, nwk)) ie = ie +1
      js = idw*(ny/nwk) + min(idw,mod(ny, nwk))
      je = js + (ny/nwk) -1 
      if ( idw < mod(ny, nwk)) je = je +1

      do j = js, je
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft_r2c( &
#else
         call dfftw_execute_dft_r2c( &
#endif
         baccel_shmx%plan, axf_ptr(1,1,k1+j*nx), ag_ptr(-ntgrid,1,k2+j*px))
      end do

      call shm_node_barrier
      call shm_fence(ag_ptr(-ntgrid,1,accel_lo%llim_node))

      do i = is, ie
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft( &
#else
         call dfftw_execute_dft( &
#endif
         baccel_shmy%plan, ag_ptr(-ntgrid,1,k2+i), ag_ptr(-ntgrid,1,k2+i))
      end do
    end subroutine fft1d_shm_inv

    !> FIXME : Add documentation          
    subroutine fft2d_shm_inv(k1,k2)
      use, intrinsic :: iso_c_binding 
      use shm_mpi3, only : shm_info, shm_fence
      implicit none
      integer, intent(in) :: k1, k2

      integer i,j, ts1, ts1o, ts2, n2,n3, howmany
      real, allocatable :: rin(:,:)
      complex,allocatable::cout(:,:)

      n2 = accel_lo%ny
      n3 = accel_lo%nx
      ts1 = yb_fft_shm%ts1
      ts2=yb_fft_shm%ts2
      howmany=yb_fft_shm%howmany
      ts1o=ntgrid +1+ts1

      if ( use_shm_2d_plan_buff) then 
         allocate( rin(howmany, n2*n3), &
              cout(howmany,(n2/2+1)*n3))

         do j=1,n2*n3
            do i=1,howmany
               rin(i,j) =  axf_ptr(ts1o+i-1,ts2,k1+j-1)
            enddo
         enddo
         call yb_fft_shm%execute_r2c(rin, cout)
         do j=1,(n2/2+1)*n3
            do i=1,howmany
               ag_ptr(ts1+i-1,ts2,k2+j-1) = cout(i,j)
            enddo
         enddo
      else
         ! For now these SHMEM calls don't match the modern fortran FFTW interface
         ! so we have to fall back to an implied interface approach
#ifdef SINGLE_PRECISION
         call sfftw_execute_dft_r2c( &
#else
         call dfftw_execute_dft_r2c( &
#endif
         yb_fft_shm%plan, axf_ptr(ts1o,ts2,k1), ag_ptr(ts1,ts2,k2))
      end if
    end subroutine fft2d_shm_inv
# endif
  end subroutine inverse2_5d_accel

  !> FIXME : Add documentation        
  subroutine init_3d (nny_in, nnx_in, how_many_in)
    use fft_work, only: init_crfftw, init_rcfftw, FFT_TO_SPECTRAL_SPACE, FFT_TO_REAL_SPACE
    implicit none
    integer, intent(in) :: nny_in, nnx_in, how_many_in
    integer, save :: nnx = 0, nny = 0, how_many = 0

    ! If we've already initialised check if any of the sizes have changed,
    ! if not we can just return, otherwise delete plans and start again
    if (initialized_3d) then
       if (nnx /= nnx_in .or. nny /= nny_in .or. how_many /= how_many_in) then
          call xf3d_cr%delete
          call xf3d_rc%delete
       else
          return
       end if
    end if
    initialized_3d = .true.
    nny = nny_in
    nnx = nnx_in
    how_many = how_many_in

    call init_crfftw (xf3d_cr, FFT_TO_REAL_SPACE, nny, nnx, how_many)
    call init_rcfftw (xf3d_rc, FFT_TO_SPECTRAL_SPACE, nny, nnx, how_many)
  end subroutine init_3d

  !> FIXME : Add documentation        
  subroutine transform2_3d (phi, phixf, nny, nnx)
    use theta_grid, only: ntgrid
    use kt_grids, only: naky, ntheta0, aky
    use job_manage, only: time_message
    use fft_work, only: time_fft
    implicit none
    integer, intent(in) :: nnx, nny
    complex, dimension (-ntgrid:,:,:), intent (in) :: phi
    real, dimension (:,:,-ntgrid:), intent (out) :: phixf  
    real, dimension (:,:,:), allocatable :: phix
    complex, dimension (:,:,:), allocatable :: aphi
    real :: fac
    integer :: ig, ik, it

    ! scale, dealias and transpose
    call init_3d (nny, nnx, 2*ntgrid+1)

    allocate (phix (-ntgrid:ntgrid, nny, nnx))
    allocate (aphi (-ntgrid:ntgrid, nny/2+1, nnx))
    aphi = 0.

    do ik = 1, naky
       fac = 0.5
       if (aky(ik) < epsilon(0.)) fac = 1.0
       do it = 1, (ntheta0+1) / 2
          do ig = -ntgrid, ntgrid
             aphi(ig, ik, it) = phi(ig, it, ik) * fac !Note transpose of it and ik here
          end do
       end do
       do it = (ntheta0+1) / 2 + 1, ntheta0
          do ig = -ntgrid, ntgrid
             !CMR, 30/3/2010: bug fix to replace nx by nnx on next line
             aphi(ig, ik, it-ntheta0+nnx) = phi(ig,it,ik) * fac
          end do
       end do
    end do

    ! transform
    call time_message(.false., time_fft, ' FFT')
    call xf3d_cr%execute_c2r(aphi, phix)
    call time_message(.false., time_fft, ' FFT')

    do it = 1, nnx
       do ik = 1, nny
          do ig = -ntgrid, ntgrid
             phixf(it, ik, ig) = phix(ig, ik, it)
          end do
       end do
    end do

    deallocate (aphi, phix)

  end subroutine transform2_3d

  !> FIXME : Add documentation
  subroutine inverse2_3d (phixf, phi, nny, nnx)
    !CMR, 30/4/2010:
    ! Fixed up previously buggy handling of dealiasing.
    use theta_grid, only: ntgrid
    use kt_grids, only: naky, ntheta0, aky
    use job_manage, only: time_message
    use fft_work, only: time_fft
    implicit none
    real, dimension (:,:,-ntgrid:), intent(in) :: phixf
    complex, dimension (-ntgrid:,:,:), intent(out) :: phi
    integer, intent(in) :: nnx, nny
    complex, dimension (:,:,:), allocatable :: aphi
    real, dimension (:,:,:), allocatable :: phix
    real :: fac
    integer :: ik, it, ig
    ! Note assumes init_3d has already been called
    allocate (aphi (-ntgrid:ntgrid, nny/2+1, nnx))
    allocate (phix (-ntgrid:ntgrid, nny, nnx))

    do it = 1, nnx
       do ik = 1, nny
          do ig = -ntgrid, ntgrid
             phix(ig, ik, it) = phixf(it, ik, ig)
          end do
       end do
    end do

    ! transform
    call time_message(.false., time_fft, ' FFT')
    call xf3d_rc%execute_r2c(phix, aphi)
    call time_message(.false., time_fft, ' FFT')

    ! dealias and scale
    do it = 1, (ntheta0+1) / 2
       do ik = 1, naky
          fac = 2.0
          if (aky(ik) < epsilon(0.0)) fac = 1.0
          do ig = -ntgrid, ntgrid
             phi(ig, it, ik) = aphi(ig, ik, it) * fac * xf3d_rc%scale
          end do
       end do
    end do

    !CMR, 30/4/2010:  fixed up previously buggy handling of dealiasing
    do it = (ntheta0+1) / 2 + 1, ntheta0
       do ik = 1, naky
          fac = 2.0
          if (aky(ik) < epsilon(0.0)) fac = 1.0
          do ig = -ntgrid, ntgrid
             phi(ig,it,ik) = aphi(ig, ik, it-ntheta0+nnx) * fac * xf3d_rc%scale
          end do
       end do
    end do

    deallocate (aphi, phix)

  end subroutine inverse2_3d

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> FIXME : Add documentation
  subroutine transform2_2d (phi, phixf, nny, nnx)
    use fft_work, only: FFT_TO_REAL_SPACE, init_crfftw, time_fft
    use kt_grids, only: naky, nakx => ntheta0, nx, aky
    use job_manage, only: time_message
    implicit none
    integer, intent(in) :: nnx, nny
    complex, dimension(:, :), intent (in) :: phi
    real, dimension(:, :), intent (out) :: phixf
    real, dimension(:, :), allocatable :: phix
    complex, dimension(:, :), allocatable :: aphi
    real :: fac
    integer :: ik, it
    type (fft_type) :: xf2d

    !May be inefficient to create and destroy this fft plan
    !on every call to the routine. We may want to move this
    !variable to module level and check its created flag.
    call init_crfftw (xf2d, FFT_TO_REAL_SPACE, nny, nnx, 1)

    allocate (phix (nny, nnx))
    allocate (aphi (nny/2+1, nnx))
    phix(:,:)=0.; aphi(:,:)=cmplx(0.,0.)

    ! scale, dealias and transpose
    do ik = 1, naky
       fac = 0.5
       if (aky(ik) < epsilon(0.)) fac = 1.0
       do it = 1, (nakx+1) / 2
          aphi(ik, it) = phi(it, ik) * fac
       end do
       do it = (nakx+1) / 2 + 1, nakx
          aphi(ik, it-nakx+nx) = phi(it, ik) * fac
       end do
    end do

    ! transform
    call time_message(.false., time_fft, ' FFT')
    call xf2d%execute_c2r(aphi, phix)
    call time_message(.false., time_fft, ' FFT')

    phixf = transpose(phix)

    deallocate (aphi, phix)
    !RN> this statement causes error for lahey with DEBUG. I don't know why
    !<DD>Reinstating after discussion with RN, if this causes anyone an issue
    !    then we can guard this line with some directives.   
    call xf2d%delete
  end subroutine transform2_2d

  !> FIXME : Add documentation
  subroutine inverse2_2d (phixf, phi, nny, nnx)
    use fft_work, only: FFT_TO_SPECTRAL_SPACE, init_rcfftw, time_fft
    use kt_grids, only: naky, nakx => ntheta0, aky
    use job_manage, only: time_message
    implicit none
    real, intent(in) :: phixf(:,:)
    complex, intent(out) :: phi(:,:)
    integer, intent(in) :: nnx, nny
    complex, allocatable :: aphi(:,:)
    real, allocatable :: phix(:,:)
    real :: fac
    integer :: ik, it
    type (fft_type) :: xf2d

    !May be inefficient to create and destroy this fft plan
    !on every call to the routine. We may want to move this
    !variable to module level and check its created flag.
    call init_rcfftw (xf2d, FFT_TO_SPECTRAL_SPACE, nny, nnx, 1)

    allocate (aphi (nny/2+1, nnx))
    allocate (phix (nny, nnx))
    phix = 0.; aphi = cmplx(0.,0.)
    phix = transpose(phixf)

    ! transform
    call time_message(.false., time_fft, ' FFT')
    call xf2d%execute_r2c(phix, aphi)
    call time_message(.false., time_fft, ' FFT')

    ! scale, dealias and transpose
    do it = 1, nakx
       do ik = 1, naky
          fac = 2.0
          if (aky(ik) < epsilon(0.0)) fac = 1.0
          phi(it, ik) = aphi(ik, it) * fac * xf2d%scale
       end do
    end do

    deallocate (aphi, phix)
    !RN> this statement causes error for lahey with DEBUG. I don't know why
    !<DD>Reinstating after discussion with RN, if this causes anyone an issue
    !    then we can guard this line with some directives.   
    call xf2d%delete
  end subroutine inverse2_2d

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> FIXME : Add documentation        
  subroutine init_zf (ntgrid, howmany)
    use fft_work, only: init_z, FFT_TO_REAL_SPACE
    implicit none
    integer, intent (in) :: ntgrid, howmany

    if (initialized_zf) return
    initialized_zf = .true.

    ! Note we pass 2*ntgrid rather than 2*ntgrid + 1 so that we don't
    ! include the duplicate theta point at {-pi,pi}
    call init_z(zf_fft, FFT_TO_REAL_SPACE, 2*ntgrid, howmany)
  end subroutine init_zf

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

  !> FIXME : Add documentation        
  subroutine kz_spectrum (an, an2)
    use job_manage, only: time_message
    use fft_work, only: time_fft
    complex, dimension (:,:,:), intent(in out)  :: an
    complex, dimension (:,:,:), intent(out) :: an2
    an2 = 0.0
    call time_message(.false., time_fft, ' FFT')
    call zf_fft%execute_c2c(an, an2)
    call time_message(.false., time_fft, ' FFT')
    an2 = conjg(an2)*an2
  end subroutine kz_spectrum

  !> FIXME : Add documentation
  subroutine finish_transforms
    use redistribute, only : delete_redist
    use fft_work, only: finish_fft_work
#ifdef SHMEM
    use shm_mpi3, only : shm_free
#endif

    if(allocated(xxf)) deallocate(xxf)
    if(allocated(ia)) deallocate(ia)
    if(allocated(iak)) deallocate(iak)
    if(allocated(aidx)) deallocate(aidx)
#ifndef SHMEM
    if(allocated(ag)) deallocate(ag)
#else
    if (associated(ag)) call shm_free(ag)
#endif

    call delete_redist(g2x)
    call delete_redist(x2y)

    if(allocated(fft)) deallocate(fft) 

    !Destroy fftw plans
    call yf_fft%delete
    call yb_fft%delete
    call xf_fft%delete
    call xb_fft%delete
    call zf_fft%delete
    call xf3d_cr%delete
    call xf3d_rc%delete

    !Reset init state flags
    initialized = .false.
    initialized_y_fft = .false.
    initialized_x_redist = .false.
    initialized_y_redist = .false.
    initialized_3d = .false.
    initialized_zf = .false.

    !Tidy up fft internals (FFTW3 only)
    call finish_fft_work
  end subroutine finish_transforms

#ifdef SHMEM
  !> FIXME : Add documentation
  subroutine compute_block_partition(whole_range, shift, extend)
    ! computes shift + extend
    use shm_mpi3, only : shm_info
    implicit none

    integer, intent(in) :: whole_range
    integer, intent(out) :: shift, extend

    integer n, chnk

    n = whole_range
    chnk = n / shm_info%size
    shift = shm_info%id * chnk + min(shm_info%id, mod(n,shm_info%size))
    extend = chnk-1
    if (shm_info%id< mod(n,shm_info%size)) extend= extend + 1

  end subroutine compute_block_partition
#endif       
end module gs2_transforms