!> Provides a consistent wrapper to different matrix multiply implementations module matrix_multiply implicit none private public :: matmul_wrapper, multiply_methods, recommend_multiplication_method public :: matrix_multiply_method_type, matmul_lapack, matmul_intrinsic, matmul_custom !> A simple wrapper type around an integer to represent different !> multiply methods. type matrix_multiply_method_type private integer :: flag = 0 character(len=16) :: name = "UNSET NAME" contains procedure, public :: get_flag => matrix_multiply_method_get_flag procedure, public :: get_name => matrix_multiply_method_get_name end type matrix_multiply_method_type !> Method calling blas' gemm directly type(matrix_multiply_method_type), parameter :: matmul_lapack = & matrix_multiply_method_type(flag = 1, name = 'Lapack') !> Method simply calling intrinsic matmul type(matrix_multiply_method_type), parameter :: matmul_intrinsic = & matrix_multiply_method_type(flag = 2, name = 'Intrinsic') !> Method using explicit loops with OpenMP parallelisation. Could do with a better name. type(matrix_multiply_method_type), parameter :: matmul_custom = & matrix_multiply_method_type(flag = 3, name = 'Custom') !> Array of all available methods type(matrix_multiply_method_type), dimension(*), parameter :: multiply_methods = [ & #ifdef LAPACK matmul_lapack, & #endif matmul_intrinsic, & matmul_custom & ] !> Provides a wrapper to different matrix multiplication methods interface matmul_wrapper module procedure matmul_wrapper_complex_2d end interface matmul_wrapper !> Used to store the default method to use if not passed to matmul_wrapper. !> This way we can override the default (e.g. after timing performance) !> whilst still having a reasonable default type(matrix_multiply_method_type) :: matmul_default_method = matmul_intrinsic contains !> Helper routine to get access to the matrix_multiply_method's flag !> in a read-only way. Primarily used for testing. elemental integer function matrix_multiply_method_get_flag(self) result(flag) implicit none class(matrix_multiply_method_type), intent(in) :: self flag = self%flag end function matrix_multiply_method_get_flag !> Helper routine to get access to the matrix_multiply_method's name !> in a read-only way. Primarily used for testing. elemental character(len=16) function matrix_multiply_method_get_name(self) result(name) implicit none class(matrix_multiply_method_type), intent(in) :: self name = self%name end function matrix_multiply_method_get_name !> Times each available method for the given problem size and returns !> the flag of the fastest one. function recommend_multiplication_method(trial_size, repeats, display_times) result(method) use job_manage, only: timer_local use file_utils, only: error_unit use ran, only: ranf use optionals, only: get_option_with_default implicit none integer, intent(in) :: trial_size integer, intent(in), optional :: repeats logical, intent(in), optional :: display_times type(matrix_multiply_method_type) :: method real :: start_time integer :: method_index, j, k complex, dimension(:, :), allocatable :: a, b, c integer :: error, number_of_methods, number_of_repeats, repeat real, dimension(:), allocatable :: times logical :: should_display_times ! Initialise to invalid method flag method = matrix_multiply_method_type(flag = -1) ! Try to allocate matrix of trial_size, if it doesn't work report error ! and return. allocate(a(trial_size, trial_size), stat = error) allocate(b(trial_size, trial_size), stat = error) allocate(c(trial_size, trial_size), stat = error) if (error > 0) then write(error_unit(), '("Unable to allocated trial array with size ",I0)') trial_size return end if ! Initialise arrays do j = 1, trial_size do k = 1, trial_size a(j,k) = cmplx(ranf(), ranf()) b(j,k) = a(j,k) c(j,k) = 0.0 end do end do number_of_methods = size(multiply_methods) allocate(times(number_of_methods)) times = 0.0 number_of_repeats = get_option_with_default(repeats, 1) should_display_times = get_option_with_default(display_times, .false.) ! Time each method do method_index = 1, number_of_methods start_time = timer_local() do repeat = 1, number_of_repeats c = matmul_wrapper(a, b, multiply_methods(method_index)) end do times(method_index) = timer_local() - start_time end do ! Report the results if requested if (should_display_times) then write(*,'("Matrix multiplication benchmark : ",I0,& &" repeats of square matrices with size ",I0)') number_of_repeats, trial_size write(*,'("Method",T18,"Total time (s)")') do method_index = 1, number_of_methods method = multiply_methods(method_index) write(*,'(A16,T18,F12.8)') trim(method%get_name()), times(method_index) end do end if ! Recommend the method with the smallest time method = multiply_methods( minloc(times, dim = 1)) end function recommend_multiplication_method !> Wrapper to complex matrix multiplication with two 2D matrices function matmul_wrapper_complex_2d(a, b, method) result(output) use mp, only: mp_abort #ifdef LAPACK use lapack_wrapper, only: gemm #endif implicit none complex, dimension(:, :), intent(in) :: a, b type(matrix_multiply_method_type), optional, intent(in) :: method type(matrix_multiply_method_type) :: internal_method complex, dimension(:, :), allocatable :: output complex, parameter :: one = 1.0, zero = 0.0 complex :: tmp integer :: d1, d2, d3 integer :: i, j, k d1 = size(a, dim = 1) d2 = size(b, dim = 2) d3 = size(a, dim = 2) ! Should match size(b, dim = 1) #ifdef GK_DEBUG if (d3 /= size(b, dim = 1)) error stop #endif allocate(output(d1, d2)) internal_method = matmul_default_method if (present(method)) internal_method = method select case (internal_method%flag) case(matmul_intrinsic%flag) output = matmul(a, b) #ifdef LAPACK case(matmul_lapack%flag) call gemm('n', 'n', d1, d2, d3, one, a, d1, b, d3, zero, output, d1) #endif case(matmul_custom%flag) !$OMP PARALLEL DO DEFAULT(none)& !$OMP PRIVATE(i, j, k, tmp) & !$OMP SHARED(output, d1, d2, d3, a, b) & !$OMP SCHEDULE(static) do j = 1, d2 output(:, j) = zero do k = 1, d3 tmp = b(k, j) do i = 1, d1 output(i, j) = output(i, j) + tmp * a(i, k) end do end do end do !$OMP END PARALLEL DO case default call mp_abort("Unknown multiply method passed to matrix_multiply::matmul_wrapper", .true.) end select end function matmul_wrapper_complex_2d end module matrix_multiply