init_inverse_matrix Subroutine

private subroutine init_inverse_matrix(am, ic)

FIXME : Add documentation

Arguments

Type IntentOptional Attributes Name
complex, intent(inout), dimension(:,f_lo(ic)%llim_proc:) :: am
integer, intent(in) :: ic

Contents

Source Code


Source Code

  subroutine init_inverse_matrix (am, ic)
    use file_utils, only: error_unit
    use kt_grids, only: aky, akx
    use theta_grid, only: ntgrid
    use mp, only: broadcast, send, receive, iproc, get_mp_times
    use job_manage, only: time_message
    use fields_arrays, only: time_field_invert, time_field_invert_mpi
    use gs2_layouts, only: f_lo, idx, idx_local, proc_id, jf_lo
    use gs2_layouts, only: if_idx, im_idx, in_idx, local_field_solve
    use gs2_layouts, only: ig_idx, ifield_idx, ij_idx
    use dist_fn, only: i_class, M_class, N_class
    use array_utils, only: zero_array
    implicit none
    integer, intent (in) :: ic
    complex, dimension(:,f_lo(ic)%llim_proc:), intent (in out) :: am
    complex, dimension(:,:), allocatable :: a_inv, lhscol, rhsrow, col_row_tmp
    complex, dimension (:), allocatable :: am_tmp
    complex :: fac
    integer :: i, j, k, ik, it, m, n, nn, if, ig, jsc, jf, jg, jc
    integer :: irow, ilo, jlo, dc, iflo, ierr
    logical :: iskip, jskip
    real :: mp_total, mp_total_after

    call time_message(.false.,time_field_invert,' Field Matrix Invert')
    call get_mp_times(total_time = mp_total)
    
    allocate (lhscol (nidx*N_class(ic),M_class(ic)))
    allocate (rhsrow (nidx*N_class(ic),M_class(ic)))
   
    !This is the length of a supercell
    j = nidx*N_class(ic)

    !Create storage space
    allocate (a_inv(j,f_lo(ic)%llim_proc:f_lo(ic)%ulim_alloc))
    call zero_array(a_inv)
    
    if (.not. skip_initialisation) then
      !Set (ifield*ig,ilo) "diagonal" to 1?
      do ilo = f_lo(ic)%llim_proc, f_lo(ic)%ulim_proc
         a_inv(if_idx(f_lo(ic),ilo),ilo) = 1.0
      end do

      ! Gauss-Jordan elimination, leaving out internal points at multiples of ntgrid 
      ! for each supercell
      !Loop over parallel gridpoints in supercell
      do i = 1, nidx*N_class(ic)
         !iskip is true iff the theta grid point(ig) corresponding to i
         !is at the upper end of a 2pi domain/cell and is not the rightmost gridpoint
         iskip = N_class(ic) > 1 !Are the multiple cells => are there connections/boundaries
         iskip = i <= nidx*N_class(ic) - nfield .and. iskip !Are we not near the upper boundary of the supercell
         iskip = mod((i+nfield-1)/nfield, 2*ntgrid+1) == 0 .and. iskip !Are we at a theta grid point corresponding to the rightmost point of a 2pi domain
         iskip = i > nfield .and. iskip !Are we not at the lower boundary of the supercell
         if (iskip) cycle
   
         if (local_field_solve) then
            do m = 1, M_class(ic)
               ilo = idx(f_lo(ic),i,m)
               if (idx_local(f_lo(ic),ilo)) then
                  lhscol(:,m) = am(:,ilo)
                  rhsrow(:,m) = a_inv(:,ilo)
               end if
            end do
         else
            allocate(col_row_tmp(nidx*N_class(ic),2)) ; col_row_tmp = 0.
            !Loop over classes (supercell lengths)
            do m = 1, M_class(ic)
               !Convert to f_lo index
               ilo = idx(f_lo(ic),i,m)
               !Is ilo on this proc?
               if (idx_local(f_lo(ic),ilo)) then
                  !If so store column/row
                  !lhscol(:,m) = am(:,ilo)
                  !rhsrow(:,m) = a_inv(:,ilo)
                  col_row_tmp(:,1) = am(:,ilo)
                  col_row_tmp(:,2) = a_inv(:,ilo)
               end if
               !Here we send lhscol and rhscol sections to all procs
               !from the one on which it is currently known
               !Can't do this outside m loop as proc_id depends on m
               !These broadcasts can be relatively expensive so local_field_solve
               !may be preferable
               !call broadcast (lhscol(:,m), proc_id(f_lo(ic),ilo))
               !call broadcast (rhsrow(:,m), proc_id(f_lo(ic),ilo))
               call broadcast (col_row_tmp, proc_id(f_lo(ic),ilo))
               lhscol(:,m) = col_row_tmp(:,1)
               rhsrow(:,m) = col_row_tmp(:,2)
            end do
            !All procs will have the same lhscol and rhsrow after this loop+broadcast
            deallocate(col_row_tmp)
         end if

         !Loop over field compound dimension
         do jlo = f_lo(ic)%llim_proc, f_lo(ic)%ulim_proc
            !jskip is true similarly to iskip
            jskip = N_class(ic) > 1 !Are there any connections?
            jskip = ig_idx(f_lo(ic), jlo) == ntgrid .and. jskip !Are we at a theta grid point corresponding to the upper boundary?
            !Get 2pi domain/cell number out of total for this supercell
            n = in_idx(f_lo(ic),jlo)
            jskip = n < N_class(ic) .and. jskip !Are we not in the last cell (i.e. not at the rightmost grid point/upper end of supercell)?
            if (jskip) cycle  !Skip this point if appropriate

            !Now get m (class number)
            m = im_idx(f_lo(ic),jlo)

            !Convert class number and cell number to ik and it
            ik = f_lo(ic)%ik(m,n)
            it = f_lo(ic)%it(m,n)
            
            !Work out what the compound theta*field index is.
            irow = if_idx(f_lo(ic),jlo)

            !If ky or kx are not 0 (i.e. skip zonal 0,0 mode) then workout the array
            if (aky(ik) /= 0.0 .or. akx(it) /= 0.0) then
               !Get factor
               fac = am(i,jlo)/lhscol(i,m)

               !Store array element
               am(i,jlo) = fac

               !Store other elements
               am(:i-1,jlo) = am(:i-1,jlo) - lhscol(:i-1,m)*fac
               am(i+1:,jlo) = am(i+1:,jlo) - lhscol(i+1:,m)*fac
               !WOULD the above three commands be better written as
               !am(:,jlo)=am(:,jlo)-lhscol(:,m)*fac
               !am(i,jlo)=fac

               !Fill in a_inv
               if (irow == i) then
                  a_inv(:,jlo) = a_inv(:,jlo)/lhscol(i,m)
               else
                  a_inv(:,jlo) = a_inv(:,jlo) &
                       - rhsrow(:,m)*lhscol(irow,m)/lhscol(i,m)
               end if
            else
               a_inv(:,jlo) = 0.0
            end if
     
         end do
      end do

      !Free memory
      deallocate (lhscol, rhsrow)

  ! fill in skipped points for each field and supercell:
  ! Do not include internal ntgrid points in sum over supercell

      do i = 1, nidx*N_class(ic)
         !iskip is true iff the theta grid point(ig) corresponding to i
         !is at the upper end of a 2pi domain/cell and is not the rightmost gridpoint
         iskip = N_class(ic) > 1 !Are the multiple cells => are there connections/boundaries
         iskip = i <= nidx*N_class(ic) - nfield .and. iskip  !Are we not near the upper boundary of the supercell
         iskip = mod((i+nfield-1)/nfield, 2*ntgrid+1) == 0 .and. iskip !Are we at a theta grid point corresponding to the rightmost point of a 2pi domain
         iskip = i > nfield .and. iskip !Are we not at the lower boundary of the supercell
         !Zero out skipped points
         if (iskip) then
            a_inv(i,:) = 0
            cycle !Seems unnexessary
         end if
      end do
  ! Make response at internal ntgrid points identical to response
  ! at internal -ntgrid points:
      do jlo = f_lo(ic)%llim_world, f_lo(ic)%ulim_world
         !jskip is true similarly to iskip
         jskip = N_class(ic) > 1 !Are there any connections?
         jskip = ig_idx(f_lo(ic), jlo) == ntgrid .and. jskip  !Are we at a theta grid point corresponding to the upper boundary?
         jskip = in_idx(f_lo(ic), jlo) < N_class(ic) .and. jskip  !Are we not in the last cell (i.e. not at the rightmost grid point/upper end of supercell)?
         !If we previously skipped this point then we want to fill it in from the matched/connected point
         if (jskip) then
            !What is the index of the matched point?
            ilo = jlo + nfield
            !If we have ilo on this proc send it to...
            if (idx_local(f_lo(ic), ilo)) then
               !jlo on this proc
               if (idx_local(f_lo(ic), jlo)) then
                  a_inv(:,jlo) = a_inv(:,ilo)
               !jlo on proc which has jlo
               else
                  call send(a_inv(:,ilo), proc_id(f_lo(ic), jlo))
               endif
            else
               !If this proc has jlo then get ready to receive
               if (idx_local(f_lo(ic), jlo)) then
                  call receive(a_inv(:,jlo), proc_id(f_lo(ic), ilo))
               end if
            end if
         end if
      end do
      !The send receives in the above loop should be able to function in a
      !non-blocking manner fairly easily, but probably don't cost that much
      !Would require WAITALL before doing am=a_inv line below

      !Update am
      am = a_inv
    end if ! .not. skip_initialisation

    !Free memory
    deallocate (a_inv)

! Re-sort this class of aminv for runtime application.  

    !Now allocate array to store matrices for each class
    if (.not.allocated(aminv)) allocate (aminv(i_class))

! only need this large array for particular values of jlo.
! To save space, count how many this is and only allocate
! required space:

    !Initialise counter
    dc = 0
! check all members of this class
    do ilo = f_lo(ic)%llim_world, f_lo(ic)%ulim_world

! find supercell coordinates
       !i.e. what is my class of supercell and which cell am I looking at
       m = im_idx(f_lo(ic), ilo)
       n = in_idx(f_lo(ic), ilo)

! find standard coordinates
       !Get theta, field, kx and ky indexes for current point
       ig = ig_idx(f_lo(ic), ilo)
       if = ifield_idx(f_lo(ic), ilo)
       ik = f_lo(ic)%ik(m,n)
       it = f_lo(ic)%it(m,n)

! translate to fast field coordinates
       jlo = ij_idx(jf_lo, ig, if, ik, it)
          
! Locate this jlo, count it, and save address
       !Is this point on this proc, if so increment counter
       if (idx_local(jf_lo,jlo)) then
! count it
          dc = dc + 1
! save dcell address
           jf_lo%dj(ic,jlo) = dc
! save supercell address
          jf_lo%mj(jlo) = m
       endif
          
    end do

! allocate dcells and supercells in this class on this PE:
    !Loop over "fast field" index
    do jlo = jf_lo%llim_proc, jf_lo%ulim_proc
          
       !Allocate store in this class, on this proc to store the jlo points
       if (.not.associated(aminv(ic)%dcell)) then
          allocate (aminv(ic)%dcell(dc))
       else
          !Just check the array is the correct size
          j = size(aminv(ic)%dcell)
          if (j /= dc) then
             ierr = error_unit()
             write(ierr,*) 'Error (1) in init_inverse_matrix: ',&
                  iproc,':',jlo,':',dc,':',j
          endif
       endif
       
       !Get the current "dcell" adress
       k = jf_lo%dj(ic,jlo)

       !No dcell should be 0 but this is a guard
       if (k > 0) then
          !How long is the supercell for this class?
          jc = nidx*N_class(ic)

          !Allocate storage for the supercell if required
          if (.not.associated(aminv(ic)%dcell(k)%supercell)) then
             allocate (aminv(ic)%dcell(k)%supercell(jc))
          else
             !Just check the array is the correct size
             j = size(aminv(ic)%dcell(k)%supercell)
             if (j /= jc) then
                ierr = error_unit()
                write(ierr,*) 'Error (2) in init_inverse_matrix: ', &
                     iproc,':',jlo,':',jc,':',j
             end if
          end if
       end if
    end do

! Now fill aminv for this class:

    !Allocate temporary supercell storage
    allocate (am_tmp(nidx*N_class(ic)))

    !Loop over all grid points
    do ilo = f_lo(ic)%llim_world, f_lo(ic)%ulim_world

       !Get supercell type (class) and cell index
       m = im_idx(f_lo(ic), ilo)
       n = in_idx(f_lo(ic), ilo)
       
       !Convert to theta,field,kx and ky indexes
       ig = ig_idx(f_lo(ic), ilo)
       if = ifield_idx(f_lo(ic), ilo)
       ik = f_lo(ic)%ik(m,n)
       it = f_lo(ic)%it(m,n)
       
       !Get fast field index
       iflo = ij_idx(jf_lo, ig, if, ik, it)
 
       !If this ilo is local then...
       if (idx_local(f_lo(ic),ilo)) then
          ! send the am data to...
          if (idx_local(jf_lo,iflo)) then
             !the local proc
             am_tmp = am(:,ilo)
          else
             !the remote proc
             call send(am(:,ilo), proc_id(jf_lo,iflo))
          endif
       else
          !Get ready to receive the data
          if (idx_local(jf_lo,iflo)) then
             call receive(am_tmp, proc_id(f_lo(ic),ilo))
          end if
       end if

       !If the fast field index is on this processor
       if (idx_local(jf_lo, iflo)) then
          !Get "dcell" adress
          dc = jf_lo%dj(ic,iflo)

          !Loop over supercell size
          do jlo = 0, nidx*N_class(ic)-1
             !Convert to cell/2pi domain index
             nn = in_idx(f_lo(ic), jlo)
             
             !Get theta grid point
             jg = ig_idx(f_lo(ic), jlo)
             !Get field index
             jf = ifield_idx(f_lo(ic), jlo)
             
             !Convert index
             jsc = ij_idx(f_lo(ic), jg, jf, nn) + 1

             !Store inverse matrix data in appropriate supercell position
             aminv(ic)%dcell(dc)%supercell(jsc) = am_tmp(jlo+1)
             
          end do
       end if
    end do

    !Free memory
    deallocate (am_tmp)

    call time_message(.false.,time_field_invert,' Field Matrix Invert')
    call get_mp_times(total_time = mp_total_after)
    time_field_invert_mpi = time_field_invert_mpi + (mp_total_after - mp_total)
  end subroutine init_inverse_matrix