collisions.fpp Source File


Contents

Source Code


Source Code

!> Routines for implementing the model collision operator defined by
!> [Barnes, Abel et
!> al. 2009](https://pubs.aip.org/aip/pop/article/16/7/072107/263154/Linearized-model-Fokker-Planck-collision-operators)
!> or on [arxiv](https://arxiv.org/abs/0809.3945v2).
!> The collision operator causes physically motivated smoothing of
!> structure in velocity space which is necessary to prevent buildup
!> of structure at fine scales in velocity space, while conserving
!> energy and momentum.
module collisions
  use abstract_config, only: abstract_config_type, CONFIG_MAX_NAME_LEN
  use redistribute, only: redist_type

  implicit none

  private 

  public :: init_collisions, finish_collisions, reset_init
  public :: read_parameters, wnml_collisions, check_collisions
  public :: dtot, fdf, fdb, vnmult, vnfac, ncheck, vnslow, vary_vnew
  public :: etol, ewindow, etola, ewindowa, init_lorentz_error
  public :: init_lorentz, init_ediffuse, init_lorentz_conserve, init_diffuse_conserve
  public :: colls, hyper_colls, heating, adjust, split_collisions,use_le_layout, solfp1
  public :: collision_model_switch, collision_model_lorentz, collision_model_none
  public :: collision_model_lorentz_test, collision_model_full, collision_model_ediffuse

  public :: collisions_config_type
  public :: set_collisions_config
  public :: get_collisions_config
  
  interface solfp1
     module procedure solfp1_le_layout
     module procedure solfp1_standard_layout
  end interface

  interface solfp_lorentz
     module procedure solfp_lorentz_le_layout
     module procedure solfp_lorentz_standard_layout
  end interface

  interface conserve_lorentz
     module procedure conserve_lorentz_le_layout
     module procedure conserve_lorentz_standard_layout
  end interface

  interface conserve_diffuse
     module procedure conserve_diffuse_le_layout
     module procedure conserve_diffuse_standard_layout
  end interface 

  ! knobs
  logical :: use_le_layout
  logical :: const_v, conserve_moments
  logical :: conservative, resistivity
  integer :: collision_model_switch
  integer :: lorentz_switch, ediff_switch
  logical :: adjust
  logical :: heating
  logical :: hyper_colls
  logical :: ei_coll_only
  logical :: test
  logical :: special_wfb_lorentz
  logical :: vpar_zero_mean
  logical :: conserve_forbid_zero
  
  integer, parameter :: collision_model_lorentz = 1      ! if this changes, check gs2_diagnostics
  integer, parameter :: collision_model_none = 3
  integer, parameter :: collision_model_lorentz_test = 5 ! if this changes, check gs2_diagnostics
  integer, parameter :: collision_model_full = 6
  integer, parameter :: collision_model_ediffuse = 7

  integer, parameter :: lorentz_scheme_default = 1
  integer, parameter :: lorentz_scheme_old = 2

  integer, parameter :: ediff_scheme_default = 1
  integer, parameter :: ediff_scheme_old = 2

  real, dimension (2) :: vnmult = -1.0
  integer :: ncheck
  logical :: vary_vnew
  real :: vnfac, vnslow
  real :: etol, ewindow, etola, ewindowa
  integer :: timesteps_between_collisions
  logical :: force_collisions, has_lorentz, has_diffuse
  real, dimension (2) :: vnm_init = 1.0

  real, dimension (:,:,:), allocatable :: dtot
  ! (-ntgrid:ntgrid,nlambda,max(ng2,nlambda-ng2)) lagrange coefficients for derivative error estimate

  real, dimension (:,:), allocatable :: fdf, fdb
  ! (-ntgrid:ntgrid,nlambda) finite difference coefficients for derivative error estimate

  real, dimension (:,:,:), allocatable :: vnew, vnew_s, vnew_D, vnew_E, delvnew
  ! (naky,negrid,nspec) replicated

  real, dimension (:,:), allocatable :: vpdiffle
  real, dimension (:,:,:), allocatable :: vpdiff
  ! (-ntgrid:ntgrid,2,nlambda) replicated

  ! only for hyper-diffusive collisions
  real, dimension (:,:,:,:), allocatable :: vnewh
  ! (-ntgrid:ntgrid,ntheta0,naky,nspec) replicated

  ! only for momentum conservation due to Lorentz operator (8.06)
  real, dimension(:,:,:), allocatable :: s0, w0, z0
  ! The following (between the start and finish comments) are only used for LE layout
  ! start
  real, dimension (:,:,:), allocatable :: s0le, w0le, z0le
  ! Copies of aj0, aj1 and vpa from dist_fn_arrays stored in
  ! le layout shaped arrays.
  real, dimension (:,:,:), allocatable :: aj0le, vperp_aj1le, vpa_aj0_le
  ! finish

  ! needed for momentum and energy conservation due to energy diffusion (3.08)
  real, dimension(:,:,:), allocatable :: bs0, bw0, bz0

  ! The following (between the start and finish comments) are only used for LE layout
  ! start
  real, dimension (:,:,:), allocatable :: bs0le, bw0le, bz0le
  ! finish

  real :: cfac
  integer :: nxi_lim !< Sets the upper xi index which we consider in loops
  ! The following (between the start and finish comments) are only used for LE layout
  ! start
  ! only for lorentz
  real, dimension (:,:,:), allocatable :: c1le, betaale, qle, d1le, h1le
  ! only for energy diffusion
  real, dimension (:,:,:), allocatable :: ec1le, ebetaale, eqle
  ! finish
  ! The following (between the start and finish comments) are only used for none LE layout
  ! start
  ! only for lorentz
  real, dimension (:,:), allocatable :: c1, betaa, ql, d1, h1
  ! only for energy diffusion
  real, dimension (:,:), allocatable :: ec1, ebetaa, eql
  ! finish

  real, dimension (:, :), allocatable :: pitch_weights

  logical :: drag = .false.
  logical :: colls = .true.
  logical :: split_collisions

  logical :: hypermult
  logical :: initialized = .false.
  logical :: exist

  !> Used to represent the input configuration of collisions
  type, extends(abstract_config_type) :: collisions_config_type
     ! namelist : collisions_knobs
     ! indexed : false
     !> If true (default) then transform from the gyro-averaged
     !> distribution function \(g\) evolved by GS2 to the
     !> non-Boltzmann part of \(\delta f\), (\(h\)), when applying the
     !> collision operator. This is the physically appropriate choice,
     !> this parameter is primarily for numerical testing.
     logical :: adjust = .true.
     !> Factor multipyling the finite Larmor radius terms in the
     !> collision operator. This term is essentially just classical
     !> diffusion. Set `cfac` to 0 to turn off this diffusion.
     !>
     !> @note Default changed to 1.0 in order to include classical
     !> diffusion April 18 2006
     real :: cfac = 1.0
     !> Selects the collision model used in the simulation. Can be one
     !> of
     !>
     !> - `'default'` : Include both pitch angle scattering and energy
     !> diffusion.
     !> - `'collisionless'` : Disable the collision operator.
     !> - `'none'` : Equivalent to `'collisionless'`.
     !> - `'lorentz'` : Only pitch angle scattering.
     !> - `'lorentz-test'` : Only pitch angle scattering. For testing,
     !> disables some components of the operator.
     !> - `'ediffuse'` : Only energy diffusion.
     !>
     !> If no species have a non-zero collision frequency, `vnewk`, then
     !> the collision operator is also automatically disabled.
     character(len = 20) :: collision_model = 'default'
     !> If true (default) then guarantee exact conservation
     !> properties.
     logical :: conservative = .true.
     !> If true (default) then forces conservation corrections to zero
     !> in the forbidden region to avoid introducing unphysical
     !> non-zero values for the distribution function in the forbidden
     !> region of trapped particles.
     !>
     !> @todo Confirm above documentation is accurate.
     !>
     !> @note The conserving terms calculated as part of the field
     !> particle collision operator should respect the forbidden of
     !> the distribution function for trapped particles. This is a
     !> cosmetic change, but has the result that plots of the
     !> distribution function for trapped particles makes sense. Note
     !> that terms involving vpa do not need to be modified Because
     !> vpa = 0 in the forbidden region because of this explicit
     !> forbid statements have not been added to the drag term
     !> involving apar.
     logical :: conserve_forbid_zero = .true.
     !> If true (default) then guarantee collision operator conserves
     !> momentum and energy.
     !>
     !> @todo Clarify the difference with [[collisions_knobs:conservative]].
     !>
     !> @note Default changed to reflect improved momentum and energy
     !> conversation 07/08
     logical :: conserve_moments = .true.  
     !> If true (not the default) then use the thermal velocity to
     !> evaluate the collision frequencies to be used. This results in
     !> an energy independent collision frequency being used for all
     !> species.
     logical :: const_v = .false.
     !> Controls how the coefficients in the matrix representing the
     !> energy diffusion operator are obtained. Can be one of
     !>
     !> - `'default'` : Use a conservative scheme.
     !> - `'old'` : Use the original non-conservative scheme.
     !>
     !> @todo Consider deprecating/removing the `'old'` option.
     character(len = 20) :: ediff_scheme = 'default'
     !> If true (not the default) then force the collision frequency
     !> used for all non-electron species to zero and force all
     !> electron-electron terms to zero.
     logical :: ei_coll_only = .false.
     !> Only used in [[get_verr]] as a part of the adaptive
     !> collisionality algorithm. Sets the maximum relative error
     !> allowed, above which the collision frequency must be
     !> increased.
     !>
     !> @todo Confirm this is really to set the relative error limit.
     real :: etol = 2.e-2
     !> Only used in [[get_verr]] as a part of the adaptive
     !> collisionality algorithm. Sets the maximum absolute error
     !> allowed, above which the collision frequency must be
     !> increased.
     !>
     !> @todo Confirm this is really to set the absolute error limit.     
     real :: etola = 2.e-2
     !> Only used in [[get_verr]] as a part of the adaptive
     !> collisionality algorithm. Sets an offset to apply to the
     !> relative error limit set by [[collisions_knobs:etol]]. This is used to provide
     !> hysteresis is the adaptive collisionality algorithm so to
     !> avoid adjusting the collision frequency up and down every step
     !> (similar to [[reinit_knobs:delt_cushion]]).
     !>
     !> @todo Confirm the above description.
     real :: ewindow = 1.e-2
     !> Only used in [[get_verr]] as a part of the adaptive
     !> collisionality algorithm. Sets an offset to apply to the
     !> absolute error limit set by [[collisions_knobs:etola]]. This is used to provide
     !> hysteresis is the adaptive collisionality algorithm so to
     !> avoid adjusting the collision frequency up and down every step
     !> (similar to the [[reinit_knobs:delt_cushion]]).
     !>
     !> @todo Confirm the above description.     
     real :: ewindowa = 1.e-2
     !> Currently we skip the application of the selected collision
     !> operator for species with zero collision frequency and disable
     !> collisions entirely if no species have non-zero collision
     !> frequency. This is generally sensible, but it can be useful to
     !> force the use of the collisional code path in some situations
     !> such as in code testing. Setting this flag to `.true.` forces
     !> the selected collision model operator to be applied even if
     !> the collision frequency is zero.
     logical :: force_collisions = .false.
     !> If true (not the default) then calculate collisional heating
     !> when applying the collion operator. This is purely a
     !> diagnostic calculation. It should not change the evolution.
     !>
     !> @todo : Verify this does not influence the evolution.
     logical :: heating = .false.
     !> If true (not the default) then multiply the hyper collision
     !> frequency by the species' collision frequency
     !> [[species_parameters:nu_h]]. This only impacts the pitch
     !> angle scattering operator.
     !>
     !> @note The hyper collision frequency is only non-zero if any
     !> species have a non-zero `nu_h` value set in the input. If any
     !> are set then the hyper collision frequency is simply `nu_h *
     !> kperp2 * kperp2` (where `kperp2` here is normalised to the
     !> maximum `kperp2`).
     logical :: hypermult = .false.
     !> Controls how the coefficients in the matrix representing the
     !> pitch angle scattering operator are obtained. Can be one of
     !>
     !> - `'default'` : Use a conservative scheme.
     !> - `'old'` : Use the original non-conservative scheme.
     !>
     !> @todo Consider deprecating/removing the `'old'` option.
     character(len = 20) :: lorentz_scheme = 'default'
     !> Used as a part of the adaptive collisionality algorithm. When
     !> active we check the velocity space error with [[get_verr]]
     !> every `ncheck` time steps. This check can be relatively
     !> expensive so it is recommended to avoid small values of
     !> `ncheck`.
     !>
     !> @warning The new diagnostics module currently ignores this
     !> value and instead uses its own input variable named `ncheck`
     !> (which has a different default). See [this
     !> bug](https://bitbucket.org/gyrokinetics/gs2/issues/88).
     integer :: ncheck = 100
     !> If true (default) then potentially include the drag term in
     !> the pitch angle scattering operator. This is a necessary but
     !> not sufficient criteria. For the drag term to be included we
     !> also require \(\beta\neq 0\), more than one simulated species
     !> and finite \(A_\|\) perturbations included in the simulation
     !> (i.e. `fapar /= 0`).
     logical :: resistivity = .true.
     !> If true (not the default) then use special handling for the
     !> wfb particle in the pitch angle scattering operator.
     !>
     !> @note MRH changed default 16/08/2018. Previous default of true
     !> seemed to cause a numerical issue in flux tube simulations for
     !> the zonal modes at large \(k_x\).
     !>
     !> @todo Improve this documentation
     logical :: special_wfb_lorentz = .false.
     !> If true (not the default) then remove the collision operator
     !> from the usual time advance algorithm. Instead the collision
     !> operator is only applied every
     !> [[collisions_knobs:timesteps_between_collisions]]
     !> timesteps. This can potentially substantially speed up
     !> collisional simulations, both in the initialisation and
     !> advance phases.
     !>
     !> @warning Currently the input
     !> [[collisions_knobs:timesteps_between_collisions]] is ignored
     !> so collisions are applied every time step. The primary result
     !> of `split_collision = .true.` currently is that collisions are
     !> not applied in the first linear solve used as a part of a
     !> single time step. Hence the cost of collisions in advance are
     !> roughly halved. The full saving in the initialisation phase is
     !> still realised.
     logical :: split_collisions = .false.
     !> If true (not the default) then performs some additional checks
     !> of the data redistribution routines associated with
     !> transforming being the standard and collisional data
     !> decompositions.
     logical :: test = .false.
     !> Should set the number of timesteps between application of the
     !> collision operator if [[collisions_knobs:split_collisions]] is
     !> true. Currently this is ignored.
     !>
     !> @warning This value is currently ignored.
     integer :: timesteps_between_collisions = 1
     !> If true (default) then use a data decomposition for collisions
     !> that brings both pitch angle and energy local to a
     !> processor. This is typically the most efficient option,
     !> however for collisional simulations that only use part of the
     !> collision operator (either energy diffusion or pitch angle
     !> scattering) then it may be beneficial to set this flag to
     !> false such that we use a decomposition that only brings either
     !> energy or pitch angle local.
     logical :: use_le_layout = .true.
     !> Set to true (not the default) to activate the adaptive
     !> collisionality algorithm.
     !>
     !> @todo Provide more documentation on the adaptive
     !> collisionality algorithm.
     logical :: vary_vnew = .false.
     !> If the collisionality is to be increased as a part of the
     !> adaptive collisionality algorithm then increase it by this
     !> factor.
     real :: vnfac = 1.1
     !> If the collisionality is to be decreased as a part of the
     !> adaptive collisionality algorithm then decrease it by this
     !> factor.
     real :: vnslow = 0.9
     !> Controls how the duplicate `vpar = 0` point is handled.  When
     !> `vpar_zero_mean = .true.` (the default) the average of `g(vpar
     !> = 0)` for both signs of the parallel velcoity (`isgn`) is used
     !> in the collision operator instead of just `g(vpar = 0)` at
     !> `isgn=2`.  This is seen to suppress a numerical instability
     !> when `special_wfb_lorentz =.false.`. With these defaults pitch
     !> angle scattering at \(\theta = \pm \pi \) is now being handled
     !> physically i.e. `vpar = 0` at this theta location is no longer
     !> being skipped.
     !>
     !> @todo Consider removing this option.
     logical :: vpar_zero_mean = .true.
   contains
     procedure, public :: read => read_collisions_config
     procedure, public :: write => write_collisions_config
     procedure, public :: reset => reset_collisions_config
     procedure, public :: broadcast => broadcast_collisions_config
     procedure, public, nopass :: get_default_name => get_default_name_collisions_config
     procedure, public, nopass :: get_default_requires_index => get_default_requires_index_collisions_config
  end type collisions_config_type

  type(collisions_config_type) :: collisions_config
contains
  
  !> FIXME : Add documentation
  subroutine check_collisions(report_unit)
    implicit none
    integer, intent(in) :: report_unit
    select case (collision_model_switch)
    case (collision_model_lorentz,collision_model_lorentz_test)
       write (report_unit, fmt="('A Lorentz collision operator has been selected.')")
       if (cfac > 0) write (report_unit, fmt="('This has both terms of the Lorentz collision operator: cfac=',e12.4)") cfac
       if (cfac == 0) write (report_unit, fmt="('This is only a partial Lorentz collision operator (cfac=0.0)')")
       if (const_v) write (report_unit, fmt="('This is an energy independent Lorentz collision operator (const_v=true)')")  
!          if (hypercoll) call init_hyper_lorentz
    case (collision_model_full)
       write (report_unit, fmt="('Full GS2 collision operator has been selected.')")
    end select
  end subroutine check_collisions

  !> FIXME : Add documentation
  subroutine wnml_collisions(unit)
    implicit none
    integer, intent(in) :: unit
    if (.not.exist) return
    write (unit, *)
    write (unit, fmt="(' &',a)") "collisions_knobs"
    select case (collision_model_switch)
    case (collision_model_lorentz)
       write (unit, fmt="(' collision_model = ',a)") '"lorentz"'
       if (hypermult) write (unit, fmt="(' hypermult = ',L1)") hypermult
    case (collision_model_lorentz_test)
       write (unit, fmt="(' collision_model = ',a)") '"lorentz-test"'
    case (collision_model_none)
       write (unit, fmt="(' collision_model = ',a)") '"collisionless"'
    end select
    write (unit, fmt="(' cfac = ',f6.3)") cfac
    write (unit, fmt="(' heating = ',L1)") heating
    write (unit, fmt="(' /')")
  end subroutine wnml_collisions

  !> FIXME : Add documentation
  subroutine init_collisions(collisions_config_in)
    use species, only: init_species, nspec, spec
    use theta_grid, only: init_theta_grid, ntgrid
    use kt_grids, only: init_kt_grids, naky, ntheta0
    use le_grids, only: init_le_grids, nlambda, negrid
    use run_parameters, only: init_run_parameters
    use gs2_layouts, only: init_dist_fn_layouts, init_gs2_layouts
    use mp, only: nproc, iproc
    implicit none
    type(collisions_config_type), intent(in), optional :: collisions_config_in
    if (initialized) return
    initialized = .true.
    call init_gs2_layouts
    call init_species

    hyper_colls = .false.
    if (any(spec%nu_h > epsilon(0.0))) hyper_colls = .true.

    call init_theta_grid
    call init_kt_grids
    call init_le_grids
    call init_run_parameters
    call init_dist_fn_layouts (ntgrid, naky, ntheta0, nlambda, negrid, nspec, nproc, iproc)
    call read_parameters(collisions_config_in)

    call init_arrays

  end subroutine init_collisions

  !> FIXME : Add documentation
  subroutine read_parameters(collisions_config_in)
    use file_utils, only: input_unit, error_unit, input_unit_exist
    use text_options, only: text_option, get_option_value
    use run_parameters, only: beta, fapar
    use species, only: nspec
    use le_grids, only: nxi, ng2
    implicit none
    type(collisions_config_type), intent(in), optional :: collisions_config_in
    
    type (text_option), dimension (6), parameter :: modelopts = &
         (/ text_option('default', collision_model_full), &
            text_option('lorentz', collision_model_lorentz), &
            text_option('ediffuse', collision_model_ediffuse), &
            text_option('lorentz-test', collision_model_lorentz_test), &
            text_option('none', collision_model_none), &
            text_option('collisionless', collision_model_none) /)
    type (text_option), dimension (2), parameter :: schemeopts = &
         (/ text_option('default', lorentz_scheme_default), &
            text_option('old', lorentz_scheme_old) /)
    type (text_option), dimension (2), parameter :: eschemeopts = &
         (/ text_option('default', ediff_scheme_default), &
            text_option('old', ediff_scheme_old) /)
    character(20) :: collision_model, lorentz_scheme, ediff_scheme
    integer :: ierr

    if (present(collisions_config_in)) collisions_config = collisions_config_in

    call collisions_config%init(name = 'collisions_knobs', requires_index = .false.)

    ! Copy out internal values into module level parameters
    adjust = collisions_config%adjust
    cfac = collisions_config%cfac
    collision_model = collisions_config%collision_model
    conservative = collisions_config%conservative
    conserve_forbid_zero = collisions_config%conserve_forbid_zero
    conserve_moments = collisions_config%conserve_moments
    const_v = collisions_config%const_v
    ediff_scheme = collisions_config%ediff_scheme
    ei_coll_only = collisions_config%ei_coll_only
    etol = collisions_config%etol
    etola = collisions_config%etola
    ewindow = collisions_config%ewindow
    ewindowa = collisions_config%ewindowa
    force_collisions = collisions_config%force_collisions
    heating = collisions_config%heating
    hypermult = collisions_config%hypermult
    lorentz_scheme = collisions_config%lorentz_scheme
    ncheck = collisions_config%ncheck
    resistivity = collisions_config%resistivity
    special_wfb_lorentz = collisions_config%special_wfb_lorentz
    split_collisions = collisions_config%split_collisions
    test = collisions_config%test
    timesteps_between_collisions = collisions_config%timesteps_between_collisions
    use_le_layout = collisions_config%use_le_layout
    vary_vnew = collisions_config%vary_vnew
    vnfac = collisions_config%vnfac
    vnslow = collisions_config%vnslow
    vpar_zero_mean = collisions_config%vpar_zero_mean

    exist = collisions_config%exist


    ierr = error_unit()
    call get_option_value &
         (collision_model, modelopts, collision_model_switch, &
         ierr, "collision_model in collisions_knobs",.true.)

    call get_option_value &
         (lorentz_scheme, schemeopts, lorentz_switch, &
         ierr, "lorentz_scheme in collisions_knobs",.true.)
    
    call get_option_value &
         (ediff_scheme, eschemeopts, ediff_switch, &
         ierr, "ediff_scheme in collisions_knobs",.true.)

    select case (collision_model_switch)
    case (collision_model_full)
       has_lorentz = .true.  ; has_diffuse = .true.
    case (collision_model_lorentz,collision_model_lorentz_test)
       has_lorentz = .true.  ; has_diffuse = .false.
    case (collision_model_ediffuse)
       has_lorentz = .false. ; has_diffuse = .true.
    case default
       has_lorentz = .false. ; has_diffuse = .false.
    end select

    drag = has_lorentz .and. resistivity .and. (beta > epsilon(0.0)) &
         .and. (nspec > 1) .and. (fapar.gt.0)

    ! The nxi > 2 * ng2 check appears to be checking if we have
    ! trapped particles or not so could be replaced with grid_has_trapped_particles()
    if (nxi > 2 * ng2 ) then
       nxi_lim = nxi + 1
    else
       nxi_lim = nxi
    end if
  end subroutine read_parameters

  !> A wrapper to sqrt which replaces -ve values with 0.0 to avoid
  !> NaNs arising from slight floating point discrepancies. We could
  !> consider adding a debug check to abort/warn if the passed value
  !> is too negative (i.e. if it looks like an error rather than small
  !> round off).
  elemental real function safe_sqrt(arg)
    implicit none
    real, intent(in) :: arg
    safe_sqrt = sqrt(max(0.0, arg))
  end function safe_sqrt

  !> FIXME : Add documentation
  subroutine init_arrays
    use species, only: nspec
    use le_grids, only: init_map
    use kt_grids, only: naky, ntheta0
    use theta_grid, only: ntgrid
    use dist_fn_arrays, only: c_rate
    use array_utils, only: zero_array
    implicit none
    logical :: use_lz_layout, use_e_layout
! lowflow terms include higher-order corrections to GK equation
! such as parallel nonlinearity that require derivatives in v-space.
! most efficient way to take these derivatives is to go from g_lo to le_lo,
! i.e., bring all energies and lambdas onto each processor
!<DD>Note the user can still disable use_le_layout in the input file
!    this just changes the default.
!>@note, not clear the above is true -- init_map will always be called with
!le layout when lowflow active here. Given use_le_layout = T is default we
!might be able to get rid of this conditional compilation now?
# ifdef LOWFLOW
    use_le_layout = .true.
# endif
    use_lz_layout = .false. ; use_e_layout = .false.

    if (collision_model_switch == collision_model_none) then
       colls = .false.
       return
    end if

    call init_vnew
    if (all(abs(vnew(:,1,:)) <= 2.0*epsilon(0.0)) .and. .not. force_collisions) then
       collision_model_switch = collision_model_none
       colls = .false.
       return
    end if

    if (heating .and. .not. allocated(c_rate)) then
       allocate (c_rate(-ntgrid:ntgrid, ntheta0, naky, nspec, 3))
       call zero_array(c_rate)
    end if

    use_lz_layout = has_lorentz .and. .not. use_le_layout
    use_e_layout = has_diffuse .and. .not. use_le_layout
    call init_map (use_lz_layout, use_e_layout, use_le_layout, test)

    if (has_lorentz) then
       call init_lorentz
       if (conserve_moments) call init_lorentz_conserve
    end if

    if (has_diffuse) then
       call init_ediffuse
       if (conserve_moments) call init_diffuse_conserve
    end if

    if (use_le_layout .and. (conserve_moments .or. drag)) call init_le_bessel
  end subroutine init_arrays

  !> Communicate Bessel functions from g_lo to le_lo
  subroutine init_le_bessel
    use gs2_layouts, only: g_lo, le_lo, ig_idx
    use dist_fn_arrays, only: aj0, aj1, vpa
    use le_grids, only: negrid, nxi, g2le, ixi_to_il, energy => energy_maxw, al
    use theta_grid, only: ntgrid
    use redistribute, only: gather, scatter
    use array_utils, only: zero_array
    implicit none
    complex, dimension (:,:,:), allocatable :: ctmp, z_big
    integer :: ile, ig, ie, ixi, il

    allocate (z_big(-ntgrid:ntgrid, 2, g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (ctmp(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
    ! We need to initialise ctmp as it is used as receiving buffer in
    ! g2le redistribute, which doesn't populate all elements
    call zero_array(ctmp)

    ! next set aj0le & aj1l
    z_big(:,1,:) = cmplx(aj0,aj1)
    z_big(:,2,:) = z_big(:,1,:)

    call gather (g2le, z_big, ctmp)

    if (.not. allocated(aj0le)) then
       allocate (aj0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       allocate (vperp_aj1le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
    end if

    aj0le = real(ctmp)
    vperp_aj1le = aimag(ctmp) !< Currently just aj1

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(ile, ig, ixi, ie, il) &
    !$OMP SHARED(le_lo, negrid, nxi, ixi_to_il, vperp_aj1le, energy, al) &
    !$OMP SCHEDULE(static)
    do ile = le_lo%llim_proc, le_lo%ulim_proc
       ig = ig_idx(le_lo, ile)
       do ie = 1, negrid
          do ixi = 1, nxi + 1
             il = ixi_to_il(ixi, ig)
             vperp_aj1le(ixi, ie, ile) = vperp_aj1le(ixi, ie, ile) * energy(ie) * al(il)
          end do
       end do
    end do
    !$OMP END PARALLEL DO

    z_big = vpa
    call gather (g2le, z_big, ctmp)
    deallocate(z_big)
    if (.not. allocated(vpa_aj0_le)) then
       allocate (vpa_aj0_le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
    end if
    vpa_aj0_le = real(ctmp) * aj0le
  end subroutine init_le_bessel

  !> Precompute three quantities needed for momentum and energy conservation:
  !> z0, w0, s0 (z0le, w0le, s0le if use_le_layout chosen in the input file defined)
  subroutine init_lorentz_conserve
    use gs2_layouts, only: g_lo, ie_idx, is_idx, ik_idx, il_idx, it_idx
    use species, only: nspec, spec, is_electron_species
    use kt_grids, only: kperp2, naky, ntheta0, kwork_filter
    use theta_grid, only: ntgrid, bmag
    use le_grids, only: energy => energy_maxw, speed => speed_maxw, al, &
         integrate_moment, negrid, forbid
    use gs2_time, only: code_dt, tunits
    use dist_fn_arrays, only: aj0, aj1, vpa
    use le_grids, only: g2le, nxi
    use gs2_layouts, only: le_lo
    use redistribute, only: gather, scatter
    use array_utils, only: zero_array
    implicit none
    complex, dimension (1,1,1) :: dum1 = 0., dum2 = 0.
    real, dimension (:,:,:), allocatable :: gtmp
    real, dimension (:,:,:,:), allocatable :: duinv, dtmp, vns
    integer :: ie, il, ik, is, isgn, iglo,  it, ig
    complex, dimension (:,:,:), allocatable :: ctmp, z_big
    complex, dimension(:,:,:), allocatable :: s0tmp, w0tmp, z0tmp
    logical, parameter :: all_procs = .true.

    if(use_le_layout) then
       allocate (ctmp(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       ! We need to initialise ctmp as it is used as receiving buffer in
       ! g2le redistribute, which doesn't populate all elements
       call zero_array(ctmp)
    end if

    allocate(s0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate(w0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate(z0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))

    allocate (gtmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (duinv(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (dtmp(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (vns(naky,negrid,nspec,3))

    !This initialisation is needed in case kwork_filter is true anywhere
    if (any(kwork_filter)) then
       call zero_array(duinv)
       call zero_array(dtmp)
    end if

    vns(:,:,:,1) = vnmult(1)*vnew_D
    vns(:,:,:,2) = vnmult(1)*vnew_s
    vns(:,:,:,3) = 0.0

    if (drag) then
       do is = 1, nspec
          if (.not. is_electron_species(spec(is))) cycle
          do ik = 1, naky
             vns(ik,:,is,3) = vnmult(1)*spec(is)%vnewk*tunits(ik)/energy**1.5
          end do
       end do

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get z0 (first form)
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
       !$OMP SHARED(g_lo, conservative, z0tmp, code_dt, vns, vpdiff, speed, aj0, vpa) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             ! u0 = -2 nu_D^{ei} vpa J0 dt f0
             if (conservative) then
                z0tmp(:,isgn,iglo) = -2.0*code_dt*vns(ik,ie,is,3)*vpdiff(:,isgn,il) &
                     * speed(ie)*aj0(:,iglo)
             else
                z0tmp(:,isgn,iglo) = -2.0*code_dt*vns(ik,ie,is,3)*vpa(:,isgn,iglo)*aj0(:,iglo)
             end if
          end do
       end do
       !$OMP END PARALLEL DO

       call zero_out_passing_hybrid_electrons(z0tmp)

       if(use_le_layout) then
          call gather (g2le, z0tmp, ctmp)
          call solfp_lorentz (ctmp)
          call scatter (g2le, ctmp, z0tmp)   ! z0 is redefined below
       else
          call solfp_lorentz (z0tmp,dum1,dum2)   ! z0 is redefined below
       end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0z0

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, isgn) &
       !$OMP SHARED(g_lo, vpa, gtmp, aj0, z0tmp) &
       !$OMP COLLAPSE(2) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          do isgn = 1, 2
             ! v0 = vpa J0 f0
             gtmp(:,isgn,iglo) = vpa(:,isgn,iglo) * aj0(:,iglo) * real(z0tmp(:,isgn,iglo))
          end do
       end do
       !$OMP END PARALLEL DO

       call integrate_moment (gtmp, dtmp, all_procs) ! v0z0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine z0 = z0 / (1 + v0z0)

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, ik, it, is, isgn) &
       !$OMP SHARED(g_lo, z0tmp, dtmp) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          it = it_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             z0tmp(:,isgn,iglo) = z0tmp(:,isgn,iglo) / (1.0 + dtmp(:,it,ik,is))
          end do
       end do
       !$OMP END PARALLEL DO

    else
       !If drag is false vns(...,3) is zero and hence z0tmp is zero here.
       call zero_array(z0tmp)
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    ! du == int (E nu_s f_0);  du = du(z, kx, ky, s)
    ! duinv = 1/du
    if (conservative) then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
       !$OMP SHARED(g_lo, gtmp, vns, vpa, vpdiff, speed) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             gtmp(:,isgn,iglo)  = vns(ik,ie,is,1)*vpa(:,isgn,iglo) &
                  * vpdiff(:,isgn,il)*speed(ie)
          end do
       end do
       !$OMP END PARALLEL DO
    else
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
       !$OMP SHARED(g_lo, gtmp, vpa, vns) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          ie = ie_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             gtmp(:,isgn,iglo)  = vns(ik,ie,is,1)*vpa(:,isgn,iglo)**2
          end do
       end do
       !$OMP END PARALLEL DO
    end if

    call integrate_moment (gtmp, duinv, all_procs)  ! not 1/du yet

    ! Could replace this with OpenMP using an explicit loop. TAG
    where (abs(duinv) > epsilon(0.0))  ! necessary b/c some species may have vnewk=0
       !duinv=0 iff vnew=0 so ok to keep duinv=0.
       duinv = 1./duinv  ! now it is 1/du
    end where

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get s0 (first form)
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, ie, il, is, isgn, ig) &
    !$OMP SHARED(g_lo, conservative, s0tmp, vns, vpdiff, speed, aj0, code_dt, duinv, vpa) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! u1 = -3 nu_s vpa dt J0 f_0 / du
          if (conservative) then
             s0tmp(:,isgn,iglo) = -vns(ik,ie,is,1)*vpdiff(:,isgn,il)*speed(ie) &
                  * aj0(:,iglo)*code_dt*duinv(:,it,ik,is)
          else
             s0tmp(:,isgn,iglo) = -vns(ik,ie,is,1)*vpa(:,isgn,iglo) &
                  * aj0(:,iglo)*code_dt*duinv(:,it,ik,is)
          end if
       end do
    end do
    !$OMP END PARALLEL DO

    call zero_out_passing_hybrid_electrons(s0tmp)

    if(use_le_layout) then
       call gather (g2le, s0tmp, ctmp)
       call solfp_lorentz (ctmp)
       call scatter (g2le, ctmp, s0tmp)   ! s0
    else
       call solfp_lorentz (s0tmp,dum1,dum2)   ! s0
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0s0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, isgn) &
    !$OMP SHARED(g_lo, gtmp, vpa, aj0, s0tmp) &
    !$OMP COLLAPSE(2) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       do isgn = 1, 2
          ! v0 = vpa J0 f0
          gtmp(:,isgn,iglo) = vpa(:,isgn,iglo) * aj0(:,iglo) * real(s0tmp(:,isgn,iglo))
       end do
    end do
    !OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v0s0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine s0 = s0 - v0s0 * z0 / (1 + v0z0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, s0tmp, dtmp, z0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          s0tmp(:,isgn,iglo) = s0tmp(:,isgn,iglo) - dtmp(:,it,ik,is) * z0tmp(:,isgn,iglo)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1s0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, conservative, gtmp, vns, speed, vpdiff, aj0, s0tmp, vpa) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = nu_D vpa J0
          if (conservative) then
             gtmp(:,isgn,iglo) = vns(ik,ie,is,1)*speed(ie)*vpdiff(:,isgn,il) &
                  * aj0(:,iglo) * real(s0tmp(:,isgn,iglo))
          else
             gtmp(:,isgn,iglo) = vns(ik,ie,is,1)*vpa(:,isgn,iglo)*aj0(:,iglo) &
                  * real(s0tmp(:,isgn,iglo))
          end if
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v1s0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine s0 = s0 / (1 + v0s0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, s0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          s0tmp(:,isgn,iglo) = s0tmp(:,isgn,iglo) / (1.0 + dtmp(:,it,ik,is))
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get w0
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, ie, il, is, isgn, ig) &
    !$OMP SHARED(g_lo, ntgrid, forbid, w0tmp, vns, energy, al , aj1, code_dt, &
    !$OMP spec, kperp2, duinv, bmag, conserve_forbid_zero) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid,ntgrid
             ! u2 = -3 dt J1 vperp vus a f0 / du
             if ( .not. (forbid(ig,il) .and. conserve_forbid_zero) ) then
                ! Note no conservative branch here, is that right?
                ! Note: energy * al * smz^2 * kperp2 / bmag is alpha^2 where
                ! alpha is the argument to the Bessel function, i.e. aj1 = J1(alpha) / alpha
                ! This appears to leave us with alpha * J1(alpha) whilst Barnes' paper
                ! only includes terms with J1(alpha). Note that alpha = vperp kperp smz/B
                ! so alpha * J1(alpha) = (kperp * smz / B) * vperp J1
                w0tmp(ig, isgn, iglo) = -vns(ik, ie, is, 1) * energy(ie) * al(il) &
                     * aj1(ig, iglo) * code_dt * spec(is)%smz**2 * kperp2(ig, it, ik) * &
                     duinv(ig, it, ik, is) / bmag(ig)
             else
                w0tmp(ig,isgn,iglo) = 0.
             endif
          end do
       end do
    end do
    !$OMP END PARALLEL DO

    call zero_out_passing_hybrid_electrons(w0tmp)

    if(use_le_layout) then
       call gather (g2le, w0tmp, ctmp)
       call solfp_lorentz (ctmp)
       call scatter (g2le, ctmp, w0tmp)
    else
       call solfp_lorentz (w0tmp,dum1,dum2)
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0w0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, isgn) &
    !$OMP SHARED(g_lo, gtmp, vpa, aj0, w0tmp) &
    !$OMP COLLAPSE(2) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       do isgn = 1, 2
          ! v0 = vpa J0 f0
          gtmp(:,isgn,iglo) = vpa(:,isgn,iglo) * aj0(:,iglo) * real(w0tmp(:,isgn,iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v0w0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w0 - v0w0 * z0 / (1 + v0z0) (this is w1 from MAB notes)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, w0tmp, z0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          w0tmp(:,isgn,iglo) = w0tmp(:,isgn,iglo) - z0tmp(:,isgn,iglo)*dtmp(:,it,ik,is)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get v1w1

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, conservative, gtmp, vns, speed, vpdiff, aj0, w0tmp, vpa) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = nud vpa J0 f0
          if (conservative) then
             gtmp(:,isgn,iglo) = vns(ik,ie,is,1)*speed(ie)*vpdiff(:,isgn,il) &
                  * aj0(:,iglo) * real(w0tmp(:,isgn,iglo))
          else
             gtmp(:,isgn,iglo) = vns(ik,ie,is,1)*vpa(:,isgn,iglo)*aj0(:,iglo) &
                  * real(w0tmp(:,isgn,iglo))
          end if
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v1w1

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w1 - v1w1 * s1 / (1 + v1s1) (this is w2 from MAB notes)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, w0tmp, s0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          w0tmp(:,isgn,iglo) = w0tmp(:,isgn,iglo) - s0tmp(:,isgn,iglo)*dtmp(:,it,ik,is)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get v2w2

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vns, energy, al, aj1, w0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v2 = nud vperp J1 f0
          ! Note : aj1 = J1(alpha) / alpha, where we have
          ! alpha^2 = energy * al * smz^2 * kperp2 / bmag = vperp^2 smz^2 kperp2 / bmag^2
          ! so energy * al * aj1 = (B / [kperp2 * smz^2]) alpha^2 aj1
          ! = (B / [kperp2 * smz^2]) alpha J1 = vperp J1 / (kperp * smz). As w0tmp
          ! appears to have an extra factor of kperp * smz / B this _may_ work out to
          ! (vperp J1 / B) * ... Is there an extra factor of 1 / B here?
          gtmp(:, isgn, iglo) = vns(ik, ie, is, 1) * energy(ie) * al(il) * aj1(:, iglo) &
               * real(w0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)   ! v2w2

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w2 / (1 + v2w2)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn, ig) &
    !$OMP SHARED(g_lo, w0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          w0tmp(:,isgn,iglo) = w0tmp(:,isgn,iglo) / (1.0 + dtmp(:,it,ik,is))
       end do
    end do
    !$OMP END PARALLEL DO

    deallocate (gtmp, duinv, dtmp, vns)

    if(use_le_layout) then
       allocate (z_big(-ntgrid:ntgrid, 2, g_lo%llim_proc:g_lo%ulim_alloc))

       ! first set s0le, w0le & z0le
       z_big = cmplx(real(s0tmp), real(w0tmp))

       ! get rid of z0, s0, w0 now that we've converted to z0le, s0le, w0le
       if (allocated(s0tmp)) deallocate(s0tmp)
       if (allocated(w0tmp)) deallocate(w0tmp)

       call gather (g2le, z_big, ctmp)

       if (.not. allocated(s0le)) then
          allocate (s0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
          allocate (w0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       end if

       s0le = real(ctmp)
       w0le = aimag(ctmp)

       ! set z0le
       call gather (g2le, z0tmp, ctmp)
       if (allocated(z0tmp)) deallocate(z0tmp)
       if (.not. allocated(z0le)) allocate (z0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       z0le = real(ctmp)

       deallocate (ctmp, z_big)
    else
       !Only need the real components (imaginary part should be zero, just
       !use complex arrays to allow reuse of existing integrate routines etc.)
       if (.not. allocated(s0)) then
          allocate (s0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          s0=real(s0tmp)
          deallocate(s0tmp)
       endif

       if (.not. allocated(w0)) then
          allocate (w0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          w0=real(w0tmp)
          deallocate(w0tmp)
       endif

       if (.not. allocated(z0)) then
          allocate (z0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          z0=real(z0tmp)
          deallocate(z0tmp)
       endif
    end if

  end subroutine init_lorentz_conserve

  !> Precompute three quantities needed for momentum and energy conservation:
  !> bz0, bw0, bs0
  subroutine init_diffuse_conserve
    use gs2_layouts, only: g_lo, ie_idx, is_idx, ik_idx, il_idx, it_idx
    use species, only: nspec, spec
    use kt_grids, only: naky, ntheta0, kperp2
    use theta_grid, only: ntgrid, bmag
    use le_grids, only: energy => energy_maxw, al, integrate_moment, negrid, forbid
    use gs2_time, only: code_dt
    use dist_fn_arrays, only: aj0, aj1, vpa
    use le_grids, only: g2le, nxi
    use gs2_layouts, only: le_lo
    use redistribute, only: gather, scatter
    use array_utils, only: zero_array
    implicit none
    real, dimension (:,:,:), allocatable :: gtmp
    real, dimension (:,:,:,:), allocatable :: duinv, dtmp, vns
    integer :: ie, il, ik, is, isgn, iglo, it, ig
    complex, dimension (:,:,:), allocatable :: ctmp, z_big
    complex, dimension (:,:,:), allocatable :: bs0tmp, bw0tmp, bz0tmp
    logical, parameter :: all_procs = .true.

    if(use_le_layout) then
       allocate (ctmp(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       ! We need to initialise ctmp as it is used as receiving buffer in
       ! g2le redistribute, which doesn't populate all elements
       call zero_array(ctmp)
    end  if

    allocate(bs0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate(bw0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate(bz0tmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))

    allocate (gtmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (duinv(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (dtmp(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (vns(naky,negrid,nspec,2))

    ! Following might only be needed if any kwork_filter
    call zero_array(duinv)
    call zero_array(dtmp)

    vns(:,:,:,1) = vnmult(2)*delvnew
    vns(:,:,:,2) = vnmult(2)*vnew_s

    ! first obtain 1/du
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, energy, vnmult, vnew_E) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          gtmp(:,isgn,iglo) = energy(ie)*vnmult(2)*vnew_E(ik,ie,is)
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, duinv, all_procs)  ! not 1/du yet

    ! Could replace this with OpenMP using an explicit loop. TAG
    where (abs(duinv) > epsilon(0.0))  ! necessary b/c some species may have vnewk=0
       !  duinv=0 iff vnew=0 so ok to keep duinv=0.
       duinv = 1 / duinv  ! now it is 1/du
    end where

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    ! Now get z0 (first form)
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, ie, il, is, isgn, ig) &
    !$OMP SHARED(g_lo, ntgrid, forbid, bz0tmp, code_dt, vnmult, vnew_E, &
    !$OMP aj0, duinv, conserve_forbid_zero) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid,ntgrid
             ! u0 = -nu_E E dt J0 f_0 / du
             if ( .not. (forbid(ig,il) .and. conserve_forbid_zero) ) then
                bz0tmp(ig,isgn,iglo) = -code_dt*vnmult(2)*vnew_E(ik,ie,is) &
                     * aj0(ig,iglo)*duinv(ig,it,ik,is)
             else
                bz0tmp(ig,isgn,iglo) = 0.
             endif
          end do
       end do
    end do
    !$OMP END PARALLEL DO

    call zero_out_passing_hybrid_electrons(bz0tmp)

    if(use_le_layout) then
       call gather (g2le, bz0tmp, ctmp)
       call solfp_ediffuse_le_layout (ctmp)
       call scatter (g2le, ctmp, bz0tmp)   ! bz0 is redefined below
    else
       call solfp_ediffuse_standard_layout (bz0tmp)   ! bz0 is redefined below
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0z0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vnmult, vnew_E, aj0, bz0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v0 = nu_E E J0 f_0
          gtmp(:,isgn,iglo) = vnmult(2) * vnew_E(ik,ie,is) * aj0(:,iglo) &
               * real(bz0tmp(:,isgn,iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs) ! v0z0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine z0 = z0 / (1 + v0z0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bz0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          bz0tmp(:,isgn,iglo) = bz0tmp(:,isgn,iglo) / (1.0 + dtmp(:,it,ik,is))
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

    ! redefine dq = du (for momentum-conserving terms)
    ! du == int (E nu_s f_0);  du = du(z, kx, ky, s)
    ! duinv = 1/du
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vns, vpa) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          gtmp(:, isgn, iglo)  = vns(ik, ie, is, 1) * vpa(:, isgn, iglo)**2
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, duinv, all_procs)  ! not 1/du yet

    ! Could replace this with OpenMP using an explicit loop. TAG
    where (abs(duinv) > epsilon(0.0))  ! necessary b/c some species may have vnewk=0
       !  duinv=0 iff vnew=0 so ok to keep duinv=0.
       duinv = 1 / duinv  ! now it is 1/du
    end where

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get s0 (first form)
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, ie, is, isgn) &
    !$OMP SHARED(g_lo, bs0tmp, vns, vpa, aj0, code_dt, duinv) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! u1 = -3 nu_s vpa dt J0 f_0 / du
          bs0tmp(:, isgn, iglo) = -vns(ik, ie, is, 1) * vpa(:, isgn, iglo) &
               * aj0(:, iglo) * code_dt * duinv(:, it, ik, is)
       end do
    end do
    !$OMP END PARALLEL DO

    call zero_out_passing_hybrid_electrons(bs0tmp)

    if(use_le_layout) then
       call gather (g2le, bs0tmp, ctmp)
       call solfp_ediffuse_le_layout (ctmp)
       call scatter (g2le, ctmp, bs0tmp)   ! bs0
    else
       call solfp_ediffuse_standard_layout (bs0tmp)    ! s0
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0s0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vnmult, vnew_E, aj0, bs0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v0 = nu_E E J0
          gtmp(:, isgn, iglo) = vnmult(2) * vnew_E(ik, ie, is) * aj0(:, iglo) &
               * real(bs0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v0s0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine s0 = s0 - v0s0 * z0 / (1 + v0z0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bs0tmp, dtmp, bz0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          bs0tmp(:, isgn, iglo) = bs0tmp(:, isgn, iglo) - dtmp(:, it, ik, is) &
               * bz0tmp(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1s0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vns, vpa, aj0, bs0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = (nu_s - nu_D) vpa J0
          gtmp(:, isgn, iglo) = vns(ik, ie, is, 1) * vpa(:, isgn, iglo) * aj0(:, iglo) &
               * real(bs0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v1s0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine s0 = s0 / (1 + v0s0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bs0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          bs0tmp(:, isgn, iglo) = bs0tmp(:, isgn, iglo) / (1.0 + dtmp(:, it, ik, is))
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get w0
    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, ie, il, is, isgn, ig) &
    !$OMP SHARED(g_lo, ntgrid, forbid, bw0tmp, vns, energy, al, aj1, code_dt, &
    !$OMP spec, kperp2, duinv, bmag, conserve_forbid_zero) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          do ig=-ntgrid,ntgrid
             ! u0 = -3 dt J1 vperp vus a f0 / du
             if ( .not. (forbid(ig,il) .and. conserve_forbid_zero)) then
                ! Note: energy * al * smz^2 * kperp2 / bmag is alpha^2 where
                ! alpha is the argument to the Bessel function, i.e. aj1 = J1(alpha) / alpha
                ! This appears to leave us with alpha * J1(alpha) whilst Barnes' paper
                ! only includes terms with J1(alpha). Note that alpha = vperp kperp smz/B
                ! so alpha * J1(alpha) = (kperp * smz / B) * vperp J1
                bw0tmp(ig, isgn, iglo) = -vns(ik, ie, is, 1) * energy(ie) * al(il) &
                     * aj1(ig, iglo) * code_dt * spec(is)%smz**2 * kperp2(ig, it, ik) * &
                     duinv(ig, it, ik, is) / bmag(ig)
             else
                bw0tmp(ig, isgn, iglo) = 0.
             endif
          end do
       end do
    end do
    !$OMP END PARALLEL DO

    call zero_out_passing_hybrid_electrons(bw0tmp)

    if(use_le_layout) then
       call gather (g2le, bw0tmp, ctmp)
       call solfp_ediffuse_le_layout (ctmp)
       call scatter (g2le, ctmp, bw0tmp)
    else
       call solfp_ediffuse_standard_layout (bw0tmp)
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v0w0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vnmult, vnew_E, aj0, bw0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v0 = nu_E E J0
          gtmp(:, isgn, iglo) = vnmult(2) * vnew_E(ik, ie, is) * aj0(:, iglo) &
               * real(bw0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v0w0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w0 - v0w0 * z0 / (1 + v0z0) (this is w1 from MAB notes)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bw0tmp, bz0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          bw0tmp(:, isgn, iglo) = bw0tmp(:, isgn, iglo) - bz0tmp(:, isgn, iglo) &
               * dtmp(:, it, ik, is)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get v1w1

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vns, vpa, aj0, bw0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = (nus-nud) vpa J0 f0
          gtmp(:, isgn, iglo) = vns(ik, ie, is, 1) * vpa(:, isgn, iglo) * aj0(:, iglo) &
               * real(bw0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)    ! v1w1

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w1 - v1w1 * s1 / (1 + v1s1) (this is w2 from MAB notes)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bw0tmp, bs0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          bw0tmp(:, isgn, iglo) = bw0tmp(:, isgn, iglo) - bs0tmp(:, isgn, iglo) &
               * dtmp(:, it, ik, is)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get v2w2

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, gtmp, vns, energy, al, aj1, bw0tmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v2 = (nus-nud) vperp J1 f0
          ! Note : aj1 = J1(alpha) / alpha, where we have
          ! alpha^2 = energy * al * smz^2 * kperp2 / bmag = vperp^2 smz^2 kperp2 / bmag^2
          ! so energy * al * aj1 = (B / [kperp2 * smz^2]) alpha^2 aj1
          ! = (B / [kperp2 * smz^2]) alpha J1 = vperp J1 / (kperp * smz). As w0tmp
          ! appears to have an extra factor of kperp * smz / B this _may_ work out to
          ! (vperp J1 / B) * ... Is there an extra factor of 1 / B here?
          gtmp(:, isgn, iglo) = vns(ik, ie, is, 1) * energy(ie) * al(il) * aj1(:, iglo) &
               * real(bw0tmp(:, isgn, iglo))
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, dtmp, all_procs)   ! v2w2

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Redefine w0 = w2 / (1 + v2w2)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, ik, it, is, isgn) &
    !$OMP SHARED(g_lo, bw0tmp, dtmp) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn=1,2
          bw0tmp(:, isgn, iglo) = bw0tmp(:, isgn, iglo) / (1.0 + dtmp(:, it, ik, is))
       end do
    end do
    !$OMP END PARALLEL DO

    deallocate (gtmp, duinv, dtmp, vns)

    if(use_le_layout) then
       allocate (z_big(-ntgrid:ntgrid, 2, g_lo%llim_proc:g_lo%ulim_alloc))
       z_big=cmplx(real(bs0tmp),real(bw0tmp))
       deallocate (bs0tmp)
       deallocate (bw0tmp)
       call gather (g2le, z_big, ctmp)
       deallocate(z_big)
       if (.not. allocated(bs0le)) allocate (bs0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       if (.not. allocated(bw0le)) allocate (bw0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       bs0le = real(ctmp)
       bw0le = aimag(ctmp)

       call gather (g2le, bz0tmp, ctmp)
       deallocate (bz0tmp)
       if (.not. allocated(bz0le)) allocate (bz0le(nxi+1, negrid+1, le_lo%llim_proc:le_lo%ulim_alloc))
       bz0le = real(ctmp)
       deallocate (ctmp)
    else
       !Only need the real components (imaginary part should be zero, just
       !use complex arrays to allow reuse of existing integrate routines etc.)
       if (.not. allocated(bs0)) then
          allocate (bs0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          bs0=real(bs0tmp)
          deallocate(bs0tmp)
       endif

       if (.not. allocated(bw0)) then
          allocate (bw0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          bw0=real(bw0tmp)
          deallocate(bw0tmp)
       endif

       if (.not. allocated(bz0)) then
          allocate (bz0(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
          bz0=real(bz0tmp)
          deallocate(bz0tmp)
       endif
    end if

  end subroutine init_diffuse_conserve
  
  !> If we have hybrid electrons then we need to remove the
  !> contribution to the conservation terms from the adiabatic
  !> passing part. This routine does this for us.
  subroutine zero_out_passing_hybrid_electrons(arr)
    use gs2_layouts, only: g_lo, ik_idx, il_idx, is_idx
    use species, only: has_hybrid_electron_species, spec
    use le_grids, only: is_passing_hybrid_electron
    implicit none
    complex, dimension(:, :, g_lo%llim_proc:), intent(in out) :: arr
    integer :: iglo, is, ik, il
    if (has_hybrid_electron_species(spec)) then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, ik, il, is) &
       !$OMP SHARED(g_lo, arr) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          ik = ik_idx(g_lo,iglo)
          il = il_idx(g_lo,iglo)
          is = is_idx(g_lo,iglo)
          if (is_passing_hybrid_electron(is, ik, il)) arr(:, :, iglo) = 0.0
       end do
       !$OMP END PARALLEL DO
    end if
  end subroutine zero_out_passing_hybrid_electrons

  !> FIXME : Add documentation
  subroutine init_vnew
    use species, only: nspec, spec, is_electron_species
    use le_grids, only: negrid, energy => energy_maxw, w => w_maxw
    use kt_grids, only: naky, ntheta0, kperp2
    use theta_grid, only: ntgrid
    use run_parameters, only: zeff, delt_option_switch, delt_option_auto
    use gs2_time, only: tunits
    use constants, only: pi, sqrt_pi
    use gs2_save, only: init_vnm
    real, dimension(negrid):: hee, hsg, local_energy
    integer :: ik, ie, is, it
    integer :: istatus
    real :: k4max
    real :: vl, vr, dv2l, dv2r
    real, dimension (2) :: vnm_saved

    if (delt_option_switch == delt_option_auto) then
       call init_vnm (vnm_saved, istatus)
       if (istatus == 0) vnm_init = vnm_saved
    endif

    ! Initialise vnmult from the values in run_parameters. This is
    ! either what has been restored from the restart file  if
    ! `delt_option == 'check_restart'` or 1 otherwise.
    if (all(vnmult == -1)) then
       vnmult = vnm_init
    end if

    if (const_v) then
       local_energy = 1.0
    else
       local_energy = energy
    end if

    do ie = 1, negrid
       hee(ie) = exp(-local_energy(ie))/sqrt(pi*local_energy(ie)) &
            + (1.0 - 0.5/local_energy(ie))*erf(sqrt(local_energy(ie)))
       !>MAB
       ! hsg is the G of Hirshman and Sigmar
       ! added to allow for momentum conservation with energy diffusion
       hsg(ie) = hsg_func(sqrt(local_energy(ie)))
       !<MAB
    end do

    if (.not.allocated(vnew)) then
       ! Here the ky dependence is just to allow us to scale by tunits.
       ! This saves some operations in later use at expense of more memory
       allocate (vnew(naky,negrid,nspec))
       allocate (vnew_s(naky,negrid,nspec))
       allocate (vnew_D(naky,negrid,nspec))
       allocate (vnew_E(naky,negrid,nspec))
       allocate (delvnew(naky,negrid,nspec))
    end if

    if (ei_coll_only) then
       vnew_s = 0 ; vnew_D = 0 ; vnew_E = 0 ; delvnew = 0.
       do is = 1, nspec
          if (is_electron_species(spec(is))) then
             do ie = 1, negrid
                vnew(:, ie, is) = spec(is)%vnewk / local_energy(ie)**1.5 &
                        * zeff * 0.5 * tunits(:)
             end do
          else
             vnew(:, :, is) = 0.0
          end if
       end do
    else
       do is = 1, nspec
          ! Don't include following lines, as haven't yet accounted for possibility of timestep changing.
          ! Correct collision frequency if applying collisions only every nth timestep:
          !spec(is)%vnewk = spec(is)%vnewk * float(timesteps_between_collisions)

          do ie = 1, negrid
             vnew_s(:, ie, is) = spec(is)%vnewk / sqrt(local_energy(ie)) &
                  * hsg(ie) * 4.0 * tunits(:)
          end do

          if (is_electron_species(spec(is))) then
             do ie = 1, negrid
                vnew(:, ie, is) = spec(is)%vnewk / local_energy(ie)**1.5 &
                     * (zeff + hee(ie)) * 0.5 * tunits(:)
                vnew_D(:, ie, is) = spec(is)%vnewk / local_energy(ie)**1.5 &
                     * hee(ie) * tunits(:)
             end do
          else
             do ie = 1, negrid
                vnew(:, ie, is) = spec(is)%vnewk / local_energy(ie)**1.5 &
                     * hee(ie) * 0.5 * tunits(:)
                vnew_D(:, ie, is) = 2.0 * vnew(:, ie, is)
             end do
          end if
       end do
    end if

    if (conservative) then !Presumably not compatible with const_v so we use energy
       do is = 1, nspec
          do ie = 2, negrid-1
             vr = 0.5*(sqrt(energy(ie+1)) + sqrt(energy(ie)))
             vl = 0.5*(sqrt(energy(ie  )) + sqrt(energy(ie-1)))
             dv2r = (energy(ie+1) - energy(ie)) / (sqrt(energy(ie+1)) - sqrt(energy(ie)))
             dv2l = (energy(ie) - energy(ie-1)) / (sqrt(energy(ie)) - sqrt(energy(ie-1)))

             vnew_E(:, ie, is) =  spec(is)%vnewk * tunits * &
                  vnew_E_conservative(vr, vl, dv2r, dv2l) / (sqrt_pi * w(ie))
             delvnew(:, ie, is) = spec(is)%vnewk * tunits * &
                  delvnew_conservative(vr, vl) / (sqrt_pi * sqrt(energy(ie)) * w(ie))
          end do

          ! boundary at v = 0
          ie = 1
          vr = 0.5 * (sqrt(energy(ie+1)) + sqrt(energy(ie)))
          vl = 0.0
          dv2r = (energy(ie+1) - energy(ie)) / (sqrt(energy(ie+1)) - sqrt(energy(ie)))
          dv2l = 0.0
          vnew_E(:, ie, is) =  spec(is)%vnewk * tunits * &
               vnew_E_conservative(vr, vl, dv2r, dv2l) / (sqrt_pi * w(ie))
          delvnew(:, ie, is) = spec(is)%vnewk * tunits * &
               delvnew_conservative(vr, vl) / (sqrt_pi * sqrt(energy(ie)) * w(ie))

          ! boundary at v -> infinity
          ie = negrid
          vr = 0.0
          vl = 0.5 * (sqrt(energy(ie)) + sqrt(energy(ie-1)))
          dv2r = 0.0
          dv2l = (energy(ie) - energy(ie-1)) / (sqrt(energy(ie)) - sqrt(energy(ie-1)))
          vnew_E(:, ie, is) =  spec(is)%vnewk * tunits * &
               vnew_E_conservative(vr, vl, dv2r, dv2l) / (sqrt_pi * w(ie))
          delvnew(:, ie, is) = spec(is)%vnewk * tunits * &
               delvnew_conservative(vr, vl) / (sqrt_pi * sqrt(energy(ie)) * w(ie))
       end do
    else
       do is = 1, nspec
          do ie = 2, negrid-1
             vnew_E(:, ie, is) = local_energy(ie) * (vnew_s(:, ie, is) * &
                  (2.0 - 0.5 / local_energy(ie)) - 2.0 * vnew_D(:, ie, is))
             delvnew(:, ie, is) = vnew_s(:, ie, is) - vnew_D(:, ie, is)
          end do
       end do
    end if

       ! add hyper-terms inside collision operator
!BD: Warning!
!BD: For finite magnetic shear, this is different in form from what appears in hyper.f90
!BD: because kperp2 /= akx**2 + aky**2;  there are cross terms that are dropped in hyper.f90
!BD: Warning!
!BD: Also: there is no "grid_norm" option here and the exponent is fixed to 4 for now
    if (hyper_colls) then
       if (.not. allocated(vnewh)) allocate (vnewh(-ntgrid:ntgrid,ntheta0,naky,nspec))
       k4max = (maxval(kperp2))**2
       do is = 1, nspec
          do ik = 1, naky
             do it = 1, ntheta0
                vnewh(:,it,ik,is) = spec(is)%nu_h * kperp2(:,it,ik)**2/k4max
             end do
          end do
       end do
    end if

  contains
    elemental real function vnew_E_conservative(vr, vl, dv2r, dv2l) result(vnE)
      real, intent(in) :: vr, vl, dv2l, dv2r
      vnE = (vl * exp(-vl**2) * dv2l * hsg_func(vl) - vr * exp(-vr**2) * dv2r * hsg_func(vr))
    end function vnew_E_conservative

    elemental real function delvnew_conservative(vr, vl) result(delVn)
      real, intent(in) :: vr, vl
      delVn = (vl * exp(-vl**2) * hsg_func(vl) - vr * exp(-vr**2) * hsg_func(vr))
    end function delvnew_conservative

  end subroutine init_vnew

  !> FIXME : Add documentation  
  elemental real function hsg_func (vel)
    use constants, only: sqrt_pi
    implicit none
    real, intent (in) :: vel
    if (abs(vel) <= epsilon(0.0)) then
       hsg_func = 0.0
    else
       hsg_func = 0.5 * erf(vel) / vel**2 - exp(-vel**2) / (sqrt_pi * vel)
    end if
  end function hsg_func

  !> FIXME : Add documentation
  subroutine init_ediffuse (vnmult_target)
    use le_grids, only: negrid, nxi, forbid, ixi_to_il
    use egrid, only: zeroes => zeroes_maxw, x0 => x0_maxw
    use gs2_layouts, only: le_lo, e_lo, il_idx
    use gs2_layouts, only: ig_idx, it_idx, ik_idx, is_idx
    use array_utils, only: zero_array
    implicit none
    real, intent (in), optional :: vnmult_target
    integer :: ie, is, ik, il, ig, it, ile, ixi, ielo
    real, dimension (:), allocatable :: aa, bb, cc, xe

    allocate (aa(negrid), bb(negrid), cc(negrid), xe(negrid))

    ! want to use x variables instead of e because we want conservative form
    ! for the x-integration
    xe(1:negrid-1) = zeroes
    xe(negrid) = x0

    if (present(vnmult_target)) then
       vnmult(2) = max (vnmult_target, 1.0)
    end if

    if(use_le_layout) then

       if (.not.allocated(ec1le)) then
          allocate (ec1le   (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          allocate (ebetaale(nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          allocate (eqle    (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          vnmult(2) = max(1.0, vnmult(2))
       end if
       call zero_array(ec1le)
       call zero_array(ebetaale)
       call zero_array(eqle)

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ile, ik, it, is, ig, ie, ixi, il, aa, bb, cc) &
       !$OMP SHARED(le_lo, nxi_lim, forbid, ec1le, ebetaale, negrid, eqle, ixi_to_il, xe) &
       !$OMP SCHEDULE(static)
       do ile = le_lo%llim_proc, le_lo%ulim_proc
          ik = ik_idx(le_lo, ile)
          it = it_idx(le_lo, ile)
          is = is_idx(le_lo, ile)
          ig = ig_idx(le_lo, ile)
          do ixi = 1, nxi_lim
             il = ixi_to_il(ixi, ig)
             if (forbid(ig, il)) cycle
             call get_ediffuse_matrix (aa, bb, cc, ig, ik, it, il, is, xe)
             ec1le(ixi, :, ile) = cc
             ebetaale(ixi, 1, ile) = 1.0 / bb(1)
             do ie = 1,  negrid - 1
                eqle(ixi, ie+1, ile) = aa(ie+1) * ebetaale(ixi, ie, ile)
                ebetaale(ixi, ie+1, ile) = 1.0 / ( bb(ie+1) - eqle(ixi, ie+1, ile) &
                     * ec1le(ixi, ie, ile) )
             end do
          end do
       end do
       !$OMP END PARALLEL DO
    else

       if (.not.allocated(ec1)) then
          allocate (ec1    (negrid,e_lo%llim_proc:e_lo%ulim_alloc))
          allocate (ebetaa (negrid,e_lo%llim_proc:e_lo%ulim_alloc))
          allocate (eql    (negrid,e_lo%llim_proc:e_lo%ulim_alloc))
          vnmult(2) = max(1.0, vnmult(2))
       endif
       call zero_array(ec1)
       call zero_array(ebetaa)
       call zero_array(eql)

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ielo, ik, it, is, ig, ie, ixi, il, aa, bb, cc) &
       !$OMP SHARED(e_lo, forbid, xe, ec1, ebetaa, negrid, eql) &
       !$OMP SCHEDULE(static)
       do ielo = e_lo%llim_proc, e_lo%ulim_proc
          il = il_idx(e_lo, ielo)
          ig = ig_idx(e_lo, ielo)
          if (forbid(ig, il)) cycle
          is = is_idx(e_lo, ielo)
          ik = ik_idx(e_lo, ielo)
          it = it_idx(e_lo, ielo)
          call get_ediffuse_matrix (aa, bb, cc, ig, ik, it, il, is, xe)
          ec1(:, ielo) = cc

          ! fill in the arrays for the tridiagonal
          ebetaa(1, ielo) = 1.0 / bb(1)
          do ie = 1, negrid - 1
             eql(ie+1, ielo) = aa(ie+1) * ebetaa(ie, ielo)
             ebetaa(ie+1, ielo) = 1.0 / (bb(ie+1) - eql(ie+1, ielo) * ec1(ie, ielo))
          end do
       end do
       !$OMP END PARALLEL DO
    end if

    deallocate(aa, bb, cc, xe)

  end subroutine init_ediffuse

  !> FIXME : Add documentation
  subroutine get_ediffuse_matrix (aa, bb, cc, ig, ik, it, il, is, xe)
    use species, only: spec
    use theta_grid, only: bmag
    use le_grids, only: al, negrid, w=>w_maxw
    use kt_grids, only: kperp2
    use gs2_time, only: code_dt, tunits

    implicit none

    integer, intent (in) :: ig, ik, it, il, is
    real, dimension (:), intent (in) :: xe
    real, dimension (:), intent (out) :: aa, bb, cc

    integer :: ie
    real :: vn, slb1, xe0, xe1, xe2, xel, xer, capgr, capgl, ee

    vn = vnmult(2) * spec(is)%vnewk * tunits(ik)

    slb1 = safe_sqrt(1.0 - bmag(ig)*al(il))     ! xi_j

    select case (ediff_switch)
    case (ediff_scheme_default)
       do ie = 2, negrid - 1
          xe0 = xe(ie-1)
          xe1 = xe(ie)
          xe2 = xe(ie+1)

          xel = (xe0 + xe1) * 0.5
          xer = (xe1 + xe2) * 0.5

          capgr = capg(xer)
          capgl = capg(xel)

          ee = 0.125 * (1 - slb1**2) * vnew_s(ik,ie,is) * kperp2(ig,it,ik)*cfac &
               / (bmag(ig) * spec(is)%zstm)**2

          ! coefficients for tridiagonal matrix:
          cc(ie) = -0.25 * vn * code_dt * capgr / (w(ie) * (xe2 - xe1))
          aa(ie) = -0.25 * vn * code_dt * capgl / (w(ie) * (xe1 - xe0))
          bb(ie) = 1.0 - (aa(ie) + cc(ie)) + ee*code_dt
       end do

       ! boundary at v = 0
       xe1 = xe(1)
       xe2 = xe(2)
       xer = (xe1 + xe2) * 0.5

       capgr = capg(xer)

       ee = 0.125 * (1 - slb1**2) * vnew_s(ik,1,is) * kperp2(ig,it,ik)*cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(1) = -0.25 * vn * code_dt * capgr / (w(1) * (xe2 - xe1))
       aa(1) = 0.0
       bb(1) = 1.0 - cc(1) + ee * code_dt

       ! boundary at v = infinity
       xe0 = xe(negrid-1)
       xe1 = xe(negrid)

       xel = (xe1 + xe0) * 0.5

       capgl = capg(xel)

       ee = 0.125 * (1.-slb1**2) * vnew_s(ik,negrid,is) * kperp2(ig,it,ik) * cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(negrid) = 0.0
       aa(negrid) = -0.25 * vn * code_dt * capgl / (w(negrid) * (xe1 - xe0))
       bb(negrid) = 1.0 - aa(negrid) + ee*code_dt

    case (ediff_scheme_old)

       ! non-conservative scheme
       do ie = 2, negrid-1
          xe0 = xe(ie-1)
          xe1 = xe(ie)
          xe2 = xe(ie+1)

          xel = (xe0 + xe1) * 0.5
          xer = (xe1 + xe2) * 0.5

          capgr = capg_old(xe1, xer)
          capgl = capg_old(xe1, xel)

          ee = 0.125 * (1 - slb1**2) * vnew_s(ik,ie,is) * kperp2(ig,it,ik) * cfac &
               / (bmag(ig) * spec(is)%zstm)**2

          ! coefficients for tridiagonal matrix:
          cc(ie) = -vn * code_dt * capgr / ((xer-xel) * (xe2 - xe1))
          aa(ie) = -vn * code_dt * capgl / ((xer-xel) * (xe1 - xe0))
          bb(ie) = 1.0 - (aa(ie) + cc(ie)) + ee * code_dt
       end do

       ! boundary at xe = 0
       xe1 = xe(1)
       xe2 = xe(2)
       xer = (xe1 + xe2) * 0.5

       capgr = capg_old(xe1, xer)

       ee = 0.125 * (1 - slb1**2) * vnew_s(ik,1,is) * kperp2(ig,it,ik)*cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(1) = -vn * code_dt * capgr / (xer * (xe2 - xe1))
       aa(1) = 0.0
       bb(1) = 1.0 - cc(1) + ee * code_dt

       ! boundary at xe = 1
       xe0 = xe(negrid-1)
       xe1 = xe(negrid)
       xel = (xe1 + xe0) * 0.5

       capgl = capg_old(xe1, xel)

       ee = 0.125 * (1 - slb1**2) * vnew_s(ik,negrid,is) * kperp2(ig,it,ik)*cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(negrid) = 0.0
       aa(negrid) = -vn * code_dt * capgl / ((1.0-xel) * (xe1 - xe0))
       bb(negrid) = 1.0 - aa(negrid) + ee * code_dt
    end select

  contains
    elemental real function capg_kernel(xe)
      use constants, only: sqrt_pi
      real, intent(in) :: xe
      ! This is the same as hsg_func(xe) * 2 * xe**2
      capg_kernel =  erf(xe) - 2 * xe * exp(-xe**2) / sqrt_pi
    end function capg_kernel

    elemental real function capg(xe)
      use constants, only: sqrt_pi
      real, intent(in) :: xe
      capg = 2 * exp(-xe**2) * capg_kernel(xe) / (xe * sqrt_pi)
    end function capg

    elemental real function capg_old(xeA, xeB)
      real, intent(in) :: xeA, xeB
      capg_old = (0.5 * exp(xeA**2-xeB**2) / xeA**2) * capg_kernel(xeB) / xeB
    end function capg_old
  end subroutine get_ediffuse_matrix

  !> FIXME : Add documentation
  subroutine init_lorentz (vnmult_target)
    use le_grids, only: negrid, jend, ng2, nxi, nlambda, al, il_is_passing, il_is_wfb
    use le_grids, only: setup_trapped_lambda_grids_old_finite_difference, setup_passing_lambda_grids
    use gs2_layouts, only: ig_idx, ik_idx, ie_idx, is_idx, it_idx, lz_lo, le_lo
    use theta_grid, only: ntgrid
    use species, only: has_hybrid_electron_species, spec
    use array_utils, only: zero_array
    implicit none
    real, intent (in), optional :: vnmult_target
    integer :: ig, ixi, it, ik, ie, is, je, te2, ile, ilz
    real, dimension (:), allocatable :: aa, bb, cc, dd, hh
    real, dimension (:, :), allocatable :: local_weights
    allocate (aa(nxi+1), bb(nxi+1), cc(nxi+1), dd(nxi+1), hh(nxi+1))

    if (.not. allocated(pitch_weights)) then
       ! Start by just copying the existing weights
       allocate(local_weights(-ntgrid:ntgrid, nlambda))
       local_weights = 0.0
       ! We have to recaculate the passing weights here as when Radau-Gauss
       ! grids are used the WFB (il=ng2+1) is treated as both passing and
       ! trapped. Otherwise we could just copy the passing weights from wl.
       call setup_passing_lambda_grids(al, local_weights)
       call setup_trapped_lambda_grids_old_finite_difference(al, local_weights)
       allocate(pitch_weights(nlambda, -ntgrid:ntgrid))
       pitch_weights = transpose(local_weights)
    end if

    call init_vpdiff

    if(use_le_layout) then
       if (.not.allocated(c1le)) then
          allocate (c1le    (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          allocate (betaale (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          allocate (qle     (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
          if (heating) then
             allocate (d1le    (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
             allocate (h1le    (nxi+1, negrid, le_lo%llim_proc:le_lo%ulim_alloc))
             call zero_array(d1le)
             call zero_array(h1le)
          end if
          vnmult(1) = max(1.0, vnmult(1))
       endif
       call zero_array(c1le)
       call zero_array(betaale)
       call zero_array(qle)

       if (present(vnmult_target)) then
          vnmult(1) = max (vnmult_target, 1.0)
       end if

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ile, ik, it, is, ig, je, te2, ie, aa, bb, cc, dd, hh, ixi) &
       !$OMP SHARED(le_lo, jend, special_wfb_lorentz, ng2, negrid, spec, c1le, &
       !$OMP d1le, h1le, qle, betaale) &
       !$OMP SCHEDULE(static)
       do ile = le_lo%llim_proc, le_lo%ulim_proc
          ig = ig_idx(le_lo,ile)
          ik = ik_idx(le_lo,ile)
          it = it_idx(le_lo,ile)
          is = is_idx(le_lo,ile)
          je = jend(ig)
          !if (je <= ng2+1) then
          ! MRH the above line is likely the original cause of the wfb bug in collisions
          ! by treating collisions arrays here as if there are only passing particles
          ! when there are only passing particles plus wfb (and no other trapped particles)
          ! introduced an inconsistency in the tri-diagonal solve for theta =+/- pi
          ! Fixed below.
          ! Here te2 is the total number of unique valid xi at this point.
          ! For passing particles this is 2*ng2 whilst for trapped we subtract one
          ! from the number of valid pitch angles due to the "degenerate" duplicate
          ! vpar = 0 point.
          if (il_is_passing(je) .or. (il_is_wfb(je) .and. special_wfb_lorentz) ) then ! MRH
             te2 = 2 * ng2
          else
             te2 = 2 * je - 1
          end if
          do ie = 1, negrid

             call get_lorentz_matrix (aa, bb, cc, dd, hh, ig, ik, it, ie, is)

             if (has_hybrid_electron_species(spec)) &
                  call set_hzero_lorentz_collisions_matrix(aa, bb, cc, je, ik, is)

             c1le(:, ie, ile) = cc
             if (allocated(d1le)) then
                d1le(:, ie, ile) = dd
                h1le(:, ie, ile) = hh
             end if

             qle(1, ie, ile) = 0.0
             betaale(1, ie, ile) = 1.0 / bb(1)
             do ixi = 1, te2-1
                qle(ixi+1, ie, ile) = aa(ixi+1) * betaale(ixi, ie, ile)
                betaale(ixi+1, ie, ile) = 1.0 / ( bb(ixi+1) - qle(ixi+1, ie, ile) &
                     * c1le(ixi, ie, ile) )
             end do
             qle(te2+1:, ie, ile) = 0.0
             betaale(te2+1:, ie, ile) = 0.0
          end do
       end do
       !$OMP END PARALLEL DO
    else

       if (.not.allocated(c1)) then
          allocate (c1(nxi+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
          allocate (betaa(nxi+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
          allocate (ql(nxi+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
          if (heating) then
             allocate (d1   (nxi+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
             allocate (h1   (nxi+1,lz_lo%llim_proc:lz_lo%ulim_alloc))
             call zero_array(d1)
             call zero_array(h1)
          end if
          vnmult(1) = max(1.0, vnmult(1))
       end if

       call zero_array(c1)
       call zero_array(betaa)
       call zero_array(ql)

       if (present(vnmult_target)) then
          vnmult(1) = max (vnmult_target, 1.0)
       end if

       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ilz, ik, it, is, ig, je, te2, ie, aa, bb, cc, dd, hh, ixi) &
       !$OMP SHARED(lz_lo, jend, special_wfb_lorentz, spec, c1, d1, h1, betaa, ql, ng2) &
       !$OMP SCHEDULE(static)
       do ilz = lz_lo%llim_proc, lz_lo%ulim_proc
          is = is_idx(lz_lo,ilz)
          ik = ik_idx(lz_lo,ilz)
          it = it_idx(lz_lo,ilz)
          ie = ie_idx(lz_lo,ilz)
          ig = ig_idx(lz_lo,ilz)
          je = jend(ig)
          !if (je <= ng2+1) then
          ! MRH the above line is likely the original cause of the wfb bug in collisions
          ! by treating collisions arrays here as if there are only passing particles
          ! when there are only passing particles plus wfb (and no other trapped particles)
          ! introduced an inconsistency in the tri-diagonal solve for theta =+/- pi
          ! Fixed below
          if (il_is_passing(je) .or. (il_is_wfb(je) .and. special_wfb_lorentz) ) then ! MRH
             te2 = 2 * ng2
          else
             te2 = 2 * je - 1
          end if

          call get_lorentz_matrix (aa, bb, cc, dd, hh, ig, ik, it, ie, is)

          if (has_hybrid_electron_species(spec)) &
               call set_hzero_lorentz_collisions_matrix(aa, bb, cc, je, ik, is)

          c1(:, ilz) = cc
          if (allocated(d1)) then
             d1(:, ilz) = dd
             h1(:, ilz) = hh
          end if

          ql(1, ilz) = 0.0
          betaa(1, ilz) = 1.0 / bb(1)
          do ixi = 1, te2-1
             ql(ixi+1, ilz) = aa(ixi+1) * betaa(ixi, ilz)
             betaa(ixi+1, ilz) = 1.0 / (bb(ixi+1) - ql(ixi+1, ilz) * c1(ixi, ilz))
          end do
          ql(te2+1:, ilz) = 0.0
          c1(te2+1:, ilz) = 0.0
          betaa(te2+1:, ilz) = 0.0
       end do
       !$OMP END PARALLEL DO
    end if

    deallocate (aa, bb, cc, dd, hh)

  end subroutine init_lorentz

  !> Special behaviour when h=0 for passing non-zonal electrons
  !>
  !> The effect of these changes is to exclude passing electrons
  !> From pitch angle scattering, and to enforce 0 passing as a
  !> boundary condition For the trapped particle pitch angle
  !> scattering.
  subroutine set_hzero_lorentz_collisions_matrix(aa, bb, cc, je, ik, is)
    use species, only: is_hybrid_electron_species, spec
    use kt_grids, only: aky
    use le_grids, only: ng2
    implicit none
    real, dimension(:), intent(in out) :: aa, bb, cc
    integer, intent(in) :: je, ik, is
    !> il index of 1st non-wfb trapped particle
    integer :: il_llim
    !> il index of last non-wfb trapped particle
    integer :: il_ulim

    il_llim = ng2 + 2
    il_ulim = 2*je-1 - (ng2+ 1)

    if ( is_hybrid_electron_species(spec(is)) .and. aky(ik) /= 0.) then
       aa(:il_llim) = 0.
       bb(:il_llim-1) = 1.
       cc(:il_llim-1) = 0.
       aa(il_ulim+1:) = 0.
       bb(il_ulim+1:) = 1.
       cc(il_ulim:) = 0.
    endif
  end subroutine set_hzero_lorentz_collisions_matrix

  !> FIXME : Add documentation
  subroutine get_lorentz_matrix (aa, bb, cc, dd, hh, ig, ik, it, ie, is)
    use species, only: spec
    use le_grids, only: al, energy => energy_maxw, xi, ng2
    use le_grids, only: jend, al, il_is_passing, il_is_wfb
    use gs2_time, only: code_dt, tunits
    use kt_grids, only: kperp2
    use theta_grid, only: bmag
    implicit none
    real, dimension (:), intent (out) :: aa, bb, cc, dd, hh
    integer, intent (in) :: ig, ik, it, ie, is
    integer :: il, je, te, te2, teh
    real :: slb0, slb1, slb2, slbl, slbr, vhyp, vn, vnh, vnc, ee, deltaxi

    je = jend(ig)
!
!CMR, 17/2/2014:
!         te, te2, teh: indices in xi, which runs from +1 -> -1.
!         te   :  index of minimum xi value >= 0.
!         te2  :  total #xi values = index of minimum xi value (= -1)
!         teh  :  index of minimum xi value > 0.
!                 teh = te if no bouncing particle at this location
!              OR teh = te-1 if there is a bouncing particle
!
    if (il_is_passing(je) .or. (il_is_wfb(je) .and. special_wfb_lorentz)) then
       !CMRDDGC, 17/2/2014:
       !   This clause is appropriate for Lorentz collisons with
       !         SPECIAL (unphysical) treatment of wfb at its bounce point
       te = ng2
       te2 = 2*ng2
       teh = ng2
    else
       !CMRDDGC, 17/2/2014:
       !   This clause is appropriate for Lorentz collisons with
       !         STANDARD treatment of wfb at its bounce point
       te = je
       te2 = 2*je-1
       teh = je-1
    end if

    if (collision_model_switch == collision_model_lorentz_test) then
       vn = vnmult(1) * abs(spec(is)%vnewk) * tunits(ik)
       vnc = 0.
       vnh = 0.
    else
       if (hyper_colls) then
          vhyp = vnewh(ig, it, ik, is)
       else
          vhyp = 0.0
       end if
       ! vnc and vnh only needed when heating is true
       vnc = vnmult(1) * vnew(ik, ie, is)
       if (hypermult) then
          vn = vnmult(1) * vnew(ik, ie, is) * (1 + vhyp)
          vnh = vhyp * vnc
       else
          vn = vnmult(1) * vnew(ik, ie, is) + vhyp
          vnh = vhyp
       end if
    end if

    aa = 0.0 ; bb = 0.0 ; cc = 0.0 ; dd = 0.0 ; hh = 0.0

    select case (lorentz_switch)
    case (lorentz_scheme_default)

       do il = 2, te-1
          slb0 = safe_sqrt(1.0 - bmag(ig) * al(il-1))
          slb1 = safe_sqrt(1.0 - bmag(ig) * al(il))
          slb2 = safe_sqrt(1.0 - bmag(ig) * al(il+1))

          slbl = (slb1 + slb0) * 0.5  ! xi(j-1/2)
          slbr = (slb1 + slb2) * 0.5  ! xi(j+1/2)

          ee = 0.5 * energy(ie)*(1 + slb1**2) * kperp2(ig,it,ik) * cfac &
               / (bmag(ig) * spec(is)%zstm)**2

          ! coefficients for tridiagonal matrix:
          cc(il) = 2.0 * vn * code_dt * (1 - slbr**2) / &
               (pitch_weights(il, ig) * (slb2 - slb1))
          aa(il) = 2.0 * vn * code_dt * (1 - slbl**2) / &
               (pitch_weights(il, ig) * (slb1 - slb0))
          bb(il) = 1.0 - (aa(il) + cc(il)) + ee * vn * code_dt

          ! coefficients for entropy heating calculation
          if (heating) then
             dd(il) = vnc * (-2.0 * (1 - slbr**2) / &
                  (pitch_weights(il, ig) * (slb2 - slb1)) + ee)
             hh(il) = vnh * (-2.0 * (1 - slbr**2) / &
                  (pitch_weights(il, ig) * (slb2 - slb1)) + ee)
          end if
       end do

       ! boundary at xi = 1
       slb1 = safe_sqrt(1.0 - bmag(ig) * al(1))
       slb2 = safe_sqrt(1.0 - bmag(ig) * al(2))

       slbr = (slb1 + slb2) * 0.5

       ee = 0.5 * energy(ie) * (1 + slb1**2) * kperp2(ig,it,ik) * cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(1) = 2.0 * vn * code_dt * (1.0 - slbr**2) &
            / (pitch_weights(1, ig) * (slb2 - slb1))
       aa(1) = 0.0
       bb(1) = 1.0 - cc(1) + ee * vn * code_dt

       if (heating) then
          dd(1) = vnc * (-2.0 * (1 - slbr**2) &
               / (pitch_weights(1, ig) * (slb2 - slb1)) + ee)
          hh(1) = vnh * (-2.0 * (1 - slbr**2) &
               / (pitch_weights(1, ig) * (slb2 - slb1)) + ee)
       end if

       ! boundary at xi = 0
       il = te
       slb0 = safe_sqrt(1.0 - bmag(ig) * al(il-1))
       if (te == ng2) then
          slb1 = safe_sqrt(1.0 - bmag(ig) * al(il))
          slb2 = -slb1
       else
          slb1 = 0.0
          slb2 = -slb0
       end if

       slbl = (slb1 + slb0) * 0.5
       slbr = (slb1 + slb2) * 0.5

       ee = 0.5 * energy(ie) * (1 + slb1**2) * kperp2(ig,it,ik) * cfac &
            / (bmag(ig) * spec(is)%zstm)**2

!CMR, 6/3/2014:
! STANDARD treatment of pitch angle scattering must resolve T-P boundary.
! NEED special_wfb= .false. to resolve T-P boundary at wfb bounce point
!     (special_wfb= .true. AVOIDS TP boundary at wfb bounce point)
!
! Original code (pre-r2766) used eq.(42) Barnes et al, Phys Plasmas 16, 072107
! (2009), with pitch angle weights to enhance accuracy in derivatives.
! NB THIS FAILS at wfb bounce point, giving aa=cc=infinite,
!    because weight wl(ig,il)=0 at wfb bounce point.
!    UNPHYSICAL as d/dxi ((1-xi^2)g) IS NOT resolved numerically for wl=0.
! MUST accept limitations of the grid resolution and USE FINITE coefficients!
! FIX here by setting a FINITE width of the trapped region at wfb B-P
!              deltaxi=xi(ig,ng2)-xi(ig,ng2+2)
! ASIDE: NB    deltaxi=wl is actually 2*spacing in xi !!!
!              which explains upfront factor 2 in definition of aa, cc
       deltaxi = pitch_weights(il, ig)
       if (te == je) deltaxi = 2 * deltaxi ! MRH vpar = 0 fix
       ! MRH appropriate when endpoint (te) is a vpar = 0 point (je)
       ! factor of 2 required above because xi=0 is a repeated point
       ! on vpar > 0, vpar < 0 grids, and hence has half the weight associated
       ! to it at each appearance compared to what the weight should be
       ! as calculated in the continuum of points xi = (-1,1)
       if ((.not. special_wfb_lorentz) .and. (deltaxi == 0) .and. il_is_wfb(il)) then
          deltaxi = xi(ng2, ig) - xi(ng2 + 2, ig)
       endif
       cc(il) = 2.0 * vn * code_dt * (1 - slbr**2) / (deltaxi * (slb2 - slb1))
       aa(il) = 2.0 * vn * code_dt * (1 - slbl**2) / (deltaxi * (slb1 - slb0))
       bb(il) = 1.0 - (aa(il) + cc(il)) + ee * vn * code_dt

       if (heating) then
          dd(il) = vnc * (-2.0 * (1.0 - slbr**2) / (deltaxi * (slb2 - slb1)) + ee)
          hh(il) = vnh * (-2.0 * (1.0 - slbr**2) / (deltaxi * (slb2 - slb1)) + ee)
       end if
!CMRend

    case (lorentz_scheme_old)
       do il = 2, te-1
          slb0 = safe_sqrt(1.0 - bmag(ig) * al(il-1))
          slb1 = safe_sqrt(1.0 - bmag(ig) * al(il))
          slb2 = safe_sqrt(1.0 - bmag(ig) * al(il+1))

          slbl = (slb1 + slb0) * 0.5  ! xi(j-1/2)
          slbr = (slb1 + slb2) * 0.5  ! xi(j+1/2)

          ee = 0.5 * energy(ie) * (1 + slb1**2) * kperp2(ig,it,ik) * cfac &
               / (bmag(ig) * spec(is)%zstm)**2

          ! coefficients for tridiagonal matrix:
          cc(il) = -vn * code_dt * (1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1))
          aa(il) = -vn * code_dt * (1 - slbl**2) / ((slbr - slbl) * (slb1 - slb0))
          bb(il) = 1.0 - (aa(il) + cc(il)) + ee * vn * code_dt

          ! coefficients for entropy heating calculation
          if (heating) then
             dd(il) = vnc * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
             hh(il) = vnh * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
          end if
       end do

       ! boundary at xi = 1
       slb0 = 1.0
       slb1 = safe_sqrt(1.0 - bmag(ig) * al(1))
       slb2 = safe_sqrt(1.0 - bmag(ig) * al(2))

       slbl = (slb1 + slb0) * 0.5
       slbr = (slb1 + slb2) * 0.5

       ee = 0.5 * energy(ie) * (1 + slb1**2) * kperp2(ig,it,ik) * cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(1) = -vn * code_dt * (-1.0 - slbr) / (slb2 - slb1)
       aa(1) = 0.0
       bb(1) = 1.0 - (aa(1) + cc(1)) + ee * vn * code_dt

       if (heating) then
          dd(1) = vnc * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
          hh(1) = vnh * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
       end if

       ! boundary at xi = 0
       il = te
       slb0 = safe_sqrt(1.0 - bmag(ig) * al(il-1))
       if (te == ng2) then
          slb1 = safe_sqrt(1.0 - bmag(ig) * al(il))
          slb2 = -slb1
       else
          slb1 = 0.0
          slb2 = -slb0
       end if

       slbl = (slb1 + slb0) * 0.5
       slbr = (slb1 + slb2) * 0.5

       ee = 0.5 * energy(ie) * (1 + slb1**2) * kperp2(ig,it,ik) * cfac &
            / (bmag(ig) * spec(is)%zstm)**2

       cc(il) = -vn * code_dt * (1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1))
       aa(il) = -vn * code_dt * (1 - slbl**2) / ((slbr - slbl) * (slb1 - slb0))
       bb(il) = 1.0 - (aa(il) + cc(il)) + ee * vn * code_dt

       if (heating) then
          dd(il) = vnc * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
          hh(il) = vnh * ((1 - slbr**2) / ((slbr - slbl) * (slb2 - slb1)) + ee)
       end if

    end select

    ! assuming symmetry in xi, fill in the rest of the arrays.
    aa(te+1:te2) = cc(teh:1:-1)
    bb(te+1:te2) = bb(teh:1:-1)
    cc(te+1:te2) = aa(teh:1:-1)

    if (heating) then
       dd(te+1:te2) = dd(teh:1:-1)
       hh(te+1:te2) = hh(teh:1:-1)
    end if

  end subroutine get_lorentz_matrix

  !> FIXME : Add documentation  
  subroutine init_lorentz_error
    use le_grids, only: jend, al, ng2, nlambda, &
         il_is_wfb, il_is_passing, il_is_trapped
    use theta_grid, only: ntgrid, bmag
    use array_utils, only: zero_array
    implicit none
    
    integer :: je, ig, il, ip, ij, im
    real :: slb0, slb1, slb2, slbr, slbl
    real, dimension (:), allocatable :: slb
    real, dimension (:,:), allocatable :: dprod
    real, dimension (:,:,:), allocatable :: dlcoef, d2lcoef

    allocate(slb(2*nlambda))
    allocate (dprod(nlambda,5))

    allocate (dlcoef(-ntgrid:ntgrid,nlambda,5))
    allocate (d2lcoef(-ntgrid:ntgrid,nlambda,5))
    allocate (dtot(-ntgrid:ntgrid,nlambda,5))
    allocate (fdf(-ntgrid:ntgrid,nlambda), fdb(-ntgrid:ntgrid,nlambda))

    dlcoef = 1.0; call zero_array(d2lcoef); call zero_array(dtot)
    call zero_array(fdf); call zero_array(fdb); slb = 0.0

    do ig=-ntgrid,ntgrid
       je = jend(ig)
       
       if (il_is_passing(je) .or. il_is_wfb(je)) then            ! no trapped particles

! calculation of xi and finite difference coefficients for non-boundary points
          do il=2,ng2-1
             slb(il) = safe_sqrt(1.0-al(il)*bmag(ig))   ! xi_{j}
             
             slb2 = safe_sqrt(1.0-al(il+1)*bmag(ig))    ! xi_{j+1}
             slb1 = slb(il)
             slb0 = safe_sqrt(1.0-al(il-1)*bmag(ig))    ! xi_{j-1}
             
             slbr = (slb2+slb1)*0.5                     ! xi_{j+1/2}
             slbl = (slb1+slb0)*0.5                     ! xi_{j-1/2}

! finite difference coefficients
             fdf(ig,il) = (1.0 - slbr*slbr)/(slbr - slbl)/(slb2 - slb1)
             fdb(ig,il) = (1.0 - slbl*slbl)/(slbr - slbl)/(slb1 - slb0)
          end do

! boundary at xi = 1
          slb(1) = safe_sqrt(1.0-al(1)*bmag(ig))
          slb0 = 1.0
          slb1 = slb(1)
          slb2 = slb(2)

          slbl = (slb1 + slb0)/2.0
          slbr = (slb1 + slb2)/2.0

! derivative of [(1-xi**2)*df/dxi] at xi_{j=1} is centered, with upper xi=1 and
! lower xi = xi_{j+1/2}
          fdf(ig,1) = (-1.0-slbr)/(slb2-slb1)
          fdb(ig,1) = 0.0

! boundary at xi = 0
          il = ng2
          slb(il) = safe_sqrt(1.0 - al(il)*bmag(ig))
          slb0 = safe_sqrt(1.0 - bmag(ig)*al(il-1))
          slb1 = slb(il)
          slb2 = -slb1

          slbl = (slb1 + slb0)/2.0
          slbr = (slb1 + slb2)/2.0

          fdf(ig,il) = (1.0 - slbr*slbr)/(slbr-slbl)/(slb2-slb1)
          fdb(ig,il) = (1.0 - slbl*slbl)/(slbr-slbl)/(slb1-slb0)

          slb(ng2+1:) = -slb(ng2:1:-1)
       else          ! run with trapped particles
          do il=2,je-1
             slb(il) = safe_sqrt(1.0-al(il)*bmag(ig))
             
             slb2 = safe_sqrt(1.0-al(il+1)*bmag(ig))
             slb1 = slb(il)
             slb0 = safe_sqrt(1.0-al(il-1)*bmag(ig))
             
             slbr = (slb2+slb1)*0.5
             slbl = (slb1+slb0)*0.5

             fdf(ig,il) = (1.0 - slbr*slbr)/(slbr - slbl)/(slb2 - slb1)
             fdb(ig,il) = (1.0 - slbl*slbl)/(slbr - slbl)/(slb1 - slb0)
          end do

! boundary at xi = 1
          slb(1) = safe_sqrt(1.0-bmag(ig)*al(1))
          slb0 = 1.0
          slb1 = slb(1)
          slb2 = slb(2)

          slbr = (slb1 + slb2)/2.0

          fdf(ig,1) = (-1.0 - slbr)/(slb2-slb1)
          fdb(ig,1) = 0.0

! boundary at xi = 0
          il = je
          slb(il) = safe_sqrt(1.0-bmag(ig)*al(il))
          slb0 = slb(je-1)
          slb1 = 0.
          slb2 = -slb0                                                        

          slbl = (slb1 + slb0)/2.0

          fdf(ig,il) = (1.0 - slbl*slbl)/slb0/slb0
          fdb(ig,il) = fdf(ig,il)

          slb(je+1:2*je-1) = -slb(je-1:1:-1)
       end if

! compute coefficients (dlcoef) multipyling first derivative of h
       do il=3,ng2
          do ip=il-2,il+2
             if (il == ip) then
                dlcoef(ig,il,ip-il+3) = 0.0
                do ij=il-2,il+2
                   if (ij /= ip) dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3) + 1/(slb(il)-slb(ij))
                end do
             else
                do ij=il-2,il+2
                   if (ij /= ip .and. ij /= il) then
                      dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                   end if
                end do
                dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3)/(slb(ip)-slb(il))
             end if
             dlcoef(ig,il,ip-il+3) = -2.0*slb(il)*dlcoef(ig,il,ip-il+3)
          end do
       end do

       il = 1
       do ip=il,il+2
          if (il == ip) then
             dlcoef(ig,il,ip) = 0.0
             do ij=il,il+2
                if (ij /= ip) dlcoef(ig,il,ip) = dlcoef(ig,il,ip) + 1./(slb(il)-slb(ij))
             end do
          else
             do ij=il,il+2
                if (ij /= ip .and. ij /= il) then
                   dlcoef(ig,il,ip) = dlcoef(ig,il,ip)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                end if
             end do
             dlcoef(ig,il,ip) = dlcoef(ig,il,ip)/(slb(ip)-slb(il))
          end if
          dlcoef(ig,il,ip) = -2.0*slb(il)*dlcoef(ig,il,ip)
       end do

       il = 2
       do ip=il-1,il+1
          if (il == ip) then
             dlcoef(ig,il,ip-il+2) = 0.0
             do ij=il-1,il+1
                if (ij /= ip) dlcoef(ig,il,ip-il+2) = dlcoef(ig,il,ip-il+2) + 1/(slb(il)-slb(ij))
             end do
          else
             do ij=il-1,il+1
                if (ij /= ip .and. ij /= il) then
                   dlcoef(ig,il,ip-il+2) = dlcoef(ig,il,ip-il+2)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                end if
             end do
             dlcoef(ig,il,ip-il+2) = dlcoef(ig,il,ip-il+2)/(slb(ip)-slb(il))
          end if
          dlcoef(ig,il,ip-il+2) = -2.0*slb(il)*dlcoef(ig,il,ip-il+2)
       end do

       dprod = 2.0

! compute coefficients (d2lcoef) multiplying second derivative of h
       do il=3,ng2
          do ip=il-2,il+2
             if (il == ip) then
                do ij=il-2,il+2
                   if (ij /= ip) then
                      do im=il-2,il+2
                         if (im /= ip .and. im /= ij) d2lcoef(ig,il,ip-il+3) = &
                              d2lcoef(ig,il,ip-il+3) + 1./((slb(il)-slb(im))*(slb(il)-slb(ij)))
                      end do
                   end if
                end do
             else
                do ij=il-2,il+2
                   if (ij /= il .and. ij /= ip) then
                      dprod(il,ip-il+3) = dprod(il,ip-il+3)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                   end if
                end do

                do ij=il-2,il+2
                   if (ij /= ip .and. ij /= il) then
                      d2lcoef(ig,il,ip-il+3) = d2lcoef(ig,il,ip-il+3) + 1./(slb(il)-slb(ij))
                   end if
                end do
                d2lcoef(ig,il,ip-il+3) = dprod(il,ip-il+3) &
                     *d2lcoef(ig,il,ip-il+3)/(slb(ip)-slb(il))
             end if
             d2lcoef(ig,il,ip-il+3) = (1.0-slb(il)**2)*d2lcoef(ig,il,ip-il+3)
          end do
       end do

       il = 1
       do ip=il,il+2
          if (il == ip) then
             do ij=il,il+2
                if (ij /= ip) then
                   do im=il,il+2
                      if (im /= ip .and. im /= ij) d2lcoef(ig,il,ip) = d2lcoef(ig,il,ip) + 1./((slb(il)-slb(im))*(slb(il)-slb(ij)))
                   end do
                end if
             end do
          else
             do ij=il,il+2
                if (ij /= il .and. ij /= ip) then
                   dprod(il,ip) = dprod(il,ip)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                end if
             end do

             do ij=il,il+2
                if (ij /= ip .and. ij /= il) then
                   d2lcoef(ig,il,ip) = d2lcoef(ig,il,ip) + 1./(slb(il)-slb(ij))
                end if
             end do
             d2lcoef(ig,il,ip) = dprod(il,ip)*d2lcoef(ig,il,ip)/(slb(ip)-slb(il))
          end if
          d2lcoef(ig,il,ip) = (1.0-slb(il)**2)*d2lcoef(ig,il,ip)
       end do

       il = 2
       do ip=il-1,il+1
          if (il == ip) then
             do ij=il-1,il+1
                if (ij /= ip) then
                   do im=il-1,il+1
                      if (im /= ip .and. im /= ij) d2lcoef(ig,il,ip-il+2) &
                           = d2lcoef(ig,il,ip-il+2) + 1./((slb(il)-slb(im))*(slb(il)-slb(ij)))
                   end do
                end if
             end do
          else
             do ij=il-1,il+1
                if (ij /= il .and. ij /= ip) then
                   dprod(il,ip-il+2) = dprod(il,ip-il+2)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                end if
             end do

             do ij=il-1,il+1
                if (ij /= ip .and. ij /= il) then
                   d2lcoef(ig,il,ip-il+2) = d2lcoef(ig,il,ip-il+2) + 1./(slb(il)-slb(ij))
                end if
             end do
             d2lcoef(ig,il,ip-il+2) = dprod(il,ip-il+2)*d2lcoef(ig,il,ip-il+2)/(slb(ip)-slb(il))
          end if
          d2lcoef(ig,il,ip-il+2) = (1.0-slb(il)**2)*d2lcoef(ig,il,ip-il+2)
       end do
       
       if (il_is_trapped(je)) then      ! have to handle trapped particles

          do il=ng2+1,je
             do ip=il-2,il+2
                if (il == ip) then
                   dlcoef(ig,il,ip-il+3) = 0.0
                   do ij=il-2,il+2
                      if (ij /= ip) dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3) + 1/(slb(il)-slb(ij))
                   end do
                else
                   do ij=il-2,il+2
                      if (ij /= ip .and. ij /= il) then
                         dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                      end if
                   end do
                   dlcoef(ig,il,ip-il+3) = dlcoef(ig,il,ip-il+3)/(slb(ip)-slb(il))
                end if
                dlcoef(ig,il,ip-il+3) = -2.0*slb(il)*dlcoef(ig,il,ip-il+3)
             end do
          end do

          do il=ng2+1,je
             do ip=il-2,il+2
                if (il == ip) then
                   do ij=il-2,il+2
                      if (ij /= ip) then
                         do im=il-2,il+2
                            if (im /= ip .and. im /= ij) d2lcoef(ig,il,ip-il+3) = &
                                 d2lcoef(ig,il,ip-il+3) + 1./((slb(il)-slb(im))*(slb(il)-slb(ij)))
                         end do
                      end if
                   end do
                else
                   do ij=il-2,il+2
                      if (ij /= il .and. ij /= ip) then
                         dprod(il,ip-il+3) = dprod(il,ip-il+3)*(slb(il)-slb(ij))/(slb(ip)-slb(ij))
                      end if
                   end do
                   
                   do ij=il-2,il+2
                      if (ij /= ip .and. ij /= il) then
                         d2lcoef(ig,il,ip-il+3) = d2lcoef(ig,il,ip-il+3) + 1./(slb(il)-slb(ij))
                      end if
                   end do
                   d2lcoef(ig,il,ip-il+3) = dprod(il,ip-il+3) &
                        *d2lcoef(ig,il,ip-il+3)/(slb(ip)-slb(il))
                end if
                d2lcoef(ig,il,ip-il+3) = (1.0-slb(il)**2)*d2lcoef(ig,il,ip-il+3)
             end do
          end do
          
       end if
    end do

    dtot = dlcoef + d2lcoef

    deallocate (slb, dprod, dlcoef, d2lcoef)
  end subroutine init_lorentz_error

  !> FIXME : Add documentation  
  subroutine solfp1_standard_layout (g, g1, gc1, gc2, diagnostics, gtoc, ctog)
    use gs2_layouts, only: g_lo, it_idx, ik_idx, ie_idx, is_idx
    use theta_grid, only: ntgrid
    use run_parameters, only: beta
    use gs2_time, only: code_dt
    use le_grids, only: energy => energy_maxw
    use species, only: spec, is_electron_species
    use dist_fn_arrays, only: vpa, aj0
    use fields_arrays, only: aparnew
    use run_parameters, only: ieqzip
    use kt_grids, only: kwork_filter, kperp2
    use optionals, only: get_option_with_default
    implicit none

    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g, g1, gc1, gc2
    integer, optional, intent (in) :: diagnostics
    logical, optional, intent (in) :: gtoc, ctog
    integer :: isgn, it, ik, ie, is, iglo
!CMR, 12/9/2013: 
!CMR   New logical optional input parameters gtoc, ctog used to set
!CMR   flags (g_to_c and c_to_g) to control whether redistributes required
!CMR   to map g_lo to collision_lo, and collision_lo to g_lo.
!CMR   All redistributes are performed by default.
!CMR  
    logical :: g_to_c, c_to_g

    g_to_c = get_option_with_default(gtoc, .true.)
    c_to_g = get_option_with_default(ctog, .true.)

    if (has_diffuse) then
       call solfp_ediffuse_standard_layout (g)
       if (conserve_moments) call conserve_diffuse (g, g1)
    end if

    if (has_lorentz) then
       if (drag) then
          !$OMP PARALLEL DO DEFAULT(none) &
          !$OMP PRIVATE(iglo, is, it, ik, ie, isgn) &
          !$OMP SHARED(g_lo, spec, kwork_filter, g, ieqzip, vnmult, code_dt, vpa, &
          !$OMP kperp2, aparnew, aj0, beta, energy) &
          !$OMP SCHEDULE(static)
          do iglo = g_lo%llim_proc, g_lo%ulim_proc
             is = is_idx(g_lo,iglo)
             if (.not. is_electron_species(spec(is))) cycle
             it = it_idx(g_lo,iglo)
             ik = ik_idx(g_lo,iglo)
             if(kwork_filter(it,ik)) cycle
             if(ieqzip(it,ik)) cycle
             ie = ie_idx(g_lo,iglo)
             do isgn = 1, 2
                g(:, isgn, iglo) = g(:, isgn, iglo) + vnmult(1)*spec(is)%vnewk*code_dt &
                     * vpa(:,isgn,iglo)*kperp2(:,it,ik)*aparnew(:,it,ik)*aj0(:,iglo) &
                     / ((-spec(is)%z*spec(is)%dens)*beta*spec(is)%stm*energy(ie)**1.5)
                ! probably need 1/(spec(is_ion)%z*spec(is_ion)%dens) above
                ! This has been implemented as 1/-spec(electron)%z*spec(electron)%dens
                ! in an attempt handle the multi-ion species case.
             end do
          end do
          !$OMP END PARALLEL DO
       end if

       call solfp_lorentz (g, gc1, gc2, diagnostics)
       if (conserve_moments) call conserve_lorentz (g, g1)
    end if
  end subroutine solfp1_standard_layout

  !> FIXME : Add documentation  
  subroutine solfp1_le_layout (gle, diagnostics)
    use gs2_layouts, only: le_lo, it_idx, ik_idx, ig_idx, is_idx
    use run_parameters, only: beta, ieqzip
    use gs2_time, only: code_dt
    use le_grids, only: energy => energy_maxw, negrid
    use species, only: spec, is_electron_species
    use fields_arrays, only: aparnew
    use kt_grids, only: kwork_filter, kperp2
    implicit none

    complex, dimension (:,:,le_lo%llim_proc:), intent (in out) :: gle
    integer, optional, intent (in) :: diagnostics
    complex :: tmp
    integer :: ig, it, ik, ie, is, ile, ixi

    if (has_diffuse) then
       call solfp_ediffuse_le_layout (gle)
       if (conserve_moments) call conserve_diffuse (gle)
    end if

    if (has_lorentz) then
       if (drag) then
          !$OMP PARALLEL DO DEFAULT(none) &
          !$OMP PRIVATE(ile, is, it, ik, ie, ig, ixi, tmp) &
          !$OMP SHARED(le_lo, spec, kwork_filter, negrid, ieqzip, vnmult, code_dt, kperp2, &
          !$OMP aparnew, beta, energy, nxi_lim, gle, vpa_aj0_le) &
          !$OMP SCHEDULE(static)
          do ile = le_lo%llim_proc, le_lo%ulim_proc
             is = is_idx(le_lo,ile)
             if (.not. is_electron_species(spec(is))) cycle
             it = it_idx(le_lo,ile)
             ik = ik_idx(le_lo,ile)
             if(kwork_filter(it,ik)) cycle
             if(ieqzip(it,ik)) cycle
             ig = ig_idx(le_lo,ile)
             do ie = 1, negrid
                ! Note here we may need aparnew from {it, ik} not owned by this
                ! processor in g_lo.
                tmp = vnmult(1)*spec(is)%vnewk*code_dt &
                     * kperp2(ig,it,ik)*aparnew(ig,it,ik) &
                     / ((-spec(is)%z*spec(is)%dens)*beta*spec(is)%stm*energy(ie)**1.5)
                do ixi = 1, nxi_lim
                   gle(ixi, ie, ile) = gle(ixi, ie, ile) + tmp * vpa_aj0_le(ixi, ie, ile)
                ! probably need 1/(spec(is_ion)%z*spec(is_ion)%dens) above
                ! This has been implemented as 1/-spec(electron)%z*spec(electron)%dens
                ! in an attempt handle the multi-ion species case.
                end do
             end do
          end do
          !$OMP END PARALLEL DO
       end if

       call solfp_lorentz (gle, diagnostics)
       if (conserve_moments) call conserve_lorentz (gle)
    end if
  end subroutine solfp1_le_layout

  !> FIXME : Add documentation
  subroutine conserve_lorentz_standard_layout (g, g1)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kt_grids, only: naky, ntheta0, kwork_filter
    use gs2_layouts, only: g_lo, ik_idx, it_idx, ie_idx, il_idx, is_idx
    use le_grids, only: energy => energy_maxw, speed => speed_maxw, al, &
         integrate_moment, negrid
    use dist_fn_arrays, only: aj0, aj1, vpa
    use run_parameters, only: ieqzip
    use array_utils, only: copy
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g, g1
    complex, dimension (:,:,:), allocatable :: gtmp
    real, dimension (:,:,:), allocatable :: vns
    complex, dimension (:,:,:,:), allocatable :: v0y0, v1y1, v2y2
    integer :: isgn, iglo, ik, ie, il, is, it
    logical, parameter :: all_procs = .true.

    allocate (v0y0(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (v1y1(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (v2y2(-ntgrid:ntgrid, ntheta0, naky, nspec))

    allocate (gtmp(-ntgrid:ntgrid,2,g_lo%llim_proc:g_lo%ulim_alloc))
    allocate (vns(naky,negrid,nspec))
    vns = vnmult(1) * vnew_D

    if (drag) then

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! First get v0y0
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, isgn) &
       !$OMP SHARED(g_lo, kwork_filter, gtmp, vpa, aj0, g) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          it = it_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          if(kwork_filter(it,ik))cycle
          do isgn = 1, 2
             ! v0 = vpa J0 f0, y0 = g
             gtmp(:, isgn, iglo) = vpa(:, isgn, iglo) * aj0(:, iglo) * g(:, isgn, iglo)
          end do
       end do
       !$OMP END PARALLEL DO
       call integrate_moment (gtmp, v0y0, all_procs)    ! v0y0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get y1 = y0 - v0y0 * z0 / (1 + v0z0)
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(iglo, it, ik, is, isgn) &
       !$OMP SHARED(g_lo, kwork_filter, g1, g, ieqzip, v0y0, z0) &
       !$OMP SCHEDULE(static)
       do iglo = g_lo%llim_proc, g_lo%ulim_proc
          it = it_idx(g_lo,iglo)
          ik = ik_idx(g_lo,iglo)
          if(kwork_filter(it,ik)) cycle
          if(ieqzip(it,ik)) cycle
          is = is_idx(g_lo,iglo)
          do isgn = 1, 2
             g1(:, isgn, iglo) = g(:, isgn, iglo) - v0y0(:, it, ik, is) * z0(:, isgn, iglo)
          end do
       end do
       !$OMP END PARALLEL DO
    else
       call copy(g, g1)
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1y1

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, il, ie, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, conservative, gtmp, vns, speed, vpdiff, aj0, g1, vpa) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       if(kwork_filter(it,ik))cycle
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = nud vpa J0 f0, y1 = g1
          if (conservative) then
             il = il_idx(g_lo,iglo)
             gtmp(:, isgn, iglo) = vns(ik, ie, is) * speed(ie) * vpdiff(:, isgn, il) &
                  * aj0(:, iglo) * g1(:, isgn, iglo)
          else
             gtmp(:, isgn, iglo) = vns(ik, ie, is) * vpa(:, isgn, iglo) * aj0(:, iglo) &
                  * g1(:, isgn, iglo)
          end if
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, v1y1, all_procs)    ! v1y1

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get y2 = y1 - v1y1 * s1 / (1 + v1s1)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, g1, ieqzip, v1y1, s0) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik)) cycle
       if(ieqzip(it,ik)) cycle
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          g1(:, isgn, iglo) = g1(:, isgn, iglo) - v1y1(:, it, ik, is) * s0(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v2y2

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, gtmp, vns, energy, al, aj1, g1) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik))cycle
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v2 = nud vperp J1 f0
          gtmp(:, isgn, iglo) = vns(ik, ie, is) * energy(ie) * al(il) * aj1(:, iglo) &
               * g1(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, v2y2, all_procs)    ! v2y2

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Finally get x = y2 - v2y2 * w2 / (1 + v2w2)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, g, g1, ieqzip, v2y2, w0) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik)) cycle
       if(ieqzip(it,ik)) cycle
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          g(:, isgn, iglo) = g1(:, isgn, iglo) - v2y2(:, it, ik, is) * w0(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    deallocate (vns, v0y0, v1y1, v2y2, gtmp)

  end subroutine conserve_lorentz_standard_layout

  !> FIXME : Add documentation
  subroutine conserve_lorentz_le_layout (gle)
    use gs2_layouts, only: ik_idx, it_idx, ie_idx, il_idx, is_idx, ig_idx, le_lo
    use le_grids, only: speed => speed_maxw, w, wxi, negrid
    use run_parameters, only: ieqzip
    use kt_grids, only: kwork_filter
    implicit none
    complex, dimension (:,:,le_lo%llim_proc:), intent (in out) :: gle
    complex :: v0y0, v1y1, v2y2
    real :: nud
    integer :: ig, ik, ie, is, it, ile, ixi

    if (drag) then
       !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
       ! First get v0y0 and y1 = y0 - v0y0 * z0 / (1 + v0z0)

       ! v0 = vpa J0 f0, y0 = gle
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ile, it, ik, ixi, ie, is, ig, v0y0) &
       !$OMP SHARED(le_lo, kwork_filter, ieqzip, negrid, nxi_lim, &
       !$OMP z0le, w, wxi, vpa_aj0_le, gle) &
       !$OMP SCHEDULE(static)
       do ile = le_lo%llim_proc, le_lo%ulim_proc
          it = it_idx(le_lo,ile)
          ik = ik_idx(le_lo,ile)
          if (kwork_filter(it, ik)) cycle
          if (ieqzip(it, ik)) cycle
          is = is_idx(le_lo, ile)
          ig = ig_idx(le_lo, ile)
          v0y0 = 0.0
          do ie = 1, negrid
             do ixi = 1, nxi_lim
                ! Should we use vpdiff here if conservative?
                v0y0 = v0y0 + vpa_aj0_le(ixi, ie, ile) * gle(ixi, ie, ile) * w(ie, is) * wxi(ixi, ig)
             end do
          end do
          gle(:, :, ile) = gle(:, :, ile) - z0le(:, :, ile) * v0y0
       end do
       !$OMP END PARALLEL DO
    end if

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1y1 and y2 = y1 - v1y1 * s1 / (1 + v1s1)

    ! v1 = nud vpa J0 f0, y1 = gle
    if (conservative) then
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ile, is, it, ik, ig, ixi, ie, nud, v1y1) &
       !$OMP SHARED(le_lo, kwork_filter, ieqzip, negrid, nxi_lim, vpdiffle, &
       !$OMP speed, vnmult, vnew_D, aj0le, gle, s0le, w, wxi) &
       !$OMP SCHEDULE(static)
       do ile = le_lo%llim_proc, le_lo%ulim_proc
          ik = ik_idx(le_lo,ile)
          it = it_idx(le_lo,ile)
          if (kwork_filter(it, ik)) cycle
          if (ieqzip(it, ik)) cycle
          is = is_idx(le_lo,ile)
          ig = ig_idx(le_lo,ile)
          v1y1 = 0.0
          do ie = 1, negrid
             nud = speed(ie) * vnmult(1) * vnew_D(ik, ie, is) * w(ie, is)
             do ixi = 1, nxi_lim
                v1y1 = v1y1 + vpdiffle(ixi, ig) * nud * aj0le(ixi, ie, ile) * &
                     gle(ixi, ie, ile) * wxi(ixi, ig)
             end do
          end do
          gle(:, :, ile) = gle(:, :, ile) - s0le(:, :, ile) * v1y1
       end do
       !$OMP END PARALLEL DO
    else
       !$OMP PARALLEL DO DEFAULT(none) &
       !$OMP PRIVATE(ile, is, it, ik, ig, ixi, ie, nud, v1y1) &
       !$OMP SHARED(le_lo, kwork_filter, ieqzip, negrid, nxi_lim, vpa_aj0_le, vnmult, &
       !$OMP vnew_D, gle, w, wxi, s0le) &
       !$OMP SCHEDULE(static)
       do ile = le_lo%llim_proc, le_lo%ulim_proc
          ik = ik_idx(le_lo, ile)
          it = it_idx(le_lo, ile)
          if (kwork_filter(it, ik)) cycle
          if (ieqzip(it, ik)) cycle
          is = is_idx(le_lo, ile)
          ig = ig_idx(le_lo, ile)
          v1y1 = 0.0
          do ie = 1, negrid
             nud =  vnmult(1) * vnew_D(ik, ie, is) * w(ie, is)
             do ixi = 1, nxi_lim
                v1y1 = v1y1 + nud * vpa_aj0_le(ixi, ie, ile) * &
                     gle(ixi, ie, ile) * wxi(ixi, ig)
             end do
          end do
          gle(:, :, ile) = gle(:, :, ile) - s0le(:, :, ile) * v1y1
       end do
       !$OMP END PARALLEL DO
    end if

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    ! Now get v2y2 and x = y2 - v2y2 * w2 / (1 + v2w2)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(ile, it, ik, is, ie, ig, ixi, nud, v2y2) &
    !$OMP SHARED(le_lo, kwork_filter, ieqzip, negrid, nxi_lim, &
    !$OMP vnmult, vnew_D, vperp_aj1le, gle, w, wxi, w0le) &
    !$OMP SCHEDULE(static)
    do ile = le_lo%llim_proc, le_lo%ulim_proc
       it = it_idx(le_lo, ile)
       ik = ik_idx(le_lo, ile)
       if (kwork_filter(it, ik)) cycle
       if (ieqzip(it, ik)) cycle
       is = is_idx(le_lo, ile)
       ig = ig_idx(le_lo, ile)
       v2y2 = 0.0
       do ie = 1, negrid
          nud = vnmult(1) * vnew_D(ik, ie, is) * w(ie, is)
          do ixi = 1, nxi_lim
             ! aj1vp2 = 2 * J1(arg)/arg * vperp^2
             v2y2 = v2y2 + nud * vperp_aj1le(ixi, ie, ile) * gle(ixi, ie, ile) * wxi(ixi, ig)
          end do
       end do
       gle(:, :, ile) = gle(:, :, ile) - w0le(:, :, ile) * v2y2
    end do
    !$OMP END PARALLEL DO
  end subroutine conserve_lorentz_le_layout

  !> FIXME : Add documentation
  subroutine conserve_diffuse_standard_layout (g, g1)
    use theta_grid, only: ntgrid
    use species, only: nspec
    use kt_grids, only: naky, ntheta0, kwork_filter
    use gs2_layouts, only: g_lo, ik_idx, it_idx, ie_idx, il_idx, is_idx
    use le_grids, only: energy => energy_maxw, al, integrate_moment, negrid
    use dist_fn_arrays, only: aj0, aj1, vpa
    use run_parameters, only: ieqzip
    use array_utils, only: zero_array
    implicit none
    complex, dimension (-ntgrid:,:,g_lo%llim_proc:), intent (in out) :: g, g1
    complex, dimension (:,:,:), allocatable :: gtmp
    real, dimension (:,:,:), allocatable :: vns
    integer :: isgn, iglo, ik, ie, il, is, it 
    complex, dimension (:,:,:,:), allocatable :: v0y0, v1y1, v2y2    
    logical, parameter :: all_procs = .true.

    allocate (v0y0(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (v1y1(-ntgrid:ntgrid, ntheta0, naky, nspec))
    allocate (v2y2(-ntgrid:ntgrid, ntheta0, naky, nspec))

    allocate (vns(naky, negrid, nspec))
    allocate (gtmp(-ntgrid:ntgrid, 2, g_lo%llim_proc:g_lo%ulim_alloc))

    vns = vnmult(2)*delvnew

    !This is needed to to ensure the it,ik values we don't set aren't included
    !in the integral (can also be enforced in integrate_moment routine)
    if(any(kwork_filter)) call zero_array(gtmp)

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! First get v0y0

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, gtmp, vnmult, vnew_E, aj0, g) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       if(kwork_filter(it,ik))cycle
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v0 = nu_E E J0 f0
          gtmp(:, isgn, iglo) = vnmult(2) * vnew_E(ik, ie, is) * aj0(:, iglo) &
               * g(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, v0y0, all_procs)    ! v0y0

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get y1 = y0 - v0y0 * z0 / (1 + v0z0)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, g1, g, ieqzip, v0y0, bz0) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik)) cycle
       if(ieqzip(it,ik)) cycle
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          g1(:, isgn, iglo) = g(:, isgn, iglo) - v0y0(:, it, ik, is) * bz0(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v1y1

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, gtmp, vns, vpa, aj0, g1) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       if(kwork_filter(it,ik))cycle
       ie = ie_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v1 = (nus-nud) vpa J0 f0
          gtmp(:, isgn, iglo) = vns(ik, ie, is) * vpa(:, isgn, iglo) * aj0(:, iglo) &
               * g1(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, v1y1, all_procs)    ! v1y1

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Get y2 = y1 - v1y1 * s1 / (1 + v1s1)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, g1, ieqzip, v1y1, bs0) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik)) cycle
       if(ieqzip(it,ik)) cycle
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          g1(:, isgn, iglo) = g1(:, isgn, iglo) - v1y1(:, it, ik, is) * bs0(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Now get v2y2

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, ie, il, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, gtmp, vns, energy, al, aj1, g1) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       ik = ik_idx(g_lo,iglo)
       it = it_idx(g_lo,iglo)
       if(kwork_filter(it,ik))cycle
       ie = ie_idx(g_lo,iglo)
       il = il_idx(g_lo,iglo)
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          ! v2 = (nus-nud) vperp J1 f0
          gtmp(:, isgn, iglo) = vns(ik, ie, is) * energy(ie) * al(il) * aj1(:, iglo) &
               * g1(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    call integrate_moment (gtmp, v2y2, all_procs)    ! v2y2

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
! Finally get x = y2 - v2y2 * w2 / (1 + v2w2)

    !$OMP PARALLEL DO DEFAULT(none) &
    !$OMP PRIVATE(iglo, it, ik, is, isgn) &
    !$OMP SHARED(g_lo, kwork_filter, g, g1, ieqzip, v2y2, bw0) &
    !$OMP SCHEDULE(static)
    do iglo = g_lo%llim_proc, g_lo%ulim_proc
       it = it_idx(g_lo,iglo)
       ik = ik_idx(g_lo,iglo)
       if(kwork_filter(it,ik)) cycle
       if(ieqzip(it,ik)) cycle
       is = is_idx(g_lo,iglo)
       do isgn = 1, 2
          g(:, isgn, iglo) = g1(:, isgn, iglo) - v2y2(:, it, ik, is) * bw0(:, isgn, iglo)
       end do
    end do
    !$OMP END PARALLEL DO

    deallocate (vns, v0y0, v1y1, v2y2, gtmp)

  end subroutine conserve_diffuse_standard_layout

  !> FIXME : Add documentation
  subroutine conserve_diffuse_le_layout (gle)
    use gs2_layouts, only: ik_idx, it_idx, ie_idx, il_idx, is_idx, ig_idx, le_lo
    use le_grids, only: wxi, w, negrid
    use run_parameters, only: ieqzip
    use kt_grids, only: kwork_filter
    implicit none
    complex, dimension (:,:,le_lo%llim_proc:), intent (in out) :: gle
    real :: delnu, vnE
    integer :: ig, ik, ie, is, it, ile, ixi
    complex :: v0y0, v1y1, v2y2

    !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
    ! First get v0y0 and then y1 = y0 - v0y0 * z0 / (1 + v0z0)

    !$OMP PARALLEL DO DEFAULT(none) &