le_grids.f90 Source File


Contents

Source Code


Source Code

!> FIXME : Add documentation
module egrid

! By Tomo Tatsuno, Aug 2005
! Improved accuracy and speed and maximum number of energy grid points
!

  implicit none

  private

  public :: setvgrid, setvgrid_genquad, init_egrid
  public :: zeroes, x0, zeroes_maxw, x0_maxw

  real :: x0_maxw
  real, dimension(:), allocatable :: x0
  real, dimension(:), allocatable, save :: zeroes_maxw
  real, dimension(:,:), allocatable, save :: zeroes

contains

  !> FIXME : Add documentation  
  subroutine init_egrid (negrid)
    use species, only: nspec
    implicit none
    integer, intent (in) :: negrid

    if (.not. allocated(zeroes)) then
       allocate (zeroes(negrid-1, nspec)) ; zeroes = 0.
       allocate (zeroes_maxw(negrid-1)) ; zeroes_maxw = 0.
       allocate (x0(nspec))
    end if
  end subroutine init_egrid
  
  !> FIXME : Add documentation
  subroutine setvgrid (vcut_in, negrid, epts, wgts, nesub_in)
    use constants, only: pi => dpi
    use gauss_quad, only: get_legendre_grids_from_cheb, get_laguerre_grids
    use species, only: nspec, spec, f0_sdanalytic, f0_maxwellian, calculate_f0_arrays, f0_values

    implicit none
    integer, intent (in) :: negrid
    real, intent (in) :: vcut_in
    integer, intent (in) :: nesub_in
    real, dimension(:,:), intent (out) :: epts, wgts
    integer:: nesub, is
    real:: vcut

    call init_egrid (negrid)

    do is = 1,nspec 
       select case (spec(is)%f0type)
       case(f0_maxwellian)
          vcut = vcut_in
          nesub = nesub_in
       case(f0_sdanalytic)
          vcut = 1.0
          nesub = negrid
       case default 
          vcut = vcut_in 
          nesub = nesub_in
       end select

       ! get grid points in v up to vcut (epts is not E yet)
       call get_legendre_grids_from_cheb (0., vcut, epts(:nesub,is), wgts(:nesub,is))

       ! change from v to E
       epts(:nesub,is) = epts(:nesub,is)**2

       ! absorb exponential and volume element in weights
       !wgts(:nesub) = wgts(:nesub)*epts(:nesub)*exp(-epts(:nesub))/sqrt(pi)
       !> GW: F0 now gets multiplied later
       wgts(:nesub, is) = wgts(:nesub, is)*epts(:nesub,is)*pi

       if (negrid > nesub) then

          ! get grid points in y = E - vcut**2 (epts not E yet)
          call get_laguerre_grids (epts(nesub+1:,is), wgts(nesub+1:,is))

          ! change from y to E
          epts(nesub+1:,is) = epts(nesub+1:,is) + vcut**2

          ! Do not absorb exponential in weights, since we'll be multiplying by f0 later
          wgts(nesub+1:, is) = wgts(nesub+1:, is)*exp(epts(nesub+1:,is)-vcut**2)*pi*0.5*sqrt(epts(nesub+1:,is))


       end if
    end do

    call calculate_f0_arrays(epts)
    wgts = wgts * f0_values
    
    zeroes(:,:) = sqrt(epts(:negrid-1,:))
    x0(:) = sqrt(epts(negrid,:))

  end subroutine setvgrid

  !> FIXME : Add documentation  
  subroutine setvgrid_genquad (negrid, epts, wgts)
    use genquad, only: get_quadrature_rule
    use constants, only: pi => dpi
    use gauss_quad, only: get_legendre_grids_from_cheb, get_laguerre_grids
    use species, only: nspec, spec, calculate_f0_arrays, eval_f0, set_current_f0_species, f0_maxwellian

    implicit none
    
    integer, intent (in) :: negrid
    real, dimension(:,:), intent (out) :: epts
    real, dimension(:,:), intent (out) :: wgts
    integer :: is

    call init_egrid (negrid)

    do is = 1,nspec

      ! Call the function that sets the species number so that eval_f0 takes only one argument
      call set_current_f0_species(is)

      if ( spec(is)%f0type .EQ. f0_maxwellian ) then
         ! If Maxwellian, use an semi-infinite domain
         call get_quadrature_rule(eval_f0,negrid,0.0,1.0,epts(:,is),wgts(:,is),.false.,.true.)
      else 
         ! Otherwise, use a bounded domain with vmax = vref_s.
         call get_quadrature_rule(eval_f0,negrid,0.0,1.0,epts(:,is),wgts(:,is),.false.,.false.)
      end if
    
      ! change from v to E
      epts(:,is) = epts(:,is)**2

      ! absorb exponential and volume element in weights
      !wgts(:nesub) = wgts(:nesub)*epts(:nesub)*exp(-epts(:nesub))/sqrt(pi)
      ! No longer absorb maxwellian... allow for arbitrary f0. EGH/GW
      ! See eq. 4.12 of M. Barnes's thesis
      wgts(:, is) = wgts(:, is)*pi*epts(:,is)

    end do

    call calculate_f0_arrays(epts)

    zeroes(:,:) = sqrt(epts(:negrid-1,:))
    x0(:) = sqrt(epts(negrid,:))

  end subroutine setvgrid_genquad

end module egrid

!> FIXME : Add documentation
module le_grids
  use abstract_config, only: abstract_config_type, CONFIG_MAX_NAME_LEN
  use redistribute, only: redist_type
  implicit none

  private

  public :: init_le_grids, finish_le_grids, wnml_le_grids
  public :: integrate_species, integrate_moment, integrate_kysum, integrate_volume
  public :: energy, energy_maxw, speed, speed_maxw, al, jend, forbid, wl, w, wxi, w_maxw, vcut
  public :: negrid, nlambda, ng2, nxi, lmax, nesub, ixi_to_il, ixi_to_isgn, xi, sgn
  public :: init_weights, finish_weights, legendre_transform, hermite_prob
  public :: eint_error, lint_error, trap_error, wdim, get_flux_vs_theta_vs_vpa
  public :: lambda_map, energy_map, g2le, g2gf, init_map, init_g2gf
  public :: setup_passing_lambda_grids, setup_trapped_lambda_grids_old_finite_difference
  public :: new_trap_int, passing_wfb, trapped_wfb, mixed_wfb
  public :: grid_has_trapped_particles, il_is_trapped, il_is_passing
  public :: il_is_wfb, can_be_ttp, is_ttp, is_passing_hybrid_electron
  public :: is_bounce_point, is_lower_bounce_point, is_upper_bounce_point

  public :: le_grids_config_type, set_le_grids_config, get_le_grids_config
  
  interface integrate_moment
     module procedure integrate_moment_c34
     module procedure integrate_moment_r34
     module procedure integrate_moment_lec
     module procedure integrate_moment_r33
  end interface

  interface integrate_species
     module procedure integrate_species_master
     module procedure integrate_species_gf_nogather
  end interface

  interface integrate_volume
     module procedure integrate_volume_c
     module procedure integrate_volume_r
  end interface

  interface get_hermite_polynomials
     module procedure get_hermite_polynomials_1d
     module procedure get_hermite_polynomials_4d
  end interface

  real, dimension (:), allocatable :: xx ! (ng2)
  real, dimension (:,:), allocatable :: wlerr ! mbmark
  real, dimension (:,:,:), allocatable :: werr
  real, dimension (:,:,:), allocatable :: wlterr
  real, dimension (:,:,:), allocatable :: wlmod
  real, dimension (:,:,:), allocatable :: wtmod
  real, dimension (:,:,:), allocatable :: wmod
  real, dimension (:,:,:), allocatable :: lpe
  real, dimension (:,:), allocatable, save :: lpl
  real, dimension (:,:,:,:), allocatable, save :: lpt

  real, dimension (:), allocatable :: energy_maxw, w_maxw, speed_maxw 
  real, dimension (:,:), allocatable :: energy, w, speed
  real, dimension (:), allocatable :: al ! (nlambda)
  real, dimension (:,:), allocatable :: wl, wxi ! (nlambda,-ntgrid:ntgrid)
  integer, dimension (:), allocatable :: jend ! (-ntgrid:ntgrid)
  logical, dimension (:,:), allocatable :: forbid, is_ttp, is_bounce_point ! (-ntgrid:ntgrid,nlambda)
  logical, dimension (:), allocatable :: can_be_ttp

  real, dimension (:,:), allocatable :: xi
  integer, dimension (:,:), allocatable :: ixi_to_il, ixi_to_isgn
  integer, dimension (2) :: sgn

  integer :: wdim
  complex, dimension (:,:), allocatable :: integration_work
  ! (-ntgrid:ntgrid, -*- processor-dependent -*-)
  logical :: exist

 ! knobs
  integer :: ngauss, npassing, negrid, nesuper, nesub
  real :: bouncefuzz, vcut
  logical :: genquad

  integer :: nlambda, ng2, lmax, nxi
  logical :: test = .false.
  logical :: trapped_particles = .true.
  logical :: new_trap_int = .false.
  logical :: new_trap_int_split = .false.
  logical :: radau_gauss_grid =.true. ! default lambda grid changed
  logical :: split_passing_region = .false.
  ! The default lambda grid is now gauss-radau, for passing particles up to 
  ! and including the passing-trapped boundary (wfb). The grid gives finite weight to
  ! the class of particles which bounce at theta = +/- pi (wfb) in the passing integral
  ! Note that the old gauss-legendre is used in the case that trapped_particles =.false.
  logical :: slintinit = .false.
  logical :: lintinit = .false.
  logical :: eintinit = .false.
  logical :: initialized = .false.
  logical :: leinit = .false.
  logical :: gfinit = .false.
  logical :: lzinit = .false.
  logical :: einit = .false.
  logical :: init_weights_init = .false.

  ! MRH trapped_wfb, passing_wfb control the choice of wfb boundary condition
  logical :: trapped_wfb, passing_wfb, mixed_wfb

  integer :: wfbbc_option_switch
  integer, parameter :: wfbbc_option_mixed = 1, &
       wfbbc_option_passing = 2, &
       wfbbc_option_trapped = 3

  integer :: nmax = 500

  real :: wgt_fac = 10.0

  type (redist_type), save :: lambda_map
  type (redist_type), save :: energy_map
  type (redist_type), save :: g2le
  type (redist_type), save :: g2gf

  integer, dimension(:),allocatable :: recvcnts_intspec,displs_intspec
  integer :: sz_intspec, local_rank_intspec
  
  !> Used to represent the input configuration of le_grids
  type, extends(abstract_config_type) :: le_grids_config_type
     ! namelist : le_grids_knobs
     ! indexed : false
     !> Acts as a small tolerance in deciding if a particular pitch
     !> angle is forbidden at a particular theta grid point. Rather
     !> than considering particles to be forbidden if `1.0 - lambda *
     !> B(theta) < 0` we treat them as forbidden if `1.0 - lambda *
     !> B(theta) < -bouncefuzz`. This provides a small "buffer" to
     !> account for floating round off in the calculation of `lambda`
     !> and `B`.
     real :: bouncefuzz = 10*epsilon(0.0)
     !> If true use generalised quadrature scheme for velocity
     !> integrals and energy grid. See [G. Wilkie
     !> thesis](https://drum.lib.umd.edu/handle/1903/17302).
     logical :: genquad = .false.
     !> Sets the number of energy grid points to use. If specified
     !> then overrides values of `nesuper = min(negrid/10 + 1,
     !> 4)` and `nesub = negrid - nesuper`. If not set then the total
     !> number of energy grid points is just `negrid = nesub +
     !> nesuper`. The energy grid is split into two regions at a point
     !> controlled by `vcut`.
     integer :: negrid = -10
     !> Sets the number of energy grid points below the cutoff.
     integer :: nesub = 8
     !> Sets the number of energy grid points above the cutoff.
     integer :: nesuper = 2
     !> If `true` then use a more accurate integration method for the
     !> trapped pitch angle contribution to velocity integrals. The
     !> default uses a simple, but potentially less accurate, finite
     !> difference method.
     logical :: new_trap_int = .false.
     !> If `true` then split the trapped region into two symmetric
     !> regions when calculating the integration weights associated
     !> with the new_trap_int approach.
     logical :: new_trap_int_split = .false.
     !> The number of untrapped pitch-angles moving in one direction
     !> along field line.
     integer :: npassing = -1
     !> The number of untrapped pitch-angles moving in one direction
     !> along field line is `2*ngauss` if [[le_grids_knobs:npassing]] is
     !> not set. Note that the number of trapped pitch angles is
     !> directly related to the number of theta grid points (per
     !> \(2\pi\)) and is `ntheta/2 + 1`.
     integer :: ngauss = 5
     !> Influences the minimum number of grid points in an integration
     !> subinterval used in calculating the integration grid weights.
     integer :: nmax = 500
     !> The default lambda grid is now gauss-radau, for passing
     !> particles up to and including the passing-trapped boundary
     !> (wfb). The grid gives finite weight to the class of particles
     !> which bounce at theta = +/- pi (wfb) in the passing integral
     !> Note that the old gauss-legendre is used in the case that
     !> trapped_particles =.false.
     !>
     !> @note In reference to above comment -
     !> code doesn't seem to change radau_gauss_grid when
     !> trapped_particles is false, should it?
     logical :: radau_gauss_grid = .true.
     !> If true then we split the passing region into two separate
     !> regions for the purpose of choosing the integration weights.
     !> The split point is determined automatically by an attempt
     !> to minimise the passing weights error.
     logical :: split_passing_region = .false.
     !> If `true` then just does a few tests and writes pitch angle
     !> and energy grids to screen before aborting the simulation.
     logical :: test = .false.
     !> If set to `false` then the lambda grid weighting `wl` is set
     !> to zero for trapped pitch angles. This means that integrals
     !> over velocity space do not include a contribution from trapped
     !> particles which is equivalent to the situation where
     !> `eps<=0.0`.  Trapped particle drifts are not set to zero so
     !> "trapped" particles still enter the source term through
     !> `wdfac`. At least for s-alpha the drifts are the main
     !> difference (*please correct if not true*) between the
     !> `eps<=0.0` and the `trapped_particles = .false.` cases as
     !> `Bmag` is not a function of theta in the `eps<=0.0` case
     !> whilst it is in the `trapped_particles = .false.` case.
     logical :: trapped_particles = .true.
     !> No. of standard deviations from the standard Maxwellian beyond
     !> which the distribution function will be set to 0.
     real :: vcut = 2.5
     !> Set boundary condition for WFB in the linear/parallel solve:
     !>
     !> - "default", "mixed": The previous default boundary condition which
     !>   mixes the passing and trapped boundary conditions
     !> - "passing": Treats the WFB using a passing particle boundary condition
     !> - "trapped": Treats the WFB using a trapped particle boundary condition
     character(len = 20) :: wfbbc_option = 'default'
     !> Influences the maximum integration weight allowed.
     real :: wgt_fac = 10.0
   contains
     procedure, public :: read => read_le_grids_config
     procedure, public :: write => write_le_grids_config
     procedure, public :: reset => reset_le_grids_config
     procedure, public :: broadcast => broadcast_le_grids_config
     procedure, public, nopass :: get_default_name => get_default_name_le_grids_config
     procedure, public, nopass :: get_default_requires_index => get_default_requires_index_le_grids_config
  end type le_grids_config_type
  
  type(le_grids_config_type) :: le_grids_config
  
contains

  !> FIXME : Add documentation  
  subroutine wnml_le_grids(unit)
    implicit none
    integer, intent(in) :: unit
    if (.not. exist) return
    write (unit, *)
    write (unit, fmt="(' &',a)") "le_grids_knobs"
    write (unit, fmt="(' nesub = ',i4)") nesub
    write (unit, fmt="(' nesuper = ',i4)") nesuper
    write (unit, fmt="(' ngauss = ',i4)") ngauss
    write (unit, fmt="(' vcut = ',e17.10)") vcut
    write (unit, fmt="(' /')")
  end subroutine wnml_le_grids

  !> FIXME : Add documentation  
  subroutine init_le_grids (le_grid_config_in)
    use mp, only: proc0, finish_mp
    use species, only: init_species
    use theta_grid, only: init_theta_grid
    use kt_grids, only: init_kt_grids
    use gs2_layouts, only: init_gs2_layouts
    use egrid, only: setvgrid
    implicit none
    type(le_grids_config_type), intent(in), optional :: le_grid_config_in    
    integer :: il, ie

    if (initialized) return
    initialized = .true.

    call init_gs2_layouts
    call init_species
    call init_theta_grid
    call init_kt_grids

    call read_parameters(le_grid_config_in)
    call set_vgrid
    if (proc0) then
       call set_grids
    end if
    call broadcast_results
    call init_integrations

    if (test) then
       if (proc0) then
          do il = 1, nlambda
             write(*,*) al(il)
          end do
          write(*,*) 
          do ie = 1, negrid
             write(*,*) energy(ie,:)
          end do
       end if
       call finish_mp
       stop
    endif
    
  end subroutine init_le_grids

  !> FIXME : Add documentation  
  subroutine set_vgrid
    use species, only: nspec, init_species
    use egrid, only: setvgrid, setvgrid_genquad
    call init_species

    allocate (energy(negrid,nspec), w(negrid,nspec))
    if (genquad) then
       call setvgrid_genquad (negrid, energy, w)
    else
       call setvgrid (vcut, negrid, energy, w, nesub)
    end if
  end subroutine set_vgrid

  !> FIXME : Add documentation 
  subroutine broadcast_results
    use mp, only: proc0, broadcast
    use species, only: nspec
    use egrid, only: init_egrid, zeroes_maxw, x0_maxw
    use theta_grid, only: ntgrid
    implicit none
    call broadcast (lmax)
    call broadcast (ng2)
    call broadcast (nlambda)
    call broadcast (nxi)

    if (.not. proc0) then
       allocate(speed(negrid,nspec))
       allocate(w_maxw(negrid))
       allocate(energy_maxw(negrid))
       allocate(speed_maxw(negrid))
       allocate (al(nlambda))
       allocate (wl(-ntgrid:ntgrid,nlambda))
       allocate (jend(-ntgrid:ntgrid))
       allocate (forbid(-ntgrid:ntgrid,nlambda))
       allocate (is_ttp(-ntgrid:ntgrid,nlambda))
       allocate (is_bounce_point(-ntgrid:ntgrid,nlambda))
       allocate (can_be_ttp(nlambda))
       allocate (xx(ng2))
       allocate (xi(2*nlambda, -ntgrid:ntgrid))
       allocate (ixi_to_il(2*nlambda, -ntgrid:ntgrid))
       allocate (ixi_to_isgn(2*nlambda, -ntgrid:ntgrid))
       allocate (wxi(2*nlambda, -ntgrid:ntgrid))
    end if

    call init_egrid (negrid)
    call broadcast (xx)
    call broadcast (x0_maxw)
    call broadcast (zeroes_maxw)

    call broadcast (al)
    call broadcast (jend)
 
    call broadcast (energy_maxw)
    call broadcast (speed)
    call broadcast (speed_maxw)
    call broadcast (w_maxw)

    call broadcast(wl)
    call broadcast(forbid)
    call broadcast(is_ttp)
    call broadcast(is_bounce_point)
    call broadcast(can_be_ttp)

    call broadcast(xi)
    call broadcast(ixi_to_il)
    call broadcast(ixi_to_isgn)
    call broadcast(wxi)
    call broadcast (sgn)
  end subroutine broadcast_results

  !> FIXME : Add documentation  
  subroutine read_parameters(le_grid_config_in)
    use file_utils, only: error_unit
    use text_options, only: text_option, get_option_value
    implicit none
    type(le_grids_config_type), intent(in), optional :: le_grid_config_in    
    type (text_option), dimension (4), parameter :: wfbbcopts = &
         (/ text_option('default', wfbbc_option_mixed), &
            text_option('passing', wfbbc_option_passing), &
            text_option('trapped', wfbbc_option_trapped), &
            text_option('mixed', wfbbc_option_mixed)/)
    character(20) :: wfbbc_option
    integer :: ierr
    
    if (present(le_grid_config_in)) le_grids_config = le_grid_config_in

    call le_grids_config%init(name = 'le_grids_knobs', requires_index = .false.)

    ! Copy out internal values into module level parameters
    bouncefuzz = le_grids_config%bouncefuzz
    genquad = le_grids_config%genquad
    negrid = le_grids_config%negrid
    nesub = le_grids_config%nesub
    nesuper = le_grids_config%nesuper
    new_trap_int = le_grids_config%new_trap_int
    new_trap_int_split = le_grids_config%new_trap_int_split
    npassing = le_grids_config%npassing
    ngauss = le_grids_config%ngauss
    nmax = le_grids_config%nmax
    radau_gauss_grid = le_grids_config%radau_gauss_grid
    split_passing_region = le_grids_config%split_passing_region
    test = le_grids_config%test
    trapped_particles = le_grids_config%trapped_particles
    vcut = le_grids_config%vcut
    wfbbc_option = le_grids_config%wfbbc_option
    wgt_fac = le_grids_config%wgt_fac

    exist = le_grids_config%exist
    
! user can choose not to set negrid (preferred for old scheme)
    if (negrid == -10) then
       negrid = nesub + nesuper

! If user chose negrid, then set nesuper and nesub accordingly
    else 
       nesuper = min(negrid/10+1, 4)
       nesub = negrid - nesuper
    endif

    ! Set ng2 from ngauss if not specified in the input file
    if (npassing < 0) then
       ng2 = 2 * ngauss
    else
       ng2 = npassing
    end if

    ierr = error_unit()

    call get_option_value &
         (wfbbc_option, wfbbcopts, wfbbc_option_switch, &
         ierr, "wfbbc_option in dist_fn_knobs",.true.)

    trapped_wfb = .false.! flag for treating wfb as a trapped particle
    passing_wfb = .false.! flag for treating wfb as a passing particle

    select case (wfbbc_option_switch)
    case(wfbbc_option_mixed) ! The previous default boundary condition which mixes the passing and trapped boundary conditions
       passing_wfb =.false.
       trapped_wfb =.false.
    case(wfbbc_option_passing) ! Treats wfb using a passing particle boundary condition
       passing_wfb =.true.
       trapped_wfb =.false.
    case(wfbbc_option_trapped) ! Treats the wfb using a trapped particle boundary condition
       passing_wfb =.false.
       trapped_wfb =.true.
    end select
    mixed_wfb = (.not. passing_wfb) .and. (.not. trapped_wfb)
  end subroutine read_parameters

  !> FIXME : Add documentation  
  subroutine init_integrations
    use kt_grids, only: naky, ntheta0
    use species, only: nspec
    use gs2_layouts, only: init_dist_fn_layouts, pe_layout
    use theta_grid, only: ntgrid
    use mp, only: nproc, iproc
    implicit none
    call init_dist_fn_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nproc, iproc)
  end subroutine init_integrations

  !> Returns true if there are trapped particles on the current
  !> pitch angle grid.
  pure logical function grid_has_trapped_particles() result(has_trapped)
    implicit none
    has_trapped = nlambda > ng2
  end function grid_has_trapped_particles

  !> Returns true if the passed pitch angle grid point
  !> is considered trapped.
  elemental logical function il_is_trapped(il) result(is_trapped)
    implicit none
    integer, intent(in) :: il
    is_trapped = il > ng2 + 1
    ! If we want to improve the consistency of the trapped wfb
    ! treatment we can replace the above with
    ! is_trapped = il > ng2
    ! To be explicit this might look like
    ! if (trapped_wfb) then
    !     is_trapped = il > ng2
    ! else
    !     is_trapped = il > ng2 + 1
    ! end if
  end function il_is_trapped

  !> Returns true if the passed pitch angle grid point
  !> is considered passing.
  elemental logical function il_is_passing(il) result(is_passing)
    implicit none
    integer, intent(in) :: il
    is_passing = il <= ng2
    ! If we want to improve the consistency of the passing wfb
    ! treatment we can replace the above with
    ! is_passing = il <= ng2 + 1
    ! To be explicit this might look like
    ! if (passing_wfb) then
    !     is_passing = il <= ng2 + 1
    ! else
    !     is_passing = il <= ng2
    ! end if
  end function il_is_passing

  !> Returns true if the passed pitch angle grid point
  !> is considered the wfb.
  elemental logical function il_is_wfb(il) result(is_wfb)
    implicit none
    integer, intent(in) :: il
    is_wfb = il == ng2 + 1
    ! If we want to improve the consistency of the wfb
    ! treatment we should hopefully rarely need to specificaly
    ! identify the wfb
  end function il_is_wfb

  !> Calculates energy grid and (untrapped) pitch angle weights for
  !> integrals which drop a point from the energy/lambda
  !> dimensions. These are only used for estimating the velocity space
  !> error in [[dist_fn::get_verr]] as a part of the adaptive
  !> collisionality algorithm.
  subroutine init_weights
    use file_utils, only: open_output_file, close_output_file
    use egrid, only: zeroes
    use constants, only: pi => dpi
    use species, only: nspec, f0_values

    implicit none

    real, dimension (:), allocatable :: modzeroes, werrtmp  ! (negrid-2)
    real, dimension (:), allocatable :: lmodzeroes, wlerrtmp ! (ng2-1)
    integer :: ipt, ndiv, divmax, is
    logical :: eflag = .false.

    if (init_weights_init) return
    init_weights_init = .true.


    allocate(lmodzeroes(ng2-1), wlerrtmp(ng2-1))
    allocate(wlerr(ng2,ng2))

    wlerr = 0.0; lmodzeroes = 0.0; wlerrtmp = 0.0

    wdim = nesub
    allocate(modzeroes(nesub-1), werrtmp(nesub-1))
    allocate(werr(negrid-1,nesub,nspec))

    werr = 0.0 ; modzeroes = 0.0 ; werrtmp = 0.0
    
    ! loop to obtain weights for energy grid points.  negrid-1 sets
    ! of weights are needed because we want to compute integrals
    ! for negrid-1 sets of energy points (corresponding to negrid-1
    ! points that we can choose to drop from the guassian quadrature)
    
    do ipt=1,nesub
       do is = 1,nspec
       
         ! drops the point corresponding to ipt from the energy grid
         
         if (ipt /= 1) modzeroes(:ipt-1) = zeroes(:ipt-1,is)
         if (ipt /= nesub) modzeroes(ipt:nesub-1) = zeroes(ipt+1:nesub,is)
         
         ! get weights for energy grid points
         
         call get_weights (nmax,0.0,vcut,modzeroes,werrtmp,ndiv,divmax,eflag)
         
         ! a zero is left in the position corresponding to the dropped point
         
         if (ipt /= 1) werr(:ipt-1,ipt,is) = werrtmp(:ipt-1)
         if (ipt /= nesub) werr(ipt+1:nesub,ipt,is) = werrtmp(ipt:nesub-1)
         werr(nesub+1:,ipt,is) = w(nesub+1:negrid-1,is)
         
         ! absorbing volume element into weights
         werr(:nesub,ipt,is) = werr(:nesub,ipt,is)*energy(:nesub,is)*2.0*pi*f0_values(:nesub,is)
       end do
       
    end do

    ! same thing done here for lamdba as was
    ! done earlier for energy space

    do ipt=1,ng2

       if (ipt /= 1) lmodzeroes(:ipt-1) = xx(:ipt-1)
       if (ipt /= ng2) lmodzeroes(ipt:ng2-1) = xx(ipt+1:)

       call get_weights (nmax,1.0,0.0,lmodzeroes,wlerrtmp,ndiv,divmax,eflag)

       if (ipt /= 1) wlerr(:ipt-1,ipt) = wlerrtmp(:ipt-1)
       if (ipt /= ng2) wlerr(ipt+1:,ipt) = wlerrtmp(ipt:)

!       do il = 1, ng2
          ! TEMP FOR TESTING -- MAB
!          write (*,*) 'wlerr', ipt, il, xx(il), wlerr(il,ipt), ndiv, divmax
!       end do

    end do

    deallocate(modzeroes,werrtmp,lmodzeroes,wlerrtmp)
    eflag = .false.

  end subroutine init_weights

  !> The get_weights subroutine determines how to divide up the integral into 
  !! subintervals and how many grid points should be in each subinterval
  subroutine get_weights (maxpts_in, llim, ulim, nodes, wgts, ndiv, divmax, err_flag)

    implicit none

    integer, intent (in) :: maxpts_in
    real, intent (in) :: llim, ulim
    real, dimension (:), intent (in) :: nodes
    real, dimension (:), intent (out) :: wgts
    logical, intent (out) :: err_flag
    integer, intent (out) :: ndiv, divmax

    integer :: npts, rmndr, basepts, divrmndr, base_idx, idiv, epts, maxpts
    integer, dimension (:), allocatable :: divpts

    real :: wgt_max

! npts is the number of grid points in the integration interval
    npts = size(nodes)

    wgts = 0.0; epts = npts; basepts = nmax; divrmndr = 0; ndiv = 1; divmax = npts

! maxpts is the max number of pts in an integration subinterval
    maxpts = min(maxpts_in,npts)

    do

!       wgt_max = wgt_fac/maxpts
       wgt_max = abs(ulim-llim)*wgt_fac/npts

! only need to subdivide integration interval if maxpts < npts
       if (maxpts .ge. npts) then
          call get_intrvl_weights (llim, ulim, nodes, wgts)
       else
          rmndr = mod(npts-maxpts,maxpts-1)
          
! if rmndr is 0, then each subinterval contains maxpts pts
          if (rmndr == 0) then
! ndiv is the number of subintervals
             ndiv = (npts-maxpts)/(maxpts-1) + 1
             allocate (divpts(ndiv))
! divpts is an array containing the # of pts for each subinterval
             divpts = maxpts
          else
             ndiv = (npts-maxpts)/(maxpts-1) + 2
             allocate (divpts(ndiv))
! epts is the effective number of pts after taking into account double
! counting of some grid points (those that are boundaries of subintervals
! are used twice)
             epts = npts + ndiv - 1
             basepts = epts/ndiv
             divrmndr = mod(epts,ndiv)
             
! determines if all intervals have same # of pts
             if (divrmndr == 0) then
                divpts = basepts
             else
                divpts(:divrmndr) = basepts + 1
                divpts(divrmndr+1:) = basepts
             end if
          end if
          
          base_idx = 0
          
! loop calls subroutine to get weights for each subinterval
          do idiv=1,ndiv
             if (idiv == 1) then
                call get_intrvl_weights (llim, nodes(base_idx+divpts(idiv)), &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             else if (idiv == ndiv) then
                call get_intrvl_weights (nodes(base_idx+1), ulim, &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             else
                call get_intrvl_weights (nodes(base_idx+1), nodes(base_idx+divpts(idiv)), &
                     nodes(base_idx+1:base_idx+divpts(idiv)),wgts(base_idx+1:base_idx+divpts(idiv)))
             end if
             base_idx = base_idx + divpts(idiv) - 1
          end do
          
          divmax = maxval(divpts)

          deallocate (divpts)
       end if

       ! check to make sure the weights do not get too large
       ! @Warning: This looks a bit unusual -- we're taking the magnitude of
       ! the maximum value, whilst usually we would find the maximum value of
       ! the magnitude.
       if (abs(maxval(wgts)) .gt. wgt_max) then
          if (maxpts .lt. 3) then
             err_flag = .true.
             exit
          end if
          maxpts = divmax - 1
       else
          exit
       end if

       wgts = 0.0; epts = npts; divrmndr = 0; basepts = nmax
    end do

  end subroutine get_weights

  !> Used by [[get_weights]] to find the Lagrange quadrature weights for the given grid.
  !>
  !> Suppose we wish to calculate Int(f[x] dx) over the range [llim, ulim].
  !> We could approximate f[x] by the Lagrange interpolating polynomial -->
  !> f[x] ~ Sum_i{ f[x_i] l_i(x)} where. the Lagrange polynomial is given by
  !> l_i(x) = Pi_j{ (x - x_i)/(x_i - x_j) }. Hence our integral can be written as
  !> I = Int(f[x] dx) ~ Sum_i{ f[x_i] Int(Pi_j{ (x - x_i)/(x_i - x_j) })} so
  !> I ~ Sum_i{f[x_i] c_i} with c_i = Int(Pi_j{ (x - x_i)/(x_i - x_j) }). Here
  !> we are calculating the `c_i` values which are passed out as [[wgts]] corresponding
  !> to the set of positions {x_i} set by [[nodes]]. To evaluate the integral
  !> of the Lagrange polynomial coefficients we adopt Gauss-Legendre quadrature,
  !> which allows highly accurate evaluation of polynomials of degree 2*N-1.
  !> Given the need to find M weights we find we only need Gauss-Legendre of order
  !> N = (M + 1)/2. Here we actually use degree (M/2) + 1.
  !>
  !> One limitation of Lagrange polynomials is their lack of numerical stability
  !> at high order (i.e. when M is large, typically > 20 is quoted as a limit).
  !> The routine [[get_weights]] allows the integration domain to be sub-divided
  !> in order to limit the Lagrange order and reduce the variation in weights.
  !> This is partially controlled by the input [[nmax]].
  subroutine get_intrvl_weights (llim, ulim, nodes, wgts)
    use gauss_quad, only: get_legendre_grids_from_cheb
    
    implicit none
    
    ! llim (ulim) is lower (upper) limit of integration
    real, intent (in) :: llim, ulim
    real, dimension (:), intent (in) :: nodes
    real, dimension (:), intent (in out) :: wgts
    
    ! stuff needed to do guassian quadrature 
    real, dimension (:), allocatable :: gnodes, gwgts, omprod
    integer :: ix, iw

    allocate (gnodes(size(nodes)/2+1), gwgts(size(wgts)/2+1), omprod(size(nodes)/2+1))
    
    call get_legendre_grids_from_cheb (llim, ulim, gnodes, gwgts)

    do iw=1,size(wgts)
       omprod = 1.0
       
       do ix=1,size(nodes)
          if (ix /= iw) omprod = omprod*(gnodes - nodes(ix))/(nodes(iw) - nodes(ix))
       end do
       
       do ix=1,size(gwgts)
          wgts(iw) = wgts(iw) + omprod(ix)*gwgts(ix)
       end do
    end do

    deallocate (gnodes, gwgts, omprod)
       
  end subroutine get_intrvl_weights

  !> FIXME : Add documentation  
  subroutine integrate_species_original (g, weights, total)
    use theta_grid, only: ntgrid
    use gs2_layouts, only: g_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use kt_grids, only: kwork_filter
    use mp, only: sum_allreduce
    use species, only: tracer_species, spec
    use array_utils, only: zero_array
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in out) :: weights
    complex, dimension (-ntgrid:,:,:), intent (out) :: total
    integer :: is, il, ie, ik, it, iglo

    !Ensure array is zero to begin
    call zero_array(total)
    where (spec%type == tracer_species) weights = 0

    !Performed integral (weighted sum) over local velocity space and species
    if(any(kwork_filter))then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
       !$OMP SHARED(g_lo, kwork_filter, weights, w, wl, g) &
       !$OMP REDUCTION(+ : total) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          if(kwork_filter(it,ik)) cycle
          is = is_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)

          !Sum up weighted g
          total(:, it, ik) = total(:, it, ik) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    else
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
       !$OMP SHARED(g_lo, weights, w, wl, g) &
       !$OMP REDUCTION(+ : total) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          is = is_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)

          !Sum up weighted g
          total(:, it, ik) = total(:, it, ik) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    endif
    !Reduce sum across all procs to make integral over all velocity space and species
    call sum_allreduce (total) 
  end subroutine integrate_species_original

  !> Integrate species on xy subcommunicator - NO GATHER
  subroutine integrate_species_sub (g, weights, total)
    use theta_grid, only: ntgrid
    use gs2_layouts, only: g_lo, intspec_sub
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce_sub, sum_allreduce
    use kt_grids, only: kwork_filter
    use species, only: spec, tracer_species
    use array_utils, only: zero_array, copy
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in out) :: weights
    complex, dimension (-ntgrid:,:,:), intent (out) :: total
    complex, dimension(:,:,:),allocatable :: total_small
    integer :: is, il, ie, ik, it, iglo

    !Allocate array and ensure is zero
    if(intspec_sub)then
       !       total(:,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max)=0.
       allocate(total_small(-ntgrid:ntgrid,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max))
    else
       !total=0.
       allocate(total_small(-ntgrid:ntgrid,g_lo%ntheta0,g_lo%naky))       
    endif
    call zero_array(total_small)
    where (spec%type == tracer_species) weights = 0

    !Performed integral (weighted sum) over local velocity space and species
    if(any(kwork_filter))then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
       !$OMP SHARED(g_lo, kwork_filter, weights, w, wl, g) &
       !$OMP REDUCTION(+ : total_small) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          if(kwork_filter(it,ik)) cycle
          is = is_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          
          !Sum up weighted g
          total_small(:, it, ik) = total_small(:, it, ik) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    else
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
       !$OMP SHARED(g_lo, weights, w, wl, g) &
       !$OMP REDUCTION(+ : total_small) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          is = is_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          
          !Sum up weighted g
          total_small(:, it, ik) = total_small(:, it, ik) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    endif

    !Reduce sum across all procs in sub communicator to make integral over all velocity space and species
    if(intspec_sub)then
       call sum_allreduce_sub(total_small,g_lo%xyblock_comm)
    else
       call sum_allreduce(total_small)
    endif

    !Copy data into output array
    !Note: When not using sub-comms this is an added expense which will mean
    !this routine is more expensive than original version just using total.
    !In practice we should have two integrate_species routines, one for sub-comms
    !and one for world-comms.
    if(intspec_sub)then
       total(:,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max)=total_small
    else
       call copy(total_small, total)
    endif

    !Deallocate
    deallocate(total_small)

  end subroutine integrate_species_sub

  !>Integrate species using gf_lo data format
  subroutine integrate_species_gf (g, weights, total)
    use species, only : nspec
    use theta_grid, only: ntgrid
    use gs2_layouts, only: gf_lo, g_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce_sub, sum_allreduce
    use kt_grids, only: kwork_filter
    use redistribute, only: gather
    use species, only: spec, tracer_species
    use array_utils, only: zero_array
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (:, :, :, :, :, :), allocatable :: gf
    real, dimension (:), intent (in out) :: weights
    complex, dimension (-ntgrid:ntgrid,gf_lo%ntheta0,gf_lo%naky), intent (out) :: total
    integer :: is, il, ie, igf, it, ik

    call zero_array(total)
    where (spec%type == tracer_species) weights = 0
    allocate(gf(-ntgrid:ntgrid,2,nspec,negrid,nlambda,gf_lo%llim_proc:gf_lo%ulim_alloc))
    call gather(g2gf, g, gf, ntgrid)

    !Performed integral (weighted sum) over local velocity space and species
    if(any(kwork_filter)) then
       do igf = gf_lo%llim_proc,gf_lo%ulim_proc
          it = it_idx(gf_lo,igf)
          ik = ik_idx(gf_lo,igf)
          if(kwork_filter(it,ik)) cycle
          do il = 1,gf_lo%nlambda
             do ie = 1,gf_lo%negrid
                do is = 1,gf_lo%nspec
                   total(:,it,ik) = total(:,it,ik) + &
                        (weights(is)*w(ie,is))*wl(:,il)*(gf(:,1,is,ie,il,igf)+gf(:,2,is,ie,il,igf))
                end do
             end do
          end do
       end do
    else
       do igf = gf_lo%llim_proc,gf_lo%ulim_proc
          it = it_idx(gf_lo,igf)
          ik = ik_idx(gf_lo,igf)
          do il = 1,gf_lo%nlambda
             do ie = 1,gf_lo%negrid
                do is = 1,gf_lo%nspec
                   total(:,it,ik) = total(:,it,ik) + &
                        (weights(is)*w(ie,is))*wl(:,il)*(gf(:,1,is,ie,il,igf)+gf(:,2,is,ie,il,igf))
                end do
             end do
          end do
       end do
    end if

    deallocate(gf)
  end subroutine integrate_species_gf

  !> Integrate species using gf_lo data format assuming that gf has already been gathered prior to being 
  !! passed to this routine.
  !! Currently this routine isn't being used as the functionality has been directly added to 
  !! getan_nogath in dist_fn.
  !!
  !! AJ
  subroutine integrate_species_gf_nogather (gf, weights, total)
    use species, only : nspec
    use theta_grid, only: ntgrid
    use gs2_layouts, only: gf_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce_sub, sum_allreduce
    use kt_grids, only: kwork_filter
    use species, only: spec, tracer_species
    use array_utils, only: zero_array
    implicit none

    complex, dimension(-ntgrid:ntgrid,2,nspec,negrid,nlambda,gf_lo%llim_proc:gf_lo%ulim_alloc), intent(in) :: gf
    real, dimension (:), intent (in out) :: weights
    complex, dimension (-ntgrid:ntgrid,gf_lo%ntheta0,gf_lo%naky), intent (out) :: total
    integer :: is, il, ie, igf, it, ik

    call zero_array(total)
    where (spec%type == tracer_species) weights = 0
    !Performed integral (weighted sum) over local velocity space and species
    if(any(kwork_filter)) then
       do igf = gf_lo%llim_proc,gf_lo%ulim_proc
          it = it_idx(gf_lo,igf)
          ik = ik_idx(gf_lo,igf)
          total(:,it,ik) = 0.
          if(kwork_filter(it,ik)) cycle
          do il = 1,gf_lo%nlambda
             do ie = 1,gf_lo%negrid
                do is = 1,gf_lo%nspec
                   total(:,it,ik) = total(:,it,ik) + &
                        (weights(is)*w(ie,is))*wl(:,il)*(gf(:,1,is,ie,il,igf)+gf(:,2,is,ie,il,igf))
                end do
             end do
          end do
       end do
    else
       do igf = gf_lo%llim_proc,gf_lo%ulim_proc
          it = it_idx(gf_lo,igf)
          ik = ik_idx(gf_lo,igf)
          total(:,it,ik) = 0.
          do il = 1,gf_lo%nlambda
             do ie = 1,gf_lo%negrid
                do is = 1,gf_lo%nspec
                   total(:,it,ik) = total(:,it,ik) + &
                        (weights(is)*w(ie,is))*wl(:,il)*(gf(:,1,is,ie,il,igf)+gf(:,2,is,ie,il,igf))
                end do
             end do
          end do
       end do
    end if
  end subroutine integrate_species_gf_nogather
  
  !> Integrate_species on subcommunicator with gather
  !! Falls back to original method if not using xyblock sub comm
  subroutine integrate_species_master (g, weights, total, nogath, gf_lo)
    use theta_grid, only: ntgrid
    use kt_grids, only: ntheta0, naky
    use gs2_layouts, only: g_lo, intspec_sub
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce_sub, allgatherv
    use mp, only: nproc_comm,rank_comm,mp_abort
    use species, only: spec, tracer_species
    use array_utils, only: zero_array
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in out) :: weights
    complex, dimension (0:,:,:), intent (out) :: total
    logical, intent(in), optional :: nogath
    logical, intent(in), optional :: gf_lo
    complex, dimension (:), allocatable :: total_flat
    complex, dimension (:,:,:), allocatable :: total_transp
    integer :: nl,nr, ik, it, iglo, ip, ie,is,il, ig


    if(present(gf_lo)) then
       if(gf_lo) then
          call integrate_species_gf(g,weights,total)
          return
       end if
    end if
    !If not using sub-communicators then just use original method
    !Note that if x and y are entirely local then we force intspec_sub=.false.
    if(.not.intspec_sub) then
       call integrate_species_original(g,weights,total)
       return
    endif

    !If we don't want to gather then use integrate_species_sub
    if(present(nogath))then
       if(nogath)then
          call integrate_species_sub(g,weights,total)
          return
       endif
    endif

    !->First intialise gather vars
    !Note: We only do this on the first call !!May be better to move this to some init routine?
    if(.not.allocated(recvcnts_intspec)) then
       !Get subcomm size
       call nproc_comm(g_lo%lesblock_comm,sz_intspec)

       !Get local rank
       call rank_comm(g_lo%lesblock_comm,local_rank_intspec)

       !Create displacement and receive count arrays
       allocate(recvcnts_intspec(sz_intspec),displs_intspec(sz_intspec))

       do ip=0,sz_intspec-1
          displs_intspec(ip+1)=MIN(g_lo%les_kxky_range(1,ip)*(2*ntgrid+1),ntheta0*naky*(2*ntgrid+1)-1)
          recvcnts_intspec(ip+1)=MAX((g_lo%les_kxky_range(2,ip)-g_lo%les_kxky_range(1,ip)+1)*(2*ntgrid+1),0)
       enddo
    endif

    !Allocate array and ensure is zero
    allocate(total_flat(g_lo%les_kxky_range(1,local_rank_intspec)*&
         (2*ntgrid+1):(1+g_lo%les_kxky_range(2,local_rank_intspec))*(2*ntgrid+1)))
    call zero_array(total_flat)
    where (spec%type == tracer_species) weights = 0
    !Performed integral (weighted sum) over local velocity space and species
    if(g_lo%x_before_y) then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is, nl, nr) &
       !$OMP SHARED(g_lo, ntgrid, weights, w, wl, g, ntheta0) &
       !$OMP REDUCTION(+ : total_flat) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          is = is_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          
          !Calculate extent
          nl=(2*ntgrid+1)*(it-1+ntheta0*(ik-1))
          nr=nl+(2*ntgrid)
          
          !Sum up weighted g
          total_flat(nl:nr) = total_flat(nl:nr) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    else
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, ie, il, is, nl, nr) &
       !$OMP SHARED(g_lo, ntgrid, weights, w, wl, g, naky) &
       !$OMP REDUCTION(+ : total_flat) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          !Convert from iglo to the separate indices
          is = is_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          
          !Calculate extent
          nl=(2*ntgrid+1)*(ik-1+naky*(it-1))
          nr=nl+(2*ntgrid)

          !Sum up weighted g
          total_flat(nl:nr) = total_flat(nl:nr) + &
               (weights(is)*w(ie,is))*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
       !$OMP END PARALLEL DO
    endif

    !Reduce sum across all procs in sub communicator to make integral over all velocity space and species
    call sum_allreduce_sub(total_flat,g_lo%xyblock_comm)

    !Now gather missing xy data from other procs (only talk to procs
    !with the same piece of les)

    if(g_lo%x_before_y)then
       call allgatherv(total_flat,recvcnts_intspec(local_rank_intspec+1),total,recvcnts_intspec,displs_intspec,g_lo%lesblock_comm)
    else
       allocate(total_transp(0:2*ntgrid,naky,ntheta0))
       call allgatherv(total_flat,recvcnts_intspec(local_rank_intspec+1),total_transp,recvcnts_intspec,displs_intspec,g_lo%lesblock_comm)
       do ig=0,2*ntgrid
          total(ig,:,:)=transpose(total_transp(ig,:,:))
       enddo
       !This is pretty bad for memory access so can do this all at once
       !using reshape with a specified order :
       !total=RESHAPE(total_transp,(/2*ntgrid+1,ntheta0,naky/),ORDER=(/1,3,2/))
       !BUT timings in a simple test code suggest loop+transpose can be faster.
       !In the case where ntgrid is large and ntheta0 is small reshape can win
       !whilst in the case where ntgrid is small and ntheta0 is large transpose wins.
       !When both are large reshape seems to win.
       !In conclusion it's not clear which method is better but if we assume we care most
       !about nonlinear simulations then small ntgrid with large ntheta0 is most likely so
       !pick transpose method
       deallocate(total_transp)
    endif

    deallocate(total_flat)
  end subroutine integrate_species_master

  !> FIXME : Add documentation
  !>
  !> @note This routine depends on the values stored in [[xx]] as input
  !> to the legendre_polynomials method. However, with Radau grids the
  !> values in [[xx]] are not Legendre roots so this routine is probably
  !> incorrect when [[le_grids_knobs::radau_gauss_grid]] is `.true.`.
  subroutine legendre_transform (g, tote, totl, tott)  
    use egrid, only: zeroes
    use mp, only: nproc
    use theta_grid, only: ntgrid, bmag, bmax
    use species, only: nspec
    use kt_grids, only: naky, ntheta0
    use gs2_layouts, only: g_lo, idx, idx_local
    use mp, only: sum_reduce
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (0:,-ntgrid:,:,:,:), intent (out) :: tote, totl
    complex, dimension (0:,-ntgrid:,:,:,:), intent (out), optional :: tott

    complex :: totfac
    real :: ulim
    integer :: is, il, ie, ik, it, iglo, ig, im, ntrap
    integer, save :: lpesize

    real, dimension (:), allocatable :: nodes
    real, dimension (:,:), allocatable :: lpltmp, lpttmp
!    real, dimension (:,:), allocatable, save :: lpe, lpl
!    real, dimension (:,:,:,:), allocatable, save :: lpt

    if (.not. allocated(lpl)) then
       allocate(lpltmp(ng2,0:ng2-1))
       allocate(lpl(nlambda,0:ng2-1))

       lpesize = nesub
       allocate(lpe(negrid,0:lpesize-1,nspec)) ; lpe = 0.0
       
       ! get value of first nesub legendre polynomials
       ! at each of the grid points on (0,vcut)
       do is = 1,nspec
          call legendre_polynomials (0.0,vcut,zeroes(:lpesize,is),lpe(:lpesize,:,is))
       end do
       ! TEMP FOR TESTING -- MAB
       !          lpe = 2.*lpe/vcut
       
       ! get value of first ng2 legendre polynomials
       ! at each of the grid points on (0,1)
       call legendre_polynomials (0.0,1.0,xx,lpltmp)

       lpl = 0.0
       lpl(1:ng2,:) = lpltmp

       if (present(tott)) then
          allocate (lpt(nlambda,0:2*(nlambda-ng2-1),-ntgrid:ntgrid,2))
          lpt = 0.0
          do ig = -ntgrid, ntgrid
             ntrap = 1
             if (jend(ig) > ng2+1) then
                ntrap = jend(ig)-ng2
                allocate (nodes(2*ntrap-1))
                allocate (lpttmp(2*ntrap-1,0:2*(ntrap-1)))
                do il = 1, ntrap
                   nodes(il) = -sqrt(max(0.0,1.0-al(ng2+il)*bmag(ig)))
                end do
                nodes(ntrap+1:) = -nodes(ntrap-1:1:-1)
!Can we remove this?
! TEMP FOR TESTING -- MAB
!                nodes = nodes + sqrt(1.0-bmag(ig)/bmax)
!                ulim = 2.*sqrt(1.0-bmag(ig)/bmax)
                ulim = sqrt(1.0-bmag(ig)/bmax)
                call legendre_polynomials (-ulim,ulim,nodes,lpttmp)
                lpt(ng2+1:jend(ig),0:2*(ntrap-1),ig,2) = lpttmp(1:ntrap,:)
                lpt(ng2+1:jend(ig)-1,0:2*(ntrap-1),ig,1) = lpttmp(2*ntrap-1:ntrap+1:-1,:)
!Can we remove this?
!                lpt(ng2+1:jend(ig),0:2*(ntrap-1),ig,1) = lpttmp(2*ntrap-1:ntrap+1:-1,:)
!                do ie = 0, 2*(ntrap-1)
!                   do il = 1, 2*ntrap-1
!                      if (proc0) write (*,*) 'lptrap', ig, ntrap, ulim, ie, il, nodes(il), lpttmp(il,ie)
!                   end do
!                end do
!                do ie = 0, 2*(ntrap-1)
!                   do il = ng2+1,jend(ig)
!                      write (*,*) 'lpt', ig, ie, il, lpt(il,ie,ig,1), lpt(il,ie,ig,2)
!                   end do
!                end do
                deallocate (nodes, lpttmp)
             end if
          end do
       end if

       deallocate (lpltmp)
    end if

    ! carry out legendre transform to get coefficients of
    ! legendre polynomial expansion of g
    totfac = 0. ; tote = 0. ; totl = 0.
    if (present(tott)) tott = 0.

    !Loop over all indices, note this loop is optimal only for layout 'xyles' (at least in terms of
    !g memory access)
    do is = 1, nspec
       do ie = 1, negrid
          do il = 1, nlambda
             do ik = 1, naky       !Swapped ik and it loop order.
                do it = 1, ntheta0
                   iglo = idx (g_lo, ik, it, il, ie, is)
                   if (idx_local (g_lo, iglo)) then
                      do ig=-ntgrid,ntgrid
                         totfac = w(ie,is)*wl(ig,il)*(g(ig,1,iglo)+g(ig,2,iglo))
                         do im=0,lpesize-1
                            tote(im, ig, it, ik, is) = tote(im, ig, it, ik, is) + totfac*lpe(ie,im,is)*(2*im+1)
                         end do
                         do im=0,ng2-1
                            totl(im, ig, it, ik, is) = totl(im, ig, it, ik, is) + totfac*lpl(il,im)*(2*im+1)
                         end do
                         if (present(tott)) then
                            do im=0,2*(jend(ig)-ng2-1)
                               tott(im, ig, it, ik, is) = tott(im, ig, it, ik, is) + &
                                    w(ie,is)*wl(ig,il)*(lpt(il,im,ig,1)*g(ig,1,iglo)+lpt(il,im,ig,2)*g(ig,2,iglo))*(2*im+1)
                            end do
                         end if
                      end do
                   end if
                end do
             end do
          end do
       end do
    end do

    !Do we really need this if?
    if (nproc > 1) then
       !Now complete velocity integral, bringing back results to proc0
       call sum_reduce (tote, 0)
       call sum_reduce (totl, 0)
       if (present(tott)) call sum_reduce (tott, 0)
    end if

  end subroutine legendre_transform

  !> FIXME : Add documentation  
  subroutine legendre_polynomials (llim, ulim, xptsdum, lpdum)
    implicit none
    double precision, dimension (:), allocatable :: lp1, lp2, lp3, zshift

    real, intent (in) :: ulim, llim
    real, dimension (:), intent (in)   :: xptsdum
    real, dimension (:,0:), intent(out) :: lpdum

    integer :: j, mmax

    lpdum = 0.0

!    nmax = size(lpdum(1,:))
    mmax = size(xptsdum)

    allocate(lp1(mmax),lp2(mmax),lp3(mmax),zshift(mmax))

    lp1 = real(1.0,kind(lp1(1)))
    lp2 = real(0.0,kind(lp2(1)))

    lpdum(:,0) = real(1.0,kind(lpdum))

! TEMP FOR TESTING -- MAB
!    zshift = real(2.0,kind(zshift))*xptsdum/ulim - real(1.0,kind(zshift))
    zshift = real(2.0,kind(zshift))*(xptsdum-llim)/(ulim-llim) - real(1.0,kind(zshift))

    do j=1, size(lpdum(1,:))-1
       lp3 = lp2
       lp2 = lp1
       lp1 = ((2*j-1) * zshift * lp2 - (j-1) * lp3) / j
       lpdum(:,j) = lp1
    end do

    deallocate(lp1,lp2,lp3,zshift)

  end subroutine legendre_polynomials

  !> FIXME : Add documentation
  !!
  !! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  !! NOTE: Takes f = f(x, y, z, sigma, lambda, E, species) and returns int f, where the integral
  !! is over all velocity space
  subroutine integrate_moment_c34 (g, total, all, full_arr)
! TT>
    use gs2_layouts, only: g_lo, is_idx, ik_idx, it_idx, ie_idx, il_idx,intmom_sub
! <TT
    use theta_grid, only: ntgrid
    use mp, only: sum_reduce, sum_allreduce_sub, nproc, sum_allreduce, sum_reduce_sub
    use array_utils, only: zero_array, copy

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    complex, dimension(:,:,:,:),allocatable :: total_small
    logical, optional, intent(in) :: all
    logical, optional, intent(in) :: full_arr
    logical :: local_full_arr, local_all
    integer :: is, il, ie, ik, it, iglo

    !Do we want to know the full result?
    local_full_arr=.false.
    if(present(full_arr)) local_full_arr=full_arr

    ! Do all processors need to know the result?
    local_all = .false.
    if(present(all)) local_all = all

    !NOTE: Currently we're lazy and force the full_arr approach to reduce
    !over the whole array. Really we should still use the sub-communicator
    !approach and then gather the remaining data as we do for integrate_species

    !Allocate array and ensure is zero
    if(intmom_sub.and.local_all.and.(.not.local_full_arr))then !If we're using reduce then we don't want to make array smaller
!       total(:,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max,g_lo%is_min:g_lo%is_max)=0.
       allocate(total_small(-ntgrid:ntgrid,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max,g_lo%is_min:g_lo%is_max))
    else
!       total=0.
       allocate(total_small(-ntgrid:ntgrid,g_lo%ntheta0,g_lo%naky,g_lo%nspec))       
    endif

    call zero_array(total_small)

    !Integrate over local velocity space
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
    !$OMP SHARED(g_lo, w, wl, g) &
    !$OMP REDUCTION(+ : total_small) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)

       !Perform local sum
       total_small(:, it, ik, is) = total_small(:, it, ik, is) + &
            w(ie,is)*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
    end do
    !$OMP END PARALLEL DO

    !Not sure that we really need to limit this to nproc>1 as if
    !we run with 1 proc MPI calls should still work ok
    if (nproc > 1) then     
       if (local_all) then 
         if (local_full_arr) then
            call sum_allreduce (total_small)
         else 
           !Complete integral over distributed velocity space and ensure all procs in sub communicator know the result
           !Note: fi intmom_sub=.false. then xysblock_comm==mp_comm  | This is why total_small must be the same size on 
           !all procs in this case.
           call sum_allreduce_sub (total_small,g_lo%xysblock_comm)
         end if
       else
         !if (local_full_arr) then
            !call sum_reduce (total_small, 0)
         !else 
           !Complete integral over distributed velocity space
           !Note: fi intmom_sub=.false. then xysblock_comm==mp_comm  | This is why total_small must be the same size on 
           !all procs in this case.
           !call sum_reduce_sub (total_small,g_lo%xysblock_comm)
         !end if

          !Complete integral over distributed velocity space but only proc0 knows the answer
          call sum_reduce (total_small, 0)
       end if
    end if

    !Copy data into output array
    !Note: When not using sub-comms this is an added expense which will mean
    !this routine is more expensive than original version just using total.
    !In practice we should have two integrate_moment_c34 routines, one for sub-comms
    !and one for world-comms.
    if(intmom_sub.and.local_all.and.(.not.local_full_arr))then
       total(:,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max,g_lo%is_min:g_lo%is_max)=total_small
    else
       call copy(total_small, total)
    endif

    !Deallocate
    deallocate(total_small)

  end subroutine integrate_moment_c34

  !> Takes f = f(theta, sigma ; x, y, z, lambda, E, species) and
  !> returns int f, where the integral is over all velocity space
  !> returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  subroutine integrate_moment_r34 (g, total, all, full_arr)
    use gs2_layouts, only: g_lo, is_idx, ik_idx, it_idx, ie_idx, il_idx,intmom_sub
    use theta_grid, only: ntgrid
    use mp, only: sum_reduce, sum_allreduce_sub, nproc, sum_allreduce, sum_reduce_sub
    use array_utils, only: zero_array, copy
    implicit none
    real, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (-ntgrid:,:,:,:), intent (out) :: total
    real, dimension(:,:,:,:),allocatable :: total_small
    logical, optional, intent(in) :: all
    logical, optional, intent(in) :: full_arr
    logical :: local_full_arr, local_all
    integer :: is, il, ie, ik, it, iglo

    !Do we want to know the full result?
    local_full_arr=.false.
    if(present(full_arr)) local_full_arr=full_arr

    ! Do all processors need to know the result?
    local_all = .false.
    if(present(all)) local_all = all

    !NOTE: Currently we're lazy and force the full_arr approach to reduce
    !over the whole array. Really we should still use the sub-communicator
    !approach and then gather the remaining data as we do for integrate_species

    !Allocate array and ensure is zero
    if(intmom_sub.and.local_all.and.(.not.local_full_arr))then !If we're using reduce then we don't want to make array smaller
       allocate(total_small(-ntgrid:ntgrid,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max,g_lo%is_min:g_lo%is_max))
    else
!       total=0.
       allocate(total_small(-ntgrid:ntgrid,g_lo%ntheta0,g_lo%naky,g_lo%nspec))
    endif

    call zero_array(total_small)

    !Integrate over local velocity space
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, il, is) &
    !$OMP SHARED(g_lo, w, wl, g) &
    !$OMP REDUCTION(+ : total_small) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)

       !Perform local sum
       total_small(:, it, ik, is) = total_small(:, it, ik, is) + &
            w(ie,is)*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
    end do
    !$OMP END PARALLEL DO

    !Not sure that we really need to limit this to nproc>1 as if
    !we run with 1 proc MPI calls should still work ok
    if (nproc > 1) then
       if (local_all) then
         if (local_full_arr) then
            call sum_allreduce (total_small)
         else
           !Complete integral over distributed velocity space and ensure all procs in sub communicator know the result
           !Note: fi intmom_sub=.false. then xysblock_comm==mp_comm  | This is why total_small must be the same size on
           !all procs in this case.
           call sum_allreduce_sub (total_small,g_lo%xysblock_comm)
         end if
       else
         !if (local_full_arr) then
            !call sum_reduce (total_small, 0)
         !else
           !Complete integral over distributed velocity space
           !Note: fi intmom_sub=.false. then xysblock_comm==mp_comm  | This is why total_small must be the same size on
           !all procs in this case.
           !call sum_reduce_sub (total_small,g_lo%xysblock_comm)
         !end if

          !Complete integral over distributed velocity space but only proc0 knows the answer
          call sum_reduce (total_small, 0)
       end if
    end if

    !Copy data into output array
    !Note: When not using sub-comms this is an added expense which will mean
    !this routine is more expensive than original version just using total.
    !In practice we should have two integrate_moment_c34 routines, one for sub-comms
    !and one for world-comms.
    if(intmom_sub.and.local_all.and.(.not.local_full_arr))then
       total(:,g_lo%it_min:g_lo%it_max,g_lo%ik_min:g_lo%ik_max,g_lo%is_min:g_lo%is_max)=total_small
    else
       call copy(total_small, total)
    endif

    !Deallocate
    deallocate(total_small)

  end subroutine integrate_moment_r34

  !> FIXME : Add documentation
  !!
  !! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  !! NOTE: Takes f = f(y, z, sigma, lambda, E, species) and returns int f, where the integral
  !! is over all velocity space
  subroutine integrate_moment_r33 (g, total, all)
    use mp, only: nproc
    use gs2_layouts, only: p_lo, is_idx, ik_idx, ie_idx, il_idx
    use mp, only: sum_reduce, sum_allreduce
    use theta_grid, only: ntgrid
    use array_utils, only: zero_array

    implicit none

    real, dimension (-ntgrid:,:,p_lo%llim_proc:), intent (in) :: g
    real, dimension (-ntgrid:,:,:), intent (out) :: total
    integer, optional, intent(in) :: all
    integer :: is, il, ie, ik, iplo

    !Ensure zero to start
    call zero_array(total)

    !Do local velocity space integral
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iplo, is, ik, ie, il) &
    !$OMP SHARED(p_lo, w, wl, g) &
    !$OMP REDUCTION(+ : total) &
    !$OMP SCHEDULE(static)
    do iplo = p_lo%llim_proc, p_lo%ulim_proc
       ik = ik_idx(p_lo,iplo)
       ie = ie_idx(p_lo,iplo)
       is = is_idx(p_lo,iplo)
       il = il_idx(p_lo,iplo)

       total(:, ik, is) = total(:, ik, is) + &
            w(ie,is)*wl(:,il)*(g(:,1,iplo)+g(:,2,iplo))

    end do
    !$OMP END PARALLEL DO

    !Do we really need this if?
    if (nproc > 1) then
       !Complete distributed integral
       if (present(all)) then
          !Return result to all procs
          call sum_allreduce (total)
       else
          !Only proc0 knows the result
          call sum_reduce (total, 0)
       end if
    end if

  end subroutine integrate_moment_r33

  !> Perform an integral over velocity space whilst in the LE_LAYOUT in 
  !! which we have ensured that all of velocity space is local. As such
  !! we don't need any calls to MPI reduction routines. Note that this means
  !! the processors for different distributed spatial points (x,y) don't know
  !! the results at other points.
  subroutine integrate_moment_lec (lo, g, total)
    use layouts_type, only: le_layout_type
    use gs2_layouts, only: ig_idx, it_idx, ik_idx, is_idx
    use kt_grids, only: kwork_filter
    use array_utils, only: zero_array
    implicit none
    type (le_layout_type), intent (in) :: lo
    complex, dimension (:,:,lo%llim_proc:), intent (in) :: g
    complex, dimension (lo%llim_proc:), intent (out) :: total
    integer :: ixi, ie, ile, ig, it, ik, is, nxup
    call zero_array(total)
    if (nxi > 2 * ng2) then !Could be grid_has_trapped_particles?
       nxup = nxi + 1
    else
       nxup = nxi
    end if

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(ile, is, it, ik, ig, ie, ixi) &
    !$OMP SHARED(lo, kwork_filter, negrid, nxup, w, wxi, g) &
    !$OMP REDUCTION(+ : total) &
    !$OMP SCHEDULE(guided)
    do ile = lo%llim_proc, lo%ulim_proc
       it = it_idx (lo,ile)
       ik = ik_idx (lo,ile)
       if (kwork_filter(it, ik)) cycle
       is = is_idx (lo,ile)
       ig = ig_idx (lo,ile)
       do ie = 1, negrid
!CMR, 2/10/2013:
!   nxi+1 limit on do loop below is CRUCIAL, as its stores phase space point
!   corresponding to g_lo (il=nlambda, isign=2).
!   This MUST contribute to the v-space integral, but is NOT
!   needed in collision operator as EQUIVALENT to g_lo(il=nlambda, isign=2).
!   (In collisions at ig=0, both of these points are EXACTLY equivalent, xi=0.)
!MRH actually nxi+1 is needed in the collision operator for consistency 16/08/2018  
          do ixi = 1, nxup
             total(ile) = total(ile) + w(ie, is) * wxi(ixi, ig) * g(ixi, ie, ile)
          end do
       end do
    end do
    !$OMP END PARALLEL DO
    ! No need for communication since all velocity grid points are together
    ! and each prcessor does not touch the unset place
    ! They actually don't need to keep all 4D array
    ! Do we stay in le_layout for total?
    ! --- ile contains necessary and sufficient information for (ig,it,ik,is)

  end subroutine integrate_moment_lec

  !> FIXME : Add documentation
  !!
  !! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  !! NOTE: Takes f = f(y, lambda, E, species) and returns int sum_{ky} f, where the integral
  !! is over energy and lambda (not sigma)
  subroutine integrate_kysum (g, ig, total, all)
    use species, only: nspec
    use kt_grids, only: aky
    use constants, only: zi
    use gs2_layouts, only: is_idx, ik_idx, ie_idx, il_idx, p_lo
    use mp, only: sum_reduce, sum_allreduce, nproc
    use array_utils, only: zero_array

    implicit none

    complex, dimension (p_lo%llim_proc:), intent (in) :: g
    integer, intent (in) :: ig
    complex, dimension (:), intent (out) :: total
    integer, optional, intent(in) :: all

    complex, dimension (negrid,nlambda,nspec) :: gksum
    integer :: is, il, ie, ik, iplo

    !Initialise both arrays to zero
    call zero_array(total) ; call zero_array(gksum)
    do iplo = p_lo%llim_proc, p_lo%ulim_proc
       ik = ik_idx(p_lo,iplo)
       ie = ie_idx(p_lo,iplo)
       is = is_idx(p_lo,iplo)
       il = il_idx(p_lo,iplo)
       gksum(ie,il,is) = gksum(ie,il,is) + real(aky(ik)*g(iplo)) + zi*aimag(aky(ik)*g(iplo))
    end do
    ! real part of gksum is | sum_{ky} ky * J0 * real[ ky*(conjg(phi+)*h- + conjg(phi-)*h+ ] |**2
    ! imag part of gksum is | sum_{ky} ky * J0 * aimag[ ky*(conjg(phi+)*h- + conjg(phi-)*h+ ] |**2
    gksum = real(gksum)**2 + zi*aimag(gksum)**2

    do iplo = p_lo%llim_proc, p_lo%ulim_proc
       ie = ie_idx(p_lo,iplo)
       is = is_idx(p_lo,iplo)
       il = il_idx(p_lo,iplo)

       total(is) = total(is) + w(ie,is)*wl(ig,il)*gksum(ie,il,is)
    end do

    !Do we really need this if?
    if (nproc > 1) then
       if (present(all)) then
          call sum_allreduce (total)
       else
          call sum_reduce (total, 0)
       end if
    end if

  end subroutine integrate_kysum

  !> FIXME : Add documentation  
  subroutine lint_error (g, weights, total)
    use theta_grid, only: ntgrid, bmag, bmax
    use gs2_layouts, only: g_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce, proc0, broadcast

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in) :: weights
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    integer :: is, il, ie, ik, it, iglo, ipt

    !If the weights array hasn't been filled in then do so now
    if (.not. allocated (wlmod)) then
       !Allocate array, don't need to initialise as below loop ensures
       !all elements are assigned a value
       allocate (wlmod(-ntgrid:ntgrid,nlambda,ng2))

       if (proc0) then
          do ipt = 1, ng2
             do il = 1, ng2
                wlmod(:,il,ipt) = wlerr(il,ipt)*2.0*sqrt((bmag(:)/bmax) &
                     *((1.0/bmax-al(il))/(1.0/bmag(:)-al(il))))
             end do
             !If we have trapped particles use the precalculated weights
             !in wlmod as above is only for passing particles
             if (grid_has_trapped_particles()) wlmod(:,ng2+1:,ipt) = wl(:,ng2+1:)
          end do
       end if

       !Now send the calculated value from proc0 to all other procs
       !We could just do the above calculations on all procs?
       call broadcast (wlmod)
    end if

    !Initialise to zero
    total = 0.

    !For each (passing) lambda point do velocity space integral
    do ipt=1,ng2
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)

          total(:, it, ik, ipt) = total(:, it, ik, ipt) + weights(is)*w(ie,is)*wlmod(:,il,ipt)*(g(:,1,iglo)+g(:,2,iglo))
       end do
    end do

    !Moved this outside of the ipt loop above
    call sum_allreduce (total) 

  end subroutine lint_error

  !> FIXME : Add documentation  
  subroutine trap_error (g, weights, total)
    use theta_grid, only: ntgrid
    use gs2_layouts, only: g_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce, proc0, broadcast

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in) :: weights
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    integer :: is, il, ie, ik, it, iglo, ipt, ntrap

    !How many trapped pitch angles are there?
    ntrap = nlambda - ng2

    !If weights not calculated yet do so now
    if (.not. allocated(wtmod)) then
       !Allocate array, don't need to initialise as below loops
       !ensure every element is assigned a value
       allocate (wtmod(-ntgrid:ntgrid,nlambda,ntrap))
          
       if (proc0) then
          do ipt=1,ntrap
             wtmod(:,:ng2,ipt) = wl(:,:ng2)
          end do

!Left below comments, but are we done testing this now?
! next line only to be used when testing!!!!
!          wtmod(:,:ng2,:) = 0.

          wtmod(:,ng2+1:,:) = wlterr(:,ng2+1:,:)
       endif

       !Send from proc0 to all others | We could just do the above calculations on all procs?
       call broadcast (wtmod)
    end if


    !Initialise to zero
    total = 0.

    !Loop over number of trapped points
    do ipt=1,ntrap
       !Do local velocity integral
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)

          total(:, it, ik, ipt) = total(:, it, ik, ipt) + weights(is)*w(ie,is)*wtmod(:,il,ipt)*(g(:,1,iglo)+g(:,2,iglo))
       end do
    end do

    !Moved this out of ipt loop above
    call sum_allreduce (total) 

  end subroutine trap_error

  !> FIXME : Add documentation  
  subroutine eint_error (g, weights, total)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use gs2_layouts, only: g_lo
    use gs2_layouts, only: is_idx, ik_idx, it_idx, ie_idx, il_idx
    use mp, only: sum_allreduce, proc0, broadcast

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (:), intent (in) :: weights
    complex, dimension (-ntgrid:,:,:,:), intent (out) :: total
    integer :: is, il, ie, ik, it, iglo, ipt

    !If we don't have the weights then calculate them now
    if (.not. allocated(wmod)) then
       !Allocate array, don't initialise as we fill in all values below
       allocate (wmod(negrid,wdim,nspec))

       if (proc0) then
          wmod(:negrid-1,:,:) = werr(:,:,:)
          do is = 1,nspec
            wmod(negrid,:,is) = w(negrid,is)  
          end do
       end if

       !send from proc0 to everywhere else
       call broadcast(wmod)
    end if

    !Initialise to zero
    total=0.

    !Do velocity space integral for each ipt (for all energy grid points)
    do ipt=1,wdim
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          it = it_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)

          total(:, it, ik, ipt) = total(:, it, ik, ipt) + weights(is)*wmod(ie,ipt,is)*wl(:,il)*(g(:,1,iglo)+g(:,2,iglo))
       end do
    end do

    !Moved this out of the above loop over ipt
    call sum_allreduce (total) 
  end subroutine eint_error

  !> FIXME : Add documentation  
  subroutine set_grids
    use species, only: init_species
    use egrid, only: setvgrid, x0, x0_maxw, zeroes, zeroes_maxw
    use species, only: nspec, spec, f0_maxwellian
    use theta_grid, only: init_theta_grid, ntgrid, nbset, bset, eps_trapped
    use file_utils, only: open_output_file, close_output_file
    use mp, only: proc0
    implicit none

    integer :: is
    logical :: has_maxwellian_species
    integer :: unit

    call init_theta_grid
    call init_species

    allocate (speed(negrid,nspec),speed_maxw(negrid),w_maxw(negrid),energy_maxw(negrid))

    w_maxw = 0.0
    energy_maxw = 0.0
    speed_maxw = 0.0
    has_maxwellian_species = .false. 

    speed = sqrt(energy)

    do is = 1,nspec
      if  (spec(is)%f0type .EQ. f0_maxwellian) then
         has_maxwellian_species = .true.
         w_maxw(:) = w(:,is)
         energy_maxw(:) = energy(:,is)
         speed_maxw(:) = speed(:,is)
         zeroes_maxw(:) = zeroes(:,is)
         x0_maxw = x0(is)
         exit
      end if
    end do
    if (.not. has_maxwellian_species) write(*,*) &
        'Warning; no maxwellian species; collisions will fail'
    !call setvgrid (vcut, negrid, energy, w, nesub)

    if (trapped_particles .and. eps_trapped > epsilon(0.0)) then
       nlambda = ng2+nbset
       lmax = nlambda-1
    else
       nlambda = ng2
       lmax = nlambda
    end if
    nxi = max(2*nlambda-1, 2*ng2)
    allocate (al(nlambda))
    allocate (wl(-ntgrid:ntgrid,nlambda))
    allocate (jend(-ntgrid:ntgrid))
    allocate (forbid(-ntgrid:ntgrid,nlambda))
    allocate (is_ttp(-ntgrid:ntgrid,nlambda))
    allocate (is_bounce_point(-ntgrid:ntgrid,nlambda))
    allocate (can_be_ttp(nlambda))
    if (grid_has_trapped_particles()) then
       al(ng2+1:nlambda) = 1.0/bset
    end if
    call lgridset

    if (proc0) then
       call open_output_file(unit, '.vspace_integration_error')
       call report_velocity_integration_error_estimate(unit)
       call close_output_file(unit)
    end if

  end subroutine set_grids

  !> Calculate estimates of the velocity space integration errors
  !> and report to screen / specified unit.
  !>
  !> This error estimate provides an estimate on the lower bound
  !> for the error on subsequent calls to integrate_species/integrate_moment
  !> Alternative error estimates can be calculated following the approach
  !> adopted in [[dist_fn::get_verr]].
  subroutine report_velocity_integration_error_estimate(unit_in)
    use iso_fortran_env, only: output_unit
    use theta_grid, only: ntgrid, bmag, bmax
    use species, only: nspec
    use constants, only: sqrt_pi
    implicit none

    integer, intent(in), optional :: unit_in
    real, parameter :: expected_lambda = 2.0, expected_energy = 0.25
    real, dimension(:), allocatable :: lambda_error, energy_error, b_ratio
    real, dimension(:), allocatable :: passing_error, expected_passing
    real, dimension(:), allocatable :: trapped_error, expected_trapped
    real, dimension(:), allocatable :: energy_sub_error, energy_super_error
    real, dimension(:, :), allocatable :: mixed_error, mixed_value
    real :: expected_energy_sub, expected_energy_super
    integer :: unit, ie, ig, is

    unit = output_unit
    if (present(unit_in)) unit = unit_in

    !Get estimates of the error on the integration weights.

    ! Total pitch angle grid
    allocate(lambda_error(-ntgrid:ntgrid))
    lambda_error = 0.0
    do ig = -ntgrid, ntgrid
       lambda_error(ig) = expected_lambda - sum(wl(ig, :))
    end do
    ! Introduce factor 2 here to account for sigma doubling
    lambda_error = 2*lambda_error

    ! The trapped/passing error breakdown is currently only
    ! appropriate when not using the Radau-Gauss grids
    if (.not. radau_gauss_grid) then

       ! Passing pitch angle grid
       allocate(passing_error(-ntgrid:ntgrid))
       passing_error = 0.0
       allocate(expected_passing(-ntgrid:ntgrid))
       allocate(b_ratio(-ntgrid:ntgrid))
       b_ratio = bmax / bmag
       ! This is the analytic integral of the function we weight the Legendre
       ! weights by over the passing region, i.e.
       ! 2.0*sqrt((bmag(ig)/bmax)*((1.0/bmax-al(il))/(1.0/bmag(ig)-al(il)))
       ! Note to integrate this correctly it is useful to transform from lambda
       ! to the Legendre variable xx, through xx = sqrt(1-lambda*bmax)
       ! The weight function then becomes:
       ! 2.0*sqrt(bmag/bmax) * sqrt([xx^2]/[(bmax/bmag)-1+xx^2])
       ! and we would integrate from xx=0 to xx=1
       ! Note this analytic result and comparison is appropriate when
       ! Radau_Gauss_Grid = F.  When this is not the case (default) then
       ! the calculation of the passing and trapped errors from this
       ! analytic result is no longer appropriate as the wfb contains
       ! contributions from both the passing and trapped regions when
       ! bmag = bmax, and as such our comparison is only valid away from
       ! where bmag=bmax.
       expected_passing = 2.0 * (1 - sqrt(b_ratio - 1)/sqrt(b_ratio))
       do ig = -ntgrid, ntgrid
          passing_error(ig) = expected_passing(ig) - sum(wl(ig, :ng2))
       end do
       ! Introduce factor 2 here to account for sigma doubling
       passing_error = 2*passing_error

       ! Trapped pitch angle grid
       allocate(trapped_error(-ntgrid:ntgrid))
       trapped_error = 0.0
       ! Only calculate this if we have any trapped pitch angles
       if (grid_has_trapped_particles()) then
          allocate(expected_trapped(-ntgrid:ntgrid))
          ! We're essentially just subtracting the passing analytic
          ! result from the total expectation, i.e. effectively
          !  expected_trapped = expected_lambda - expected_passing
          expected_trapped = 2.0 * sqrt(b_ratio - 1)/sqrt(b_ratio)
          do ig = -ntgrid, ntgrid
             trapped_error(ig) = expected_trapped(ig) - sum(wl(ig, ng2+1:))
          end do
          !Introduce factor 2 here to account for sigma doubling
          trapped_error = 2*trapped_error
       end if
    end if

    ! Total energy grid
    allocate(energy_error(nspec))
    energy_error = 0.0
    do is = 1, nspec
       energy_error(is) = expected_energy - sum(w(:,is))
    end do

    ! Energy sub grid
    allocate(energy_sub_error(nspec))
    energy_sub_error = 0.0
    ! NOTE: This assumes we are considering a species with Maxwellian
    ! background such that f0_values = exp(-energy)/pi^3/2
    ! We are analytically evaluating
    ! Integral_0^vcut{v^2 * pi * f0_values dv}
    ! Assuming Maxwellian this becomes
    ! Integral_0^vcut{v^2 exp(-v^2)/sqrt(pi) dv}
    ! Which has result:
    ! (1/4) [ erf(vcut) - 2 vcut exp(-vcut^2)/sqrt(pi)]
    expected_energy_sub = 0.25 * ( erf(vcut) - 2*vcut*exp(-vcut*vcut)/sqrt_pi)
    do is = 1, nspec
       energy_sub_error(is) = expected_energy_sub - sum(w(:nesub,is))
    end do

    ! Energy super grid
    allocate(energy_super_error(nspec))
    energy_super_error = 0.0
    if (negrid > nesub) then
       ! NOTE: This assumes we are considering a species with Maxwellian
       ! background such that f0_values = exp(-energy)/pi^3/2
       ! We are analytically evaluating
       ! Integral_0^infinity{sqrt(y+vcut^2) * exp(y) * pi * f0_values * W(y) dy}/2
       ! where W(y) are the Laguerre weights = exp(-y)
       ! Assuming Maxwellian this becomes
       ! Integral_0^infinity{sqrt(y+vcut^2) * exp(y) * exp(-(y+vcut^2)) * exp(-y) dy}/2
       ! which simplifies slightly to
       ! Integral_0^infinity{sqrt(y+vcut^2) * exp(-(vcut^2)) * exp(-y) dy}/2
       ! Which has result:
       ! [sqrt(pi)*erfc(vcut)/2 + vcut*exp(-vcut^2)]/(2*sqrt(pi))
       expected_energy_super =  0.5*(sqrt_pi*(1-erf(vcut))/2 + vcut*exp(-vcut*vcut))/sqrt_pi
       do is = 1, nspec
          energy_super_error(is) = expected_energy_super - sum(w(nesub+1:,is))
       end do
    end if

    allocate(mixed_error(-ntgrid:ntgrid, nspec))
    allocate(mixed_value(size(wl(0,:)), size(w(:,1))))
    mixed_error = 0.0
    do is = 1, nspec
       do ig = -ntgrid, ntgrid
          mixed_value = 0.0
          do ie = 1, size(w(:,1))
             mixed_value(:, ie) = wl(ig,:)*w(ie, is)
          end do
          mixed_error(ig, is) = expected_lambda*expected_energy - sum(mixed_value)
       end do
    end do

    ! Introduce factor 2 here to account for sigma doubling
    mixed_error = 2 * mixed_error

    ! Report errors to unit
    call report_error(unit, "Lambda", lambda_error)
    ! The trapped/passing error breakdown is currently only
    ! appropriate when not using the Radau-Gauss grids
    if (.not. radau_gauss_grid) then
       call report_error(unit, "Passing", passing_error)
       call report_error(unit, "Trapped", trapped_error)
    end if
    call report_error(unit, "Energy", energy_error)
    call report_error(unit, "Energy_Sub", energy_sub_error)
    call report_error(unit, "Energy_Super", energy_super_error)
    call report_error_2d(unit, "Mixed", mixed_error)

  contains
    subroutine report_error(unit, name, error)
      implicit none
      integer, intent(in) :: unit
      character(len=*), intent(in) :: name
      real, dimension(:), intent(in) :: error
      write(unit, '(A," weights errors, max/mean/mean(|error|)  :",3(E14.6E3," "))') &
           name, maxval(abs(error)), sum(error)/size(error), sum(abs(error))/size(error)
    end subroutine report_error

    subroutine report_error_2d(unit, name, error)
      implicit none
      integer, intent(in) :: unit
      character(len=*), intent(in) :: name
      real, dimension(:, :), intent(in) :: error
      write(unit, '(A," weights errors, max/mean/mean(|error|)  :",3(E14.6E3," "))') &
           name, maxval(abs(error)), sum(error)/size(error), sum(abs(error))/size(error)
    end subroutine report_error_2d

  end subroutine report_velocity_integration_error_estimate

  !> Sets up the pitch angle grid and associated data. Does so by
  !> calling other routines. A summary of the actions is below.
  !>
  !> Determines the location of lambda ([[al]]) grid points in the
  !> passing domain and associated integrations weights. Also
  !> determines the integration weights for the trapped pitch angle
  !> grid (the location of trapped [[al]] is already determined
  !> through the values of [[bset]]).
  !>
  !> Alongside determining the passing locations and the lambda
  !> integration weights we also set [[forbid]], [[jend]] and
  !> call [[xigridset]] to setup the \(\xi = \v_\|/v\) grid
  !> and associated quantities.
  subroutine lgridset
    use theta_grid, only: eps_trapped
    implicit none
    logical :: has_trapped_points

    ! Decide if we need to calculate the trapped lambda integration weights
    has_trapped_points = trapped_particles .and. eps_trapped > epsilon(0.0)

    ! Intialise the weights to zero
    wl = 0.0

    ! Calculate the passing lambda grid location and integration weights
    call setup_passing_lambda_grids(al, wl)

    ! Calculate the trapped lambda grid integration weights
    if (has_trapped_points) call setup_trapped_lambda_grids(al, wl)

    ! Calculate the forbid flag indicating which pitch angles are allowed at
    ! various points.
    call calculate_forbidden_region(forbid)

    ! Calculate the is_bounce_point flag indicating which pitch angles are allowed at
    ! various points.
    call calculate_bounce_points(is_bounce_point)

    ! Calculate the jend values indicating the pitch angle which bounces at
    ! each theta point.
    call calculate_jend(jend)

    ! Calculate the ittp values indicating which pitch angles are
    ! considered totally trapped.
    call calculate_ittp(can_be_ttp, is_ttp)

    ! Setup the xi (v||/v) grid used in collisions
    call xigridset

  end subroutine lgridset

  !> Determines the location of lambda ([[al]]) grid points in the
  !> passing domain and the associated integrations weights.
  subroutine setup_passing_lambda_grids(lambda_grid, weights)
    use gauss_quad, only: get_legendre_grids_from_cheb, get_radau_gauss_grids
    use theta_grid, only: ntgrid, bmag, bmax
    implicit none
    !> The lambda grid points. Note this is intent `in out` as we only
    !> set a portion of the grid so may want to keep the rest of the
    !> input values.
    real, dimension(:), intent(in out) :: lambda_grid
    !> The integration weights for the lambda grid. This is intent `in out`
    !> for the same reason as `lambda_grid`.
    real, dimension(-ntgrid:, :), intent(in out) :: weights
    real, dimension (:), allocatable :: wx, xx_radau
    integer :: ig, il, passing_split_index
    real :: passing_split_value

    ! Note this is a module level quantity that we keep around for possible
    ! future use. For iproc /= 0 this allocation is actually done in [[broadcast_results]]
    if (.not. allocated(xx)) allocate (xx(ng2))

    if(radau_gauss_grid .and. grid_has_trapped_particles() ) then

       ! This grid uses a fixed endpoint such that il = ng2+1 has a finite weight at bounce points
       ! which contributes to the integration over passing particles
       ! which is why xx, wx are one element longer than the standard Legendre Gauss grid
       ! This grid should only be used when nlambda > ng2, i.e. there is a wfb and trapped particles

       allocate (xx_radau(ng2 +1))
       allocate (wx(ng2 +1))
       call get_radau_gauss_grids(0.,1., xx_radau, wx,report_in=.false.)
       xx = xx_radau(:ng2)
    else
       allocate (wx(ng2))
       if (split_passing_region) then
          call try_to_optimise_passing_grid(passing_split_value, passing_split_index)
          if (passing_split_index > 0) then
             call get_legendre_grids_from_cheb (1., passing_split_value, &
                  xx(:passing_split_index), wx(:passing_split_index))
             call get_legendre_grids_from_cheb (passing_split_value, 0.0, &
                  xx(1 + passing_split_index:), wx(1 + passing_split_index:))
          else
             call get_legendre_grids_from_cheb (1., 0., xx, wx)
          end if
       else
          call get_legendre_grids_from_cheb (1., 0., xx, wx)
       end if

    end if

    ! Store the location of the passing lambda grid points
    lambda_grid(:ng2) = (1.0 - xx(:ng2)**2)/bmax

    ! Transform the Legendre/Radau weights to the lambda grid
    do il = 1, ng2
       do ig = -ntgrid, ntgrid
          weights(ig,il) = wx(il)*2.0*sqrt((bmag(ig)/bmax) &
               *((1.0/bmax-lambda_grid(il))/(1.0/bmag(ig)-lambda_grid(il))))
       end do
    end do

    ! Assign the weight for wfb from the radau-gauss grid
    if (radau_gauss_grid .and. grid_has_trapped_particles() ) then
       il=ng2+1
       do ig = -ntgrid, ntgrid
          if (bmag(ig) < bmax) then ! i.e. we are not the bounce point for wfb
             ! Note this will currently always be exactly 0
             ! as lambda_grid(ng2+1) == 1.0/bmax
             weights(ig,il) = wx(il)*2.0*sqrt((bmag(ig)/bmax) &
                  *((1.0/bmax-lambda_grid(il))/(1.0/bmag(ig)-lambda_grid(il))))
          else ! at the bounce point
             weights(ig,il) = wx(il)*2.0
          endif
       end do
    end if

    if(allocated(wx)) deallocate(wx)
    if(allocated(xx_radau)) deallocate(xx_radau)
  end subroutine setup_passing_lambda_grids

  !> Try to find the optimal way to split the passing lambda grid into
  !> two regions in order to minimise the error on the integration
  !> weights. The approach taken is trial and error -- we try a number
  !> of predetermined splits, spline the error over the trials and then
  !> evaluate the spline on a high resolution grid to determine an
  !> approximate minmium.
  subroutine try_to_optimise_passing_grid(optimal_split, optimal_points, &
       use_max_error, noptimal_points)
    use gauss_quad, only: get_legendre_grids_from_cheb
    use theta_grid, only: ntgrid, bmag, bmax
    use splines, only: new_spline, spline, splint
    use optionals, only: get_option_with_default
    implicit none
    real, intent(out) :: optimal_split
    integer, intent(out) :: optimal_points
    logical, intent(in), optional :: use_max_error
    integer, intent(in), optional :: noptimal_points
    real, dimension(:), allocatable :: lambda
    real, dimension(:, :), allocatable :: weights
    real, dimension (:), allocatable :: wx_tmp, xx_tmp
    real, dimension(:), allocatable :: analytic_sum, b_ratio
    real, dimension(:), allocatable :: error
    real, dimension(:), allocatable :: error_history
    real, dimension(*), parameter :: trial_splits = &
         [0.001, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.07, 0.08, 0.1, 0.5]
    integer, parameter :: ntrial = size(trial_splits)
    real, parameter :: min_split = trial_splits(1), max_split = trial_splits(ntrial)
    integer :: high_ntrial, itrial, il
    type(spline) :: the_spline
    integer :: minimum_location
    logical :: use_max_error_local
    logical, parameter :: debug = .false.
    integer, parameter :: high_res_factor = 100
    real :: error_from_no_split
    use_max_error_local = get_option_with_default(use_max_error, .false.)

    allocate (analytic_sum(-ntgrid:ntgrid))
    allocate (b_ratio(-ntgrid:ntgrid))
    allocate (error(-ntgrid:ntgrid))

    b_ratio = bmax / bmag
    analytic_sum = 2.0 * (1 - sqrt(b_ratio - 1)/sqrt(b_ratio))

    ! Determine how many points to use in the uppper region.
    ! Here we set it to about half if not passed in.
    ! For now this is a user choice/hard-coded value. We could
    ! imagine also optimising this value.
    optimal_points = get_option_with_default(noptimal_points, ng2 - (1 + ng2/2))
    if (optimal_points <= 0) optimal_points = ng2/2

    allocate(error_history(ntrial))
    allocate(xx_tmp(ng2), wx_tmp(ng2), lambda(ng2))
    allocate(weights(-ntgrid:ntgrid, ng2))

    ! Determine the initial error metric if we don't split
    call get_legendre_grids_from_cheb (1., 0., xx_tmp, wx_tmp)
    lambda = (1.0 - xx_tmp**2)/bmax
    ! Transform the Legendre weights to the lambda grid
    do il = 1, ng2
       weights(:,il) = wx_tmp(il)*2.0*sqrt((bmag/bmax) &
            *((1.0/bmax-lambda(il))/(1.0/bmag-lambda(il))))
    end do

    error = analytic_sum - sum(weights, dim=2)

    error_from_no_split = get_error_metric(error)

    ! Run with a few splits
    do itrial = 1, ntrial
       call get_legendre_grids_from_cheb (1., trial_splits(itrial), xx_tmp(:optimal_points), wx_tmp(:optimal_points))
       call get_legendre_grids_from_cheb (trial_splits(itrial), 0.0, xx_tmp(1+optimal_points:), wx_tmp(1+optimal_points:))

       lambda = (1.0 - xx_tmp**2)/bmax

       ! Transform the Legendre weights to the lambda grid
       do il = 1, ng2
          weights(:,il) = wx_tmp(il)*2.0*sqrt((bmag/bmax) &
               *((1.0/bmax-lambda(il))/(1.0/bmag-lambda(il))))
       end do

       error = analytic_sum - sum(weights, dim=2)

       error_history(itrial) = get_error_metric(error)
    end do

    deallocate(error)
    deallocate(xx_tmp)

    if (debug) print*,"History",error_history

    ! Spline the error metric vs splits
    the_spline = new_spline(trial_splits, error_history)

    high_ntrial = ntrial * high_res_factor
    allocate(error(high_ntrial))
    allocate(xx_tmp(high_ntrial))

    ! Evaluate the spline on a high resolution grid
    do itrial = 1, high_ntrial
       xx_tmp(itrial) = min_split + (itrial-1)*(max_split-min_split)/(high_ntrial-1)
       error(itrial) = splint(xx_tmp(itrial), the_spline)
    end do

    ! Choose the location of the minimum error metric as the optimal split point
    minimum_location = minloc(abs(error), dim = 1)

    ! We should probably calculate the actual error with this split
    ! choice in case our spline is not a good representation of the
    ! data. If the real error is larger than any of the real measurements
    ! we could just opt to use that instead?
    optimal_split = xx_tmp(minimum_location)

    if (debug) print*,"Optimal split : ",optimal_split," with error ",error(minimum_location)
    if (minval(abs(error)) > error_from_no_split) then
       optimal_split = -1
       optimal_points = -1
       if (debug) print*,'Minimum error from split larger than no split'
    end if
  contains
    pure real function get_error_metric(errors) result(metric)
      implicit none
      real, dimension(:), intent(in) :: errors
       ! Choose error metric we will later minimise
       if (use_max_error_local) then
          ! Maximum error
          metric = maxval(abs(errors))
       else
          ! Average error
          metric = sum(abs(errors))/size(errors)
       end if
    end function get_error_metric

  end subroutine try_to_optimise_passing_grid

  !> Determines the lambda grid point integration weights in the
  !> trapped domain using either new or old (default) methods,
  !> determined by [[le_grids_knobs::new_trap_int]] with
  !> new=high-order interp, old=finite difference).
  !>
  !> @note Here we overwrite/replace any previously calculated weights
  !> for points considered trapped. This is usually ok as we have a clear
  !> separation between passing and trapped so the trapped weights are
  !> zero on entry to this routine. In some situations, however, we
  !> may consider some points as both passing and trapped and run the
  !> risk of discarding any existing passing weights for such points.
  !> This could potentially be the case with Radau-like grids, for example
  !> (although likely isn't an issue currently as the routine here does not
  !> calculate a weight for the wfb when bmag = bmax). A simple fix is to
  !> replace `weights = ` with effectively `weights +=` in this routine and
  !> those it calls.
  subroutine setup_trapped_lambda_grids(lambda_grid, weights)
    use theta_grid, only: ntgrid
    implicit none
    !> The lambda grid points.
    real, dimension(:), intent(in) :: lambda_grid
    !> The integration weights for the lambda grid. Note this is
    !> intent `in out` as we only set a portion of the grid so may
    !> want to keep the rest of the input values.
    real, dimension(-ntgrid:, :), intent(in out) :: weights

    if (new_trap_int) then
       call setup_trapped_lambda_grids_new_trap_int(lambda_grid, weights)
    else
       call setup_trapped_lambda_grids_old_finite_difference(lambda_grid, weights)
    end if

  end subroutine setup_trapped_lambda_grids

  !> Determine trapped pitch angle weights using "new" polynomial
  !> interpolation.
  !>
  !> The new method uses Lagrange polynomial interpolation to provide
  !> an accurate approximation of the integration over the trapped
  !> region. The old method effectively uses the trapezium rule to
  !> integrate. In both cases the integration variable is the parallel
  !> velocity spanning the range [-v_||_wfb, v_||_wfb].  Both methods
  !> give zero weight to the wfb at the locations where B == Bmax.
  !> Generally the old method appears to be accurate to at most single
  !> precision, whilst the new method is more accurate, reaching the
  !> double precision round off level even with just a few trapped
  !> pitch angles. It can be noted that the the old method can reach
  !> the same level of accuracy at some theta grid points but not
  !> all. Further work may be able to optimise the grid to avoid these
  !> locations.  There are a number of reasons why we may not want to
  !> make new_trap_int default to true, including :
  !>   1. It is not compatible with radau_gauss_grid = T (which is the
  !>   default).
  !>   2. One might expect high order Lagrange interpolation to be
  !>      unstable so may anticipate that the new method will not work
  !>      well at higher ntheta. Empirically it seems this does not
  !>      become an issue here. Further, the input `nmax` can be used
  !>      to limit the maximum Lagrange order used by splitting the
  !>      domain.
  !>   3. For large ntheta the new method leads to slow
  !>      initialisation. This is due to the calculation of the error
  !>      coefficients for use with get_verr. This could perhaps be
  !>      optmised.
  !> Both points 2 and 3 could be avoided to some extent if the
  !> trapped pitch angle grid resolution is decoupled from the theta
  !> grid resolution. It may be possible to make the new approach
  !> compatible with radau_gauss_grid or we may decide
  !> radau_gauss_grid should default to false.
  subroutine setup_trapped_lambda_grids_new_trap_int(lambda_grid, weights)
    use theta_grid, only: ntgrid, bmag, bmax
    use file_utils, only: open_output_file, close_output_file, error_unit
    use mp, only: proc0
    implicit none
    !> The lambda grid points.
    real, dimension(:), intent(in) :: lambda_grid
    !> The integration weights for the lambda grid. Note this is
    !> intent `in out` as we only set a portion of the grid so may
    !> want to keep the rest of the input values.
    real, dimension(-ntgrid:, :), intent(in out) :: weights
    real, dimension (:), allocatable :: ytmp, yb, wb
    real, dimension (:,:), allocatable :: wberr
    integer :: npts, ntrap
    real :: llim, ulim, wgt_tmp
    logical :: eflag
    integer :: ig, il, ndiv, divmax

    ! Initialise the error flag
    eflag = .false.

    ! max number of trapped particles (occurs at outboard midplane)
    ntrap = nlambda - ng2

    ! wlterr contains weights for less accurate trapped integrals (for error estimation)
    allocate(wlterr(-ntgrid:ntgrid,nlambda,ntrap))

    wlterr = 0.0

    ! Find the integration weights for each theta grid point
    do ig = -ntgrid, ntgrid

       ! First we count how many trapped lambda bounce points we
       ! have to consider at this point. In other words how many
       ! bounce points are not forbidden for the current theta
       ! grid point.
       ! npts is the number of lambda_grid values in the trapped integral (varies with theta)
       npts = 0

       do il = ng2+1, nlambda
          if (1.0 - lambda_grid(il)*bmag(ig) > -bouncefuzz) then
             npts = npts + 1
          end if
       end do

       ! If there are any valid bounce points then we need to
       ! calculate the weights. If npts == 1 then we probably also expect
       ! the integration weights to be 0 so we could probably skip all of
       ! this work for npts = 1 as well.
       if (npts > 0) then

          ! ytmp is an array containing pitch angle grid points (for vpa >= 0)
          ! These are the v||/v points for each lambda between wfb (sets maximum
          ! v||/v here) and the lambda which bounces at this point (v||/v = 0)
          ! yb is an array containing the full set of [-v||/v, v||/v] grid points.
          ! Constructed directly from ytmp by copying and flipping sign+order
          ! wb is an array containing the integration weights calculated for the
          ! given v||/v grid.
          allocate(ytmp(npts), yb(2*npts-1), wb(2*npts-1))
          ! wberr is an array used to hold the weights used in the error estimation
          ! in [[trap_error]] called by [[get_verr]]
          allocate(wberr(2*npts-1,npts))

          ytmp = 0.0; yb = 0.0; wb = 0.0; wberr = 0.0

          ! loop computes transformed variable of integration (v||/v)
          do il = ng2+1, ng2+npts
             ytmp(il-ng2) = sqrt(max(1 - lambda_grid(il)*bmag(ig), 0.0))
          end do

          ! define array (yb) with pitch-angle gridpts
          ! corresponding to both positive and negative vpa
          if (npts > 1) yb(:npts-1) = -ytmp(:npts-1)
          yb(npts:) = ytmp(npts:1:-1)


          ! get grid point weights for trapped particle integral

          ! Note : Here ulim and llim are the upper and lower vpar
          ! values for the valid trapped pitch angles at this theta
          ! location. In other words ulim == v_||_wfb(theta(ig)) =
          ! sqrt(max(1.0-lambda_grid(ng2+1)*bmag(ig),0.0)) = yb(2*npts-1)
          ! We could probably replace ulim and llim with yb(2*npts-1)
          ! and yb(1) respectively.
          ulim = sqrt(max(1.0-bmag(ig)/bmax,0.0))
          llim = -ulim

          ! Call get_weights to find the integration weights for the v||/v grid
          ! Note we don't call this if ulim == 0 as the grid then has no extent.
          ! This situation can arise when bmag(ig) == bmax where we would find
          ! npts = 1, corresponding to just the wfb being a valid lambda. This
          ! means that at this location the wfb has zero integration weight.
          ! Note we currently ignore the error flag and other returned values other
          ! than the weights. We should probably check this to at least warn the
          ! user if something looks badly behaved.
          if (ulim > 0.0) then
             if (new_trap_int_split) then
                ! We split the weights calculation into two symmetric domains from
                ! -v||_wfb to 0 and from 0 to v||_wfb. This has been shown to have
                ! favourable properties when compared to treating the full -v||_wfb
                ! to v||_wfb domain in a single call. In particular significantly more
                ! accurate results when integrating certain functions (see mod vpa test
                ! in the le_grids_integrate unit tests).
                call get_weights (nmax, llim, 0., yb(:npts), wb(:npts), ndiv, divmax, eflag)

                if (eflag .and. proc0) then
                   write(error_unit(), '("Error flag set by first call to get_weights in setup_trapped_lambda_grids")')
                end if

                ! Store the v|| = 0 weight for the lower domain as this will be
                ! clobbered by the subsequent get_weights call so we need to add
                ! this on afterwards
                wgt_tmp = wb(npts)
                call get_weights (nmax, 0., ulim, yb(npts:), wb(npts:), ndiv, divmax, eflag)

                if (eflag .and. proc0) then
                   write(error_unit(), '("Error flag set by second call to get_weights in setup_trapped_lambda_grids")')
                end if

                wb(npts) = wb(npts) + wgt_tmp
             else
                call get_weights (nmax, llim, ulim, yb, wb, ndiv, divmax, eflag)

                if (eflag .and. proc0) then
                   write(error_unit(), '("Error flag set by first call to get_weights in setup_trapped_lambda_grids")')
                end if
             end if
          end if

          ! Calculate the weights for use in the trapped
          ! integration error estimation code.  This can be quite
          ! expensive so we would like to to only call this if the
          ! error estimation code is active (i.e. if we're going to
          ! call [[get_verr]]).
          if (npts > 1) call get_trapped_lambda_grid_error_estimate_weights(yb, llim, ulim, npts, wberr)

          ! Convert from v||/v weights to lambda
          ! weights. Essentially just sum up the postive and
          ! negtive v|| points corresponding to the same lambda
          ! point. Note we skip the v|| == 0 point currently to ensure we
          ! don't double count.
          if (npts > 1) then
             do il = ng2+1, ng2+npts-1
                ! take into account possible asymmetry of weights about xi = 0
                ! due to unequal # of grid points per integration interval
                ! Whilst the v||/v grid should always be symmetric
                ! the algorithm in [[get_weights]] could
                ! potentially lead to asymmetry in the resulting
                ! weights. This may be an indication that things
                ! aren't behaving well.
                weights(ig,il) = wb(il-ng2) + wb(2*npts-il+ng2)
                wlterr(ig,il,:npts) = wberr(il-ng2,:) + wberr(2*npts-il+ng2,:)
             end do
          end if

          ! avoid double counting of gridpoint at vpa=0
          weights(ig,ng2+npts) = wb(npts)
          wlterr(ig,ng2+npts,:npts) = wberr(npts,:)

          deallocate(ytmp, yb, wb, wberr)
       end if
    end do
  end subroutine setup_trapped_lambda_grids_new_trap_int

  !> Find the trapped pitch angle weights using an old
  !> (finite-difference) integration scheme
  !>
  !> Here we find the trapped pitch angle weights corresponding to a
  !> trapezium rule integration in v||/v. With this method we can
  !> write:
  !>   Int(f) ~ Sum_i((f(l_{i+1}) - f(l_{i}))*(l_{i+1}-l_{i}))/2
  !> with some handling required for the upper and lower
  !> boundaries. Here we find the v||/v spacing. Note there is no
  !> factor 1/2 due to the fact the lambda weight is the sum of the
  !> (identical) positive and negative v||/v weights, which cancels
  !> this factor of 1/2. Also note that there is no weight for the
  !> bouncing lambda.
  subroutine setup_trapped_lambda_grids_old_finite_difference(lambda_grid, weights)
    use theta_grid, only: ntgrid, bmag
    implicit none
    !> The lambda grid points.
    real, dimension(:), intent(in) :: lambda_grid
    !> The integration weights for the lambda grid. Note this is
    !> intent `in out` as we only set a portion of the grid so may
    !> want to keep the rest of the input values.
    real, dimension(-ntgrid:, :), intent(in out) :: weights
    real :: wwo
    integer :: ig, il

    do il = ng2+1, nlambda-1
       do ig = -ntgrid, ntgrid
          if (1.0-lambda_grid(il)*bmag(ig) > -bouncefuzz &
               .and. 1.0-lambda_grid(il+1)*bmag(ig) > -bouncefuzz) &
               then
             wwo = sqrt(max(1.0 -   lambda_grid(il)*bmag(ig),0.0)) - &
                  sqrt(max(1.0 - lambda_grid(il+1)*bmag(ig),0.0))
             weights(ig,il)   = weights(ig,il)   + wwo
             weights(ig,il+1) = weights(ig,il+1) + wwo
          end if
       end do
    end do
  end subroutine setup_trapped_lambda_grids_old_finite_difference

  !> Routine for getting the weights used in estimating the error on
  !> trapped particle integrals as a part of [[trap_error]]
  !>
  !> This routine is all to calculate wberr/wlterr which is just used
  !> in [[trap_error]], which in turn is only used in
  !> [[get_verr]]. This is used as a part of the [[vary_vnew]]
  !> adaptive collisionality code triggered through the
  !> diagnostics. This routine is rather expensive for high theta
  !> resolution - measured in the order of 4-5 minutes for ntheta=256.
  !> As such we might want to consider skipping this if the variable
  !> vnewk code is not active. As we do this for each npt and each
  !> theta, and npt ~ntheta/2 we might expect this to scale
  !> quadratically with ntheta. In practice it seems to scale somewhat
  !> worse than this. This may in part be due to an increase in the
  !> number of iterations required in get_weights such that this may
  !> scale closer to ntheta cubed.
  subroutine get_trapped_lambda_grid_error_estimate_weights(yb, llim, ulim, npts, wberr)
    implicit none
    real, dimension(:), intent(in) :: yb
    real, intent(in) :: llim, ulim
    integer, intent(in) :: npts
    real, dimension(:, :), intent(in out) :: wberr
    real, dimension (:), allocatable :: yberr, wberrtmp
    integer :: ix
    logical :: eflag
    integer :: divmaxerr, ndiverr

    ! Can't get error estimate for npts = 0 or 1 as we don't have any points
    ! that we can drop.
    if (npts > 1) then
       do ix=1,npts
          if (ix == 1) then
             ! drop the first and last grid points from the integral
             allocate (yberr(2*npts-3),wberrtmp(2*npts-3))
             yberr = 0.0; wberrtmp = 0.0
             yberr = yb(2:2*npts-2)
          else if (ix == npts) then
             ! drop the vpa=0 grid point from the integral
             allocate (yberr(2*npts-2),wberrtmp(2*npts-2))
             yberr = 0.0; wberrtmp = 0.0
             yberr(:npts-1) = yb(:npts-1)
             yberr(npts:) = yb(npts+1:)
          else
             ! drop the grid points corresponding to ix and its negative from the integral
             allocate (yberr(2*npts-3),wberrtmp(2*npts-3))
             yberr = 0.0; wberrtmp = 0.0
             yberr(:ix-1) = yb(:ix-1)
             yberr(ix:2*npts-ix-2) = yb(ix+1:2*npts-ix-1)
             yberr(2*npts-ix-1:) = yb(2*npts-ix+1:)
          end if

          call get_weights (nmax, llim, ulim, yberr, wberrtmp, ndiverr, divmaxerr, eflag)

          ! insert a weight of zero into indices corresponding to ix and its conjugate
          if (ix == 1) then
             wberr(2:2*npts-2,1) = wberrtmp
          else if (ix == npts) then
             wberr(:npts-1,npts) = wberrtmp(:npts-1)
             wberr(npts+1:,npts) = wberrtmp(npts:)
          else
             wberr(:ix-1,ix) = wberrtmp(:ix-1)
             wberr(ix+1:2*npts-ix-1,ix) = wberrtmp(ix:2*npts-ix-2)
             wberr(2*npts-ix+1,ix) = wberrtmp(2*npts-ix-1)
          end if

          deallocate (yberr,wberrtmp)
       end do
    end if
  end subroutine get_trapped_lambda_grid_error_estimate_weights

  !> Determine which lambda grid points are forbidden at each
  !> theta/bmag value.
  subroutine calculate_forbidden_region(forbid)
    use theta_grid, only: ntgrid, bmag
    use mp, only: mp_abort
    implicit none
    logical, dimension(-ntgrid:, :), intent(out) :: forbid
    integer :: ig, il

    forbid = .false.

    ! Set the forbid flag
    do il = 1, nlambda
       do ig = -ntgrid, ntgrid
          forbid(ig,il) = 1.0 - al(il)*bmag(ig) < -bouncefuzz
       end do
    end do

    ! Check that none of our supposedly passing particles are forbidden
    ! at any point in our domain
    do il = 1, ng2
       if ( any(forbid(:,il)) ) then
          call mp_abort("Fatal error: supposedly passing particle was trapped, in calculate_forbidden_region, in le_grids.f90", .true.)
       end if
    end do
  end subroutine calculate_forbidden_region

  !> Determine which theta grid points correspond to bounce points
  !> for each pitch angle.
  !>
  !> For regular trapped particles we _could_ determine this from forbid,
  !> as we do in other parts of the code, however we prefer a forbid independent
  !> method to allow for generalisation to wfb and wfb-like particles.
  subroutine calculate_bounce_points(bounce_points)
    use theta_grid, only: ntgrid, bmag
    use mp, only: mp_abort
    implicit none
    logical, dimension(-ntgrid:, :), intent(out) :: bounce_points
    integer :: ig, il
    integer :: il_llim

    bounce_points = .false.

    ! We could set the lower lambda grid index which we consider here
    ! in order to either include or exclude the wfb if we allow it to
    ! bounce. For now we will not allow for wfb bounce points, leaving these
    ! to be handled by special code elsewhere.
    il_llim = ng2 + 2

    do il = il_llim, nlambda
       do ig = -ntgrid, ntgrid
          ! Note, this imposes a requirement that our pitch angle and
          ! magnetic field grids are calculated consisently to within bouncefuzz.
          ! In other words we require lambda*bmag = 1 to within bouncefuzz.
          bounce_points(ig,il) = abs(1.0 - al(il)*bmag(ig)) <= bouncefuzz
       end do
    end do

  end subroutine calculate_bounce_points

  !> Returns true if the passed theta grid index, ig, is a lower bounce point
  !> for the passed pitch angle index, il.
  elemental logical function is_lower_bounce_point(ig, il) result(is_lower)
    use theta_grid, only: ntgrid
    implicit none
    integer, intent(in) :: ig, il
    is_lower = .false.

    ! A lower bounce point firstly has to be a bounce point
    if (.not. is_bounce_point(ig, il)) return

    ! Special handling for wfb, only bounce for trapped_wfb Currently
    ! this doesn't activate as wfb don't have bounce points flagged in
    ! is_bounce_point.
    if (il_is_wfb(il) .and. .not. trapped_wfb) return

    ! If we're at the lower end of the grid then it is a lower bounce
    ! point (if it's a bounce point). If not at the lower end of the
    ! theta grid, the magnetic field at the previous theta grid point
    ! must be higher than this one. Note we can't chain checks
    ! together as there's no short-circuiting so if ig==-ntgrid we don't
    ! want to access bmag(ig-1)
    if (ig == -ntgrid) then
       is_lower = .true.
    elseif (ig == ntgrid) then
       is_lower = .false.
    else
       ! The commented out below code is fine for non-wfb like particles (those
       ! which have bounce points which are both upper and lower bounce points, i.e.
       ! those without a fobidden region separating allowed regions).
       ! is_lower = bmag(ig-1) > bmag(ig)

       ! Instead we simply check if the _next_ theta grid point is part of an allowed
       ! region. This works for all pitch angles _except_ totally trapped particles.
       !is_lower = .not. forbid(ig + 1, il) .or. &
       !     (forbid(ig - 1, il) .and. forbid(ig + 1, il)) ! For totally trapped particles

       ! For now we choose to invert this logic and check if the _previous_ theta grid point
       ! is forbidden. This means that internal wfb-like bounce points are not detected
       ! as having an upper bounce point. This is a temporary work around to maintain
       ! old behaviour whilst we improve our trapped particle handling. See issues 148
       ! and 239
       is_lower = forbid(ig - 1, il) .or. &
            (forbid(ig - 1, il) .and. forbid(ig + 1, il)) ! For totally trapped particles
    end if

  end function is_lower_bounce_point

  !> Returns true if the passed theta grid index, ig, is a upper bounce point
  !> for the passed pitch angle index, il.
  elemental logical function is_upper_bounce_point(ig, il) result(is_upper)
    use theta_grid, only: ntgrid
    implicit none
    integer, intent(in) :: ig, il
    is_upper = .false.

    ! A lower bounce point firstly has to be a bounce point
    if (.not. is_bounce_point(ig, il)) return

    ! Special handling for wfb, only bounce for trapped_wfb Currently
    ! this doesn't activate as wfb don't have bounce points flagged in
    ! is_bounce_point.
    if (il_is_wfb(il) .and. .not. trapped_wfb) return

    ! If we're at the lower end of the grid then it is a lower bounce
    ! point (if it's a bounce point). If not at the lower end of the
    ! theta grid, the magnetic field at the previous theta grid point
    ! must be higher than this one. Note we can't chain checks
    ! together as there's no short-circuiting so if ig==ntgrid we don't
    ! want to access bmag(ig+1)
    if (ig == -ntgrid) then
       is_upper = .false.
    elseif (ig == ntgrid) then
       is_upper = .true.
    else
       ! The commented out below code is fine for non-wfb like particles (those
       ! which have bounce points which are both upper and lower bounce points, i.e.
       ! those without a fobidden region separating allowed regions).
       ! is_upper = bmag(ig+1) > bmag(ig)

       ! Instead we simply check if the _previous_ theta grid point is part of an allowed
       ! region. This works for all pitch angles _except_ totally trapped particles.
       !is_upper = .not. forbid(ig - 1, il) .or. &
       !     (forbid(ig - 1, il) .and. forbid(ig + 1, il)) ! For totally trapped particles

       ! For now we choose to invert this logic and check if the _next_ theta grid point
       ! is forbidden. This means that internal wfb-like bounce points are not detected
       ! as having an upper bounce point. This is a temporary work around to maintain
       ! old behaviour whilst we improve our trapped particle handling. See issues 148
       ! and 239
       is_upper = forbid(ig + 1, il) .or. &
           (forbid(ig - 1, il) .and. forbid(ig + 1, il)) ! For totally trapped particles
    end if

  end function is_upper_bounce_point

  !> Determine which lambda grid point bounces at this
  !> theta/bmag value.
  subroutine calculate_jend(jend)
    use theta_grid, only: ntgrid, bmag
    use mp, only: mp_abort
    implicit none
    integer, dimension(-ntgrid:), intent(out) :: jend
    integer :: ig, il

    jend = 0

    ! jend(ig) is total # of valid al grid points at each theta value

    !CMR, 1/11/2013:
    ! Above, with no trapped particles, we set: jend(ig)=   0
    ! Here, with trapped particles, we set:     jend(ig)=  il
    !  where il is the lambda index of the trapped particle bouncing at theta(ig)

    ! Note it might be better to set jend(ig) > nlambda if we have no
    ! trapped particles so il <= jend(ig) is always true for a particle
    ! which isn't forbidden. This would require a change in logic elsewhere

    ! Exit now if there are no trapped particles
    if (.not. grid_has_trapped_particles()) return

    do ig = -ntgrid, ntgrid
       ! We could initialise jend to ng2 and start this loop at ng2 + 1
       ! as we assume that all lambda up to and including with ng2 satisfy
       ! 1-lambda*bmag > -bouncefuzz for all theta. This is enforced/checked
       ! in [[calculate_forbidden_region]].
       do il = 1, nlambda
          if (1.0 - al(il)*bmag(ig) > -bouncefuzz) jend(ig) = jend(ig) + 1
       end do
    end do
  end subroutine calculate_jend

  !> Determine which lambda grid point is totally trapped at this
  !> theta grid point.
  subroutine calculate_ittp(can_be_ttp_flag, is_ttp_value)
    use theta_grid, only: ntgrid
    implicit none
    logical, dimension(nlambda), intent(out) :: can_be_ttp_flag
    logical, dimension(-ntgrid:ntgrid, nlambda), intent(out) :: is_ttp_value
    integer, dimension(-ntgrid:ntgrid) :: ittp_indices
    integer :: ig, il
    ittp_indices = nlambda+1
    can_be_ttp_flag = .false.
    is_ttp_value = .false.
    if (.not. grid_has_trapped_particles()) return

    ! Note we exclude the possibility of totally trapped particles
    ! existing at either end of the theta grid.
    do ig = -ntgrid+1, ntgrid-1
       ! all pitch angles greater than or equal to ittp are totally trapped or forbidden
       do il = ng2+1, nlambda
          if (forbid(ig-1,il) .and. forbid(ig+1, il) &
               .and. .not. forbid(ig, il)) then
             ittp_indices(ig) = il

             ! Record the lowest il which satisfies the condition
             exit
          end if
       end do
    end do

    do il = ng2+1, nlambda
       can_be_ttp_flag(il) = any(il >= ittp_indices)
    end do

    ! Calculate is_ttp_value, indicating which pitch angles can reach
    ! and are totally trapped at this theta location.
    do il = ng2 + 1, nlambda
       do ig = -ntgrid, ntgrid
          if (il >= ittp_indices(ig) .and. .not. forbid(ig, il)) is_ttp_value(ig, il) = .true.
       end do
    end do

  end subroutine calculate_ittp

  !> FIXME : Add documentation  
  subroutine xigridset
    use theta_grid, only: ntgrid, bmag
    implicit none

    integer :: ig, je, ixi, il

    if (.not. allocated(xi)) allocate (xi(2*nlambda, -ntgrid:ntgrid))
    if (.not. allocated(ixi_to_il)) allocate (ixi_to_il(2*nlambda, -ntgrid:ntgrid))
    if (.not. allocated(ixi_to_isgn)) allocate (ixi_to_isgn(2*nlambda, -ntgrid:ntgrid))
    if (.not. allocated(wxi)) allocate (wxi(2*nlambda, -ntgrid:ntgrid))

    ! define array 'sgn' that returns sign of vpa associated with isgn=1,2
    sgn(1) = 1
    sgn(2) = -1

    ! xi is vpa / v and goes from 1 --> -1
    xi = 0.0
    do ig = -ntgrid, ntgrid
       xi(:jend(ig), ig) = sqrt(max(1.0 - al(:jend(ig))*spread(bmag(ig),1,jend(ig)),0.0))
       xi(jend(ig)+1:2*jend(ig)-1, ig) = -xi(jend(ig)-1:1:-1, ig)
    end do
    
    ! get mapping from ixi (which runs between 1 and 2*nlambda) and il (runs between 1 and nlambda)
    do ig = -ntgrid, ntgrid
       je = jend(ig)
       ! if no trapped particles
       if (je == 0) then
          do ixi = 1, 2*nlambda
             if (ixi > nlambda) then
                ixi_to_isgn(ixi, ig) = 2
                ixi_to_il(ixi, ig) = 2*nlambda - ixi + 1
             else
                ixi_to_isgn(ixi, ig) = 1
                ixi_to_il(ixi, ig) = ixi
             end if
          end do
       else
!CMR, 1/11/2013:
! Sketch of how ixi=>il mapping is arranged
!===============================================================================================
!  ixi=   1, ... , je-1, je, je+1, ... , 2je-1, || 2je, 2je+1, ... , nl+je, nl+je+1, ... ,  2nl
!   il=   1, ... , je-1, je, je-1, ... ,     1, ||  je,  je+1, ... ,    nl,      nl, ... , je+1
! isgn=   1, ... ,    1,  2,    2, ... ,     2, ||   1,     1, ... ,     1,       2, ... ,    2
!         particles passing through             ||  je,   + forbidden trapped velocity space
!         nb only need one isigma for je, as v||=0 at the bounce point
!===============================================================================================

          do ixi = 1, 2*nlambda
             if (ixi >= nlambda + je + 1) then
                ixi_to_isgn(ixi, ig) = 2
                ixi_to_il(ixi, ig) = 2*nlambda + je + 1 - ixi
             else if (ixi >= 2*je) then
                ixi_to_isgn(ixi, ig) = 1
                ixi_to_il(ixi, ig) = ixi - je
             else if (ixi >= je) then
                ixi_to_isgn(ixi, ig) = 2
                ixi_to_il(ixi, ig) = 2*je - ixi
             else
                ixi_to_isgn(ixi, ig) = 1
                ixi_to_il(ixi, ig) = ixi
             end if
          end do
       end if
    end do

    wxi = 0.0
    do ig = -ntgrid, ntgrid
       do ixi = 1, 2*nlambda
          il = ixi_to_il(ixi, ig)
          wxi(ixi, ig) = wl(ig, il)
       end do
    end do
  end subroutine xigridset

  !> FIXME : Add documentation
  !!
  !! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  !! NOTE: Takes f = f(x, y, z, sigma, lambda, E, species) and returns int f, where the integral
  !! is over x-y space
  subroutine integrate_volume_c (g, total, all)
    use theta_grid, only: ntgrid
    use kt_grids, only: aky
    use gs2_layouts, only: g_lo, is_idx, ik_idx, ie_idx, il_idx
    use mp, only: nproc, sum_reduce, sum_allreduce

    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    complex, dimension (-ntgrid:,:,:,:,:), intent (out) :: total
    integer, optional, intent(in) :: all
    real :: fac
    integer :: is, il, ie, ik, iglo, isgn

    !Initialise to zero
    total = 0.

    !Do integral over local x-y space
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)

       !Pick the weighting factor
       if (aky(ik) == 0.) then
          fac = 1.0
       else
          fac = 0.5
       end if

       !For both signs of vpar do sum
       !May be more efficient to move ign loop above iglo loop (good for total but bad for g memory access)
       do isgn = 1, 2
          total(:, il, ie, isgn, is) = total(:, il, ie, isgn, is) + &
               fac*g(:,isgn,iglo)
       end do
    end do

    !Do we really need this if statement?
    if (nproc > 1) then
       if (present(all)) then
          call sum_allreduce (total)
       else
          call sum_reduce (total, 0)
       end if
    end if

  end subroutine integrate_volume_c

  !> FIXME : Add documentation
  !!
  !! returns results to PE 0 [or to all processors if 'all' is present in input arg list]
  !! NOTE: Takes f = f(x, y, z, sigma, lambda, E, species) and returns int f, where the integral
  !! is over x-y space  
  subroutine integrate_volume_r (g, total, all)
    use theta_grid, only: ntgrid
    use kt_grids, only: aky
    use gs2_layouts, only: g_lo, is_idx, ik_idx, ie_idx, il_idx
    use mp, only: nproc,sum_reduce, sum_allreduce

    implicit none

    real, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in) :: g
    real, dimension (-ntgrid:,:,:,:,:), intent (out) :: total
    integer, optional, intent(in) :: all
    real :: fac
    integer :: is, il, ie, ik, iglo, isgn

    !Initialise to zero
    total = 0.

    !Do integral over local x-y space
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)

       !Pick the weighting factor
       if (aky(ik) == 0.) then
          fac = 1.0
       else
          fac = 0.5
       end if

       !For both signs of vpar do sum
       !May be more efficient to move ign loop above iglo loop (good for total but bad for g memory access)
       do isgn = 1, 2
          total(:, il, ie, isgn, is) = total(:, il, ie, isgn, is) + &
               fac*g(:,isgn,iglo)
       end do
    end do

    !Do we really need this if statement?
    if (nproc > 1) then
       if (present(all)) then
          call sum_allreduce (total)
       else
          call sum_reduce (total, 0)
       end if
    end if

  end subroutine integrate_volume_r

  !> Calculates and returns toroidal momentum flux as a function
  !! of vpar and theta
  subroutine get_flux_vs_theta_vs_vpa (f, vflx, dealloc)

    use theta_grid, only: ntgrid, bmag
    use species, only: nspec

    implicit none
    logical, intent(in), optional :: dealloc
    real, dimension (-ntgrid:,:,:,:,:), intent (in) :: f
    real, dimension (-ntgrid:,:,:), intent (out) :: vflx

    real, dimension (:,:,:), allocatable :: favg
    real, dimension (:,:), allocatable, save :: vpa1d
    real, dimension (:,:,:), allocatable, save :: hermp1d
    real, dimension (:,:,:,:,:), allocatable, save :: vpapts
    real, dimension (:,:,:,:,:,:), allocatable, save :: hermp

    integer :: is, il, ie, ig, iv
    integer :: norder

    if(present(dealloc))then
       if(allocated(vpa1d)) deallocate(vpa1d)
       if(allocated(hermp1d)) deallocate(hermp1d)
       if(allocated(vpapts)) deallocate(vpapts)
       if(allocated(hermp)) deallocate(hermp)
       return
    endif

    norder = min(negrid, nlambda)/2

    allocate (favg(-ntgrid:ntgrid,nspec,0:norder-1))

    if (.not. allocated(vpapts)) then
       allocate (vpa1d(negrid*nlambda,nspec))
       allocate (hermp1d(negrid*nlambda,0:norder-1,nspec))
       allocate (vpapts(-ntgrid:ntgrid,nlambda,negrid,2,nspec))
       allocate (hermp(-ntgrid:ntgrid,nlambda,negrid,2,0:norder-1,nspec))
       vpapts = 0.0 ; hermp = 0.0 ; vpa1d = 0.0 ; hermp1d = 0.0

       do ie = 1, negrid
          do il = 1, nlambda
             do ig = -ntgrid, ntgrid
                vpapts(ig,il,ie,1,:) = sqrt(energy(ie,:)*max(0.0, 1.0-al(il)*bmag(ig)))
                vpapts(ig,il,ie,2,:) = -vpapts(ig,il,ie,1,:)
             end do
          end do
       end do

       do iv = 1, negrid*nlambda
          vpa1d(iv,:) = sqrt(energy(negrid,:))*(1. - 2.*(iv-1)/real(negrid*nlambda-1))
       end do

       do is = 1,nspec
         call get_hermite_polynomials (vpa1d(:,is), hermp1d(:,:,is))
         call get_hermite_polynomials (vpapts(:,:,:,:,is), hermp(:,:,:,:,:,is))
       end do
    end if

    favg = 0.
    do is = 1, nspec
       do ie = 1, negrid
          do il = 1, nlambda
             do ig = -ntgrid, ntgrid
                favg(ig,is,:) = favg(ig,is,:) &
                     +w(ie,is)*wl(ig,il)*(hermp(ig,il,ie,1,:,is)*f(ig,il,ie,1,is) &
                     +hermp(ig,il,ie,2,:,is)*f(ig,il,ie,2,is))
             end do
          end do
       end do
    end do

    do is = 1, nspec
       do iv = 1, negrid*nlambda
          do ig = -ntgrid, ntgrid
             vflx(ig,iv,is) = sum(favg(ig,is,:)*hermp1d(iv,:,is))*exp(-vpa1d(iv,is)**2)
          end do
       end do
    end do

    deallocate (favg)
    
  end subroutine get_flux_vs_theta_vs_vpa

  !> Returns Gn = Hn / sqrt(2^n n!) / pi^(1/4),
  !! where Hn are the hermite polynomials
  !! i.e. int dx Gm * Gn exp(-x^2) = 1
  subroutine get_hermite_polynomials_4d (xptsdum, hpdum)
    use constants, only: pi
    use theta_grid, only: ntgrid

    implicit none

    real, dimension (-ntgrid:,:,:,:), intent (in)   :: xptsdum
    real, dimension (-ntgrid:,:,:,:,0:), intent (out) :: hpdum

    integer :: j
    double precision, dimension (:,:,:,:), allocatable :: hp1, hp2, hp3

    hpdum = 0.0

    allocate (hp1(-ntgrid:ntgrid,nlambda,negrid,2))
    allocate (hp2(-ntgrid:ntgrid,nlambda,negrid,2))
    allocate (hp3(-ntgrid:ntgrid,nlambda,negrid,2))

    hp1 = real(1.0,kind(hp1(0,1,1,1)))
    hp2 = real(0.0,kind(hp2(0,1,1,1)))

    hpdum(:,:,:,:,0) = 1.0

    do j=1, size(hpdum,5)-1
       hp3 = hp2
       hp2 = hp1
       hp1 = sqrt(2./j)*xptsdum*hp2 - sqrt(real(j-1)/j)*hp3
       hpdum(:,:,:,:,j) = hp1
    end do

    hpdum = hpdum/pi**(0.25)

    deallocate (hp1,hp2,hp3)

  end subroutine get_hermite_polynomials_4d

  !> Returns Gn = Hn / sqrt(2^n n!) / pi^(1/4),
  !! where Hn are the hermite polynomials
  !! i.e. int dx Gm * Gn exp(-x^2) = 1
  subroutine get_hermite_polynomials_1d (xptsdum, hpdum)
    use constants, only: pi
    implicit none

    real, dimension (:), intent (in)   :: xptsdum
    real, dimension (:,0:), intent (out) :: hpdum

    integer :: j
    double precision, dimension (:), allocatable :: hp1, hp2, hp3

    hpdum = 0.0

    allocate (hp1(size(xptsdum)))
    allocate (hp2(size(xptsdum)))
    allocate (hp3(size(xptsdum)))

    hp1 = real(1.0,kind(hp1(1)))
    hp2 = real(0.0,kind(hp2(1)))

    hpdum(:,0) = 1.0

    do j=1, size(hpdum,2)-1
       hp3 = hp2
       hp2 = hp1
       hp1 = sqrt(2./j)*xptsdum*hp2 - sqrt(real(j-1)/j)*hp3
       hpdum(:,j) = hp1
    end do

    hpdum = hpdum/pi**(0.25)

    deallocate (hp1,hp2,hp3)

  end subroutine get_hermite_polynomials_1d

  !> Returns probabilist's Hermite polynomials, He_n(x)
  elemental function hermite_prob (n, x)
    implicit none
    integer, intent (in) :: n
    real, intent (in) :: x
    integer :: k
    double precision :: hermite_prob, p, p1, p2

    p1 = x
    p2 = dble(1.0)

    if (n==0) then
       hermite_prob = p2
       return
    else if (n==1) then
       hermite_prob = p1
       return
    end if

    do k=2, n
       p = x*p1 - (k-1)*p2
       p2 = p1
       p1 = p
    end do

    hermite_prob = p

  end function hermite_prob

  !> FIXME : Add documentation
  subroutine init_map (use_lz_layout, use_e_layout, use_le_layout, test)
    use mp, only: finish_mp, proc0
    use redistribute, only: report_map_property
    implicit none

    logical, intent (in) :: use_lz_layout, use_e_layout, use_le_layout, test

    ! initialize maps from g_lo to lz_lo, e_lo, and/or le_lo

    if (use_lz_layout) then
       ! init_lambda_layout is called in redistribute
       call init_lambda_redistribute_local

       if (test) then
          if (proc0) print *, '=== Lambda map property ==='
          call report_map_property (lambda_map)
       end if
    end if

    if (use_e_layout) then
       ! init_energy_layout is called in redistribute
       call init_energy_redistribute_local

       if (test) then
          if (proc0) print *, '=== Energy map property ==='
          call report_map_property (energy_map)
       end if
    end if

    if (use_le_layout) then
       call init_g2le_redistribute_local
       if (test) call check_g2le

    end if

    call init_g2gf(test)

    if (test) then
       if (proc0) print *, 'init_map done'
    end if

  end subroutine init_map

  !> FIXME : Add documentation  
  subroutine init_g2gf(test)
    implicit none

    logical, intent (in) :: test

    call init_g2gf_redistribute
    
    if (test) call check_g2gf

  end subroutine init_g2gf

  !> Construct a redistribute for g_lo -> le_lo
  subroutine setup_g2le_redistribute_local(g_lo, le_lo, g2le)
    use layouts_type, only: g_layout_type, le_layout_type
    use mp, only: nproc
    use gs2_layouts, only: idx_local, proc_id
    use gs2_layouts, only: ig_idx, isign_idx
    use gs2_layouts, only: ik_idx, it_idx, ie_idx, is_idx, il_idx, idx
    use sorting, only: quicksort
    use redistribute, only: index_list_type, init_redist, delete_list

    implicit none
    type(g_layout_type), intent(in) :: g_lo
    type(le_layout_type), intent(in) :: le_lo
    type(redist_type), intent(in out) :: g2le
    type (index_list_type), dimension(0:nproc-1) :: to_list, from_list, sort_list
    integer, dimension (0:nproc-1) :: nn_to, nn_from
    integer, dimension (3) :: from_low, from_high
    integer, dimension (3) :: to_high
    integer :: to_low
    integer :: ig, isign, iglo, il, ile
    integer :: ie
    integer :: n, ip, je
    integer :: ile_bak, il0
    !> The following is for debug/testing to force all grid
    !> points to be communicated if we wish.
    logical, parameter :: skip_forbidden_region = .true.

    !Initialise the data counters
    nn_to = 0
    nn_from = 0

    !First count the data to be sent | g_lo-->le_lo
    !Protect against procs with no data
    if(g_lo%ulim_proc.ge.g_lo%llim_proc)then
       do iglo = g_lo%llim_proc,g_lo%ulim_alloc
          !Get le_lo idx for ig=-ntgrid
          ile=idx(le_lo,-g_lo%ntgrid,ik_idx(g_lo,iglo),&
               it_idx(g_lo,iglo),is_idx(g_lo,iglo))

          il = il_idx(g_lo, iglo)

          !Loop over remaining dimensions, note ile is independent of isign
          !so add two to count
          do ig=-g_lo%ntgrid, g_lo%ntgrid
             if ((.not. forbid(ig, il)) .or. .not. skip_forbidden_region) then
                !Increment the data sent counter for this proc
                nn_from(proc_id(le_lo,ile))=nn_from(proc_id(le_lo,ile))+2
             end if

             !Increment ile
             ile=ile+1
          enddo
       enddo
    endif

    !Now count how much data to receive | le_lo<--g_lo
    !Protect against procs with no data
    if(le_lo%ulim_proc.ge.le_lo%llim_proc)then
       do ile=le_lo%llim_proc,le_lo%ulim_alloc
          ig = ig_idx(le_lo, ile)

          !Loop over local dimensions, adding 2 to account for each sign
          do ie=1,g_lo%negrid !le_lo%?
             do il=1,g_lo%nlambda !le_lo%?
                if (forbid(ig, il) .and. skip_forbidden_region) cycle

                !Get index
                iglo=idx(g_lo,ik_idx(le_lo,ile),it_idx(le_lo,ile),&
                     il,ie,is_idx(le_lo,ile))

                !Increment the data to receive counter
                nn_to(proc_id(g_lo,iglo))=nn_to(proc_id(g_lo,iglo))+2
             enddo
          enddo
       enddo
    endif

    !Now allocate storage for index arrays
    do ip = 0, nproc-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip)%first (nn_from(ip)))
          allocate (from_list(ip)%second(nn_from(ip)))
          allocate (from_list(ip)%third (nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip)%first (nn_to(ip)))
          allocate (to_list(ip)%second(nn_to(ip)))
          allocate (to_list(ip)%third (nn_to(ip)))
          !For sorting message order later
          allocate (sort_list(ip)%first(nn_to(ip)))
       end if
    end do

    !Reinitialise counters
    nn_to = 0
    nn_from = 0

    !First fill in sending indices, these define the message order
    !Protect against procs with no data
    if(g_lo%ulim_proc.ge.g_lo%llim_proc)then
       do iglo=g_lo%llim_proc,g_lo%ulim_alloc
          !Get ile for ig=-ntgrid
          ile=idx(le_lo,-g_lo%ntgrid,ik_idx(g_lo,iglo),&
               it_idx(g_lo,iglo),is_idx(g_lo,iglo))
          ile_bak=ile
          il = il_idx(g_lo, iglo)
          !Loop over sign
          do isign=1,2
             do ig=-g_lo%ntgrid, g_lo%ntgrid
                if ((.not. forbid(ig, il)) .or. .not. skip_forbidden_region ) then
                   !Get proc id
                   ip=proc_id(le_lo,ile)

                   !Increment procs message counter
                   n=nn_from(ip)+1
                   nn_from(ip)=n

                   !Store indices
                   from_list(ip)%first(n)=ig
                   from_list(ip)%second(n)=isign
                   from_list(ip)%third(n)=iglo
                end if

                !Increment ile
                ile=ile+1
             enddo

             !Restore ile
             ile=ile_bak
          enddo
       enddo
    endif

    !Now fill in the receiving indices, these must match message data order
    !Protect against procs with no data
    if(le_lo%ulim_proc.ge.le_lo%llim_proc)then
       do ile=le_lo%llim_proc,le_lo%ulim_alloc
          !Get ig index
          ig=ig_idx(le_lo,ile)

          !Loop over local dimensions,Whilst ile is independent of sign this information
          !is in lambda so loop over sign included here
          do ie=1,g_lo%negrid !le_lo%?
             do isign=1,2
                do il0=1,g_lo%nlambda !le_lo%?
                   if (forbid(ig, il0) .and. skip_forbidden_region) cycle
                   !Pick correct extended lambda value
                   je=jend(ig)
                   il=il0
                   if (je.eq.0) then
                      if (isign.eq.2) il=2*g_lo%nlambda+1-il !le_lo%?
                   else
                      if(il.eq.je) then
                         if(isign.eq.1) il=2*je
                      else if(il.gt.je) then
                         if(isign.eq.1) then
                            il=il+je
                         else
                            il=2*g_lo%nlambda+1-il+je !le_lo%?
                         endif
                      else
                         if(isign.eq.2) il=2*je-il
                      endif
                   endif

                   !Get iglo value
                   iglo=idx(g_lo,ik_idx(le_lo,ile),it_idx(le_lo,ile),&
                        il0,ie,is_idx(le_lo,ile))

                   !Get proc_id
                   ip=proc_id(g_lo,iglo)

                   !Increment counter
                   n=nn_to(ip)+1
                   nn_to(ip)=n

                   !Store indices
                   to_list(ip)%first(n)=il
                   to_list(ip)%second(n)=ie
                   to_list(ip)%third(n)=ile

                   !Store sorting index
                   sort_list(ip)%first(n)=ig+g_lo%ntgrid-1+g_lo%ntgridtotal*(isign-1+2*(iglo-g_lo%llim_world))
                enddo
             enddo
          enddo
       enddo
    endif

    !Now sort receive indices into message order
    do ip=0,nproc-1
       if(nn_to(ip)>0) then
          !Apply quicksort
          call quicksort(nn_to(ip),sort_list(ip)%first,to_list(ip)%first,to_list(ip)%second,to_list(ip)%third)
       endif
    enddo

    !Now setup array range values
    from_low (1) = -g_lo%ntgrid
    from_low (2) = 1
    from_low (3) = g_lo%llim_proc

    from_high(1) = g_lo%ntgrid
    from_high(2) = 2
    from_high(3) = g_lo%ulim_alloc

    to_low = le_lo%llim_proc

    ! Note: Need to revisit next two so not dependent on module level nlambda/negrid
    to_high(1) = max(2*nlambda, 2*ng2+1)
    to_high(2) = negrid + 1  ! TT: just followed convention with +1.
    ! TT: It may be good to avoid bank conflict.
    to_high(3) = le_lo%ulim_alloc

    !Create g2le redist object
    call init_redist (g2le, 'c', to_low, to_high, to_list, from_low, from_high, from_list)

    !Deallocate lists
    call delete_list (to_list)
    call delete_list (from_list)
    call delete_list (sort_list)
  end subroutine setup_g2le_redistribute_local

  !> Constructs the redistribute mapping from the global g_lo data
  !> decomposition to the le_lo decomposition.
  subroutine init_g2le_redistribute_local
    use mp, only: nproc, iproc
    use species, only: nspec
    use theta_grid, only: ntgrid
    use kt_grids, only: naky, ntheta0
    use gs2_layouts, only: init_le_layouts, g_lo, le_lo

    !Early exit if possible
    if (leinit) return
    leinit = .true.

    !Setup the le layout object (le_lo)
    call init_le_layouts (ntgrid, naky, ntheta0, nspec, nproc, iproc)

    call setup_g2le_redistribute_local(g_lo, le_lo, g2le)
  end subroutine init_g2le_redistribute_local

  !> FIXME : Add documentation
  subroutine check_g2le
    use file_utils, only: error_unit
    use mp, only: finish_mp, iproc, proc0
    use theta_grid, only: ntgrid
    use gs2_layouts, only: g_lo, le_lo
    use gs2_layouts, only: ig_idx, ik_idx, it_idx, il_idx, ie_idx, is_idx
    use redistribute, only: gather, scatter, report_map_property

    implicit none

    integer :: iglo, ile, ig, isgn, ik, it, il, ie, is, ierr
    integer :: ixi, je
    complex, dimension (:,:,:), allocatable :: gtmp, letmp

    if (proc0) then
       ierr = error_unit()
    else
       ierr = 6
    end if

    ! report the map property
    if (proc0) print *, '=== g2le map property ==='
    call report_map_property (g2le)

    allocate (gtmp(-ntgrid:ntgrid, 2, g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (letmp(nlambda*2+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
    letmp = 0.

    ! gather check
    gtmp = 0.0
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig = -ntgrid, ntgrid
             gtmp(ig,isgn,iglo) = rule(ig,isgn,ik,it,il,ie,is)
          end do
       end do
    end do
    call gather (g2le, gtmp, letmp)
    do ile = le_lo%llim_proc, le_lo%ulim_proc
       ig = ig_idx(le_lo,ile)
       je = jend(ig)
       ik = ik_idx(le_lo,ile)
       it = it_idx(le_lo,ile)
       is = is_idx(le_lo,ile)
       do ie = 1, negrid
          do ixi = 1, 2*nlambda
             isgn = ixi_to_isgn(ixi, ig)
             il = ixi_to_il(ixi, ig)
             if (int(real(letmp(ixi,ie,ile))) /= rule(ig,isgn,ik,it,il,ie,is)) &
                  write (ierr,'(a,8i6)') 'ERROR: gather by g2le broken!', iproc
          end do
       end do
    end do
    if (proc0) write (ierr,'(a)') 'g2le gather check done'

    ! scatter check
    gtmp = 0.0
    call scatter (g2le, letmp, gtmp)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig = -ntgrid, ntgrid
             if (gtmp(ig,isgn,iglo) /= rule(ig,isgn,ik,it,il,ie,is)) &
                  write (ierr,'(a,i6)') 'ERROR: scatter by g2le broken!', iproc
          end do
       end do
    end do
    if (proc0) write (ierr,'(a)') 'g2le scatter check done'

    deallocate (gtmp,letmp)

!    call finish_mp
!    stop

  contains

    function rule (ig, isgn, ik, it, il, ie, is)
      integer, intent (in) :: ig, isgn, ik, it, il, ie, is
      integer :: rule
      rule = ig + isgn + ik + it + il + ie + is  ! make whatever you want
    end function rule

  end subroutine check_g2le

  !> Construct the redistribute for g_lo -> gf_lo
  subroutine setup_g2gf_redistribute(g_lo, gf_lo, g2gf)
    use mp, only: nproc
    use layouts_type, only: g_layout_type, gf_layout_type
    use gs2_layouts, only: idx_local, proc_id
    use gs2_layouts, only: ig_idx, isign_idx
    use gs2_layouts, only: ik_idx, it_idx, ie_idx, is_idx, il_idx, idx
    use redistribute, only: index_list_type, init_redist, delete_list
    use sorting, only: quicksort

    implicit none
    type(g_layout_type), intent(in) :: g_lo
    type(gf_layout_type), intent(in) :: gf_lo
    type(redist_type), intent(in out) :: g2gf
    type (index_list_type), dimension(0:nproc-1) :: to_list, from_list, sort_list, bak_sort_list
    integer, dimension (0:nproc-1) :: nn_to, nn_from
    integer, dimension (3) :: from_low, from_high
    integer, dimension (6) :: to_low, to_high
    integer :: iglo, il, igf
    integer :: ie, is
    integer :: n, ip
    !Initialise the data counters
    nn_to = 0
    nn_from = 0

    !First count the data to be sent | g_lo-->gf_lo
    !Protect against procs with no data
    if(g_lo%ulim_proc.ge.g_lo%llim_proc)then
       do iglo = g_lo%llim_proc,g_lo%ulim_alloc

          igf = idx(gf_lo, ik_idx(g_lo, iglo), it_idx(g_lo, iglo))
          nn_from(proc_id(gf_lo,igf))=nn_from(proc_id(gf_lo,igf))+1
       enddo
    endif

    !Now count how much data to receive | gf_lo<--g_lo
    !Protect against procs with no data
    if(gf_lo%ulim_proc.ge.gf_lo%llim_proc)then
       do igf=gf_lo%llim_proc,gf_lo%ulim_alloc
          do is=1,g_lo%nspec
             do ie=1,g_lo%negrid
                do il=1,g_lo%nlambda
                   iglo = idx(g_lo, ik_idx(gf_lo, igf), it_idx(gf_lo, igf), il, ie, is)
                   !Increment the data to receive counter
                   nn_to(proc_id(g_lo,iglo))=nn_to(proc_id(g_lo,iglo))+1
                enddo
             enddo
          enddo
       enddo
    endif


    !Now allocate storage for index arrays
    do ip = 0, nproc-1
       if (nn_from(ip) > 0) then
          allocate (from_list(ip)%first (nn_from(ip)))
          allocate (from_list(ip)%second(nn_from(ip)))
          allocate (from_list(ip)%third (nn_from(ip)))
       end if
       if (nn_to(ip) > 0) then
          allocate (to_list(ip)%first (nn_to(ip)))
          allocate (to_list(ip)%second(nn_to(ip)))
          allocate (to_list(ip)%third (nn_to(ip)))
          allocate (to_list(ip)%fourth (nn_to(ip)))
          allocate (to_list(ip)%fifth (nn_to(ip)))
          allocate (to_list(ip)%sixth (nn_to(ip)))
          allocate (sort_list(ip)%first (nn_to(ip)))
          allocate (bak_sort_list(ip)%first (nn_to(ip)))
       end if
    end do

    !Reinitialise counters
    nn_to = 0
    nn_from = 0