matrix_multiply.fpp Source File


Contents

Source Code


Source Code

!> Provides a consistent wrapper to different matrix multiply implementations
module matrix_multiply
  implicit none

  private

  public :: matmul_wrapper, multiply_methods, recommend_multiplication_method
  public :: matrix_multiply_method_type, matmul_lapack, matmul_intrinsic, matmul_custom

  !> A simple wrapper type around an integer to represent different
  !> multiply methods.
  type matrix_multiply_method_type
     private
     integer :: flag = 0
     character(len=16) :: name = "UNSET NAME"
   contains
     procedure, public :: get_flag => matrix_multiply_method_get_flag
     procedure, public :: get_name => matrix_multiply_method_get_name
  end type matrix_multiply_method_type

  !> Method calling blas' gemm directly
  type(matrix_multiply_method_type), parameter :: matmul_lapack = &
       matrix_multiply_method_type(flag = 1, name = 'Lapack')

  !> Method simply calling intrinsic matmul
  type(matrix_multiply_method_type), parameter :: matmul_intrinsic = &
       matrix_multiply_method_type(flag = 2, name = 'Intrinsic')

  !> Method using explicit loops with OpenMP parallelisation. Could do with a better name.
  type(matrix_multiply_method_type), parameter :: matmul_custom = &
       matrix_multiply_method_type(flag = 3, name = 'Custom')

  !> Array of all available methods
  type(matrix_multiply_method_type), dimension(*), parameter :: multiply_methods = [ &
#ifdef LAPACK
       matmul_lapack, &
#endif
       matmul_intrinsic, &
       matmul_custom &
       ]

  !> Provides a wrapper to different matrix multiplication methods
  interface matmul_wrapper
     module procedure matmul_wrapper_complex_2d
  end interface matmul_wrapper

  !> Used to store the default method to use if not passed to matmul_wrapper.
  !> This way we can override the default (e.g. after timing performance)
  !> whilst still having a reasonable default
  type(matrix_multiply_method_type) :: matmul_default_method = matmul_intrinsic

contains

  !> Helper routine to get access to the matrix_multiply_method's flag
  !> in a read-only way. Primarily used for testing.
  elemental integer function matrix_multiply_method_get_flag(self) result(flag)
    implicit none
    class(matrix_multiply_method_type), intent(in) :: self
    flag = self%flag
  end function matrix_multiply_method_get_flag

  !> Helper routine to get access to the matrix_multiply_method's name
  !> in a read-only way. Primarily used for testing.
  elemental character(len=16) function matrix_multiply_method_get_name(self) result(name)
    implicit none
    class(matrix_multiply_method_type), intent(in) :: self
    name = self%name
  end function matrix_multiply_method_get_name

  !> Times each available method for the given problem size and returns
  !> the flag of the fastest one.
  function recommend_multiplication_method(trial_size, repeats, display_times) result(method)
    use job_manage, only: timer_local
    use file_utils, only: error_unit
    use ran, only: ranf
    use optionals, only: get_option_with_default
    implicit none
    integer, intent(in) :: trial_size
    integer, intent(in), optional :: repeats
    logical, intent(in), optional :: display_times
    type(matrix_multiply_method_type) :: method
    real :: start_time
    integer :: method_index, j, k
    complex, dimension(:, :), allocatable :: a, b, c
    integer :: error, number_of_methods, number_of_repeats, repeat
    real, dimension(:), allocatable :: times
    logical :: should_display_times

    ! Initialise to invalid method flag
    method = matrix_multiply_method_type(flag = -1)

    ! Try to allocate matrix of trial_size, if it doesn't work report error
    ! and return.
    allocate(a(trial_size, trial_size), stat = error)
    allocate(b(trial_size, trial_size), stat = error)
    allocate(c(trial_size, trial_size), stat = error)

    if (error > 0) then
       write(error_unit(), '("Unable to allocated trial array with size ",I0)') trial_size
       return
    end if

    ! Initialise arrays
    do j = 1, trial_size
       do k = 1, trial_size
          a(j,k) = cmplx(ranf(), ranf())
          b(j,k) = a(j,k)
          c(j,k) = 0.0
       end do
    end do

    number_of_methods = size(multiply_methods)

    allocate(times(number_of_methods))
    times = 0.0

    number_of_repeats = get_option_with_default(repeats, 1)
    should_display_times = get_option_with_default(display_times, .false.)

    ! Time each method
    do method_index = 1, number_of_methods
       start_time = timer_local()
       do repeat = 1, number_of_repeats
          c = matmul_wrapper(a, b, multiply_methods(method_index))
       end do
       times(method_index) = timer_local() - start_time
    end do

    ! Report the results if requested
    if (should_display_times) then
       write(*,'("Matrix multiplication benchmark : ",I0,&
            &" repeats of square matrices with size ",I0)') number_of_repeats, trial_size
       write(*,'("Method",T18,"Total time (s)")')
       do method_index = 1, number_of_methods
          method = multiply_methods(method_index)
          write(*,'(A16,T18,F12.8)') trim(method%get_name()), times(method_index)
       end do
    end if

    ! Recommend the method with the smallest time
    method = multiply_methods( minloc(times, dim = 1))

  end function recommend_multiplication_method

  !> Wrapper to complex matrix multiplication with two 2D matrices
  function matmul_wrapper_complex_2d(a, b, method) result(output)
    use mp, only: mp_abort
#ifdef LAPACK
    use lapack_wrapper, only: gemm
#endif
    implicit none
    complex, dimension(:, :), intent(in) :: a, b
    type(matrix_multiply_method_type), optional, intent(in) :: method
    type(matrix_multiply_method_type) :: internal_method
    complex, dimension(:, :), allocatable :: output
    complex, parameter :: one = 1.0, zero = 0.0
    complex :: tmp
    integer :: d1, d2, d3
    integer :: i, j, k
    d1 = size(a, dim = 1)
    d2 = size(b, dim = 2)
    d3 = size(a, dim = 2) ! Should match size(b, dim = 1)
#ifdef GK_DEBUG
    if (d3 /= size(b, dim = 1)) error stop
#endif
    allocate(output(d1, d2))

    internal_method = matmul_default_method
    if (present(method)) internal_method = method

    select case (internal_method%flag)
    case(matmul_intrinsic%flag)
       output = matmul(a, b)
#ifdef LAPACK
    case(matmul_lapack%flag)
       call gemm('n', 'n', d1, d2, d3, one, a, d1, b, d3, zero, output, d1)
#endif
    case(matmul_custom%flag)
       !$OMP PARALLEL DO DEFAULT(none)&
       !$OMP PRIVATE(i, j, k, tmp) &
       !$OMP SHARED(output, d1, d2, d3, a, b) &
       !$OMP SCHEDULE(static)
       do j = 1, d2
          output(:, j) = zero
          do k = 1, d3
             tmp = b(k, j)
             do i = 1, d1
                output(i, j) = output(i, j) + tmp * a(i, k)
             end do
          end do
       end do
       !$OMP END PARALLEL DO
    case default
       call mp_abort("Unknown multiply method passed to matrix_multiply::matmul_wrapper", .true.)
    end select

  end function matmul_wrapper_complex_2d

end module matrix_multiply