Wrapper to complex matrix multiplication with two 2D matrices
Type | Intent | Optional | Attributes | Name | ||
---|---|---|---|---|---|---|
complex, | intent(in), | dimension(:, :) | :: | a | ||
complex, | intent(in), | dimension(:, :) | :: | b | ||
type(matrix_multiply_method_type), | intent(in), | optional | :: | method |
function matmul_wrapper_complex_2d(a, b, method) result(output)
use mp, only: mp_abort
#ifdef LAPACK
use lapack_wrapper, only: gemm
#endif
implicit none
complex, dimension(:, :), intent(in) :: a, b
type(matrix_multiply_method_type), optional, intent(in) :: method
type(matrix_multiply_method_type) :: internal_method
complex, dimension(:, :), allocatable :: output
complex, parameter :: one = 1.0, zero = 0.0
complex :: tmp
integer :: d1, d2, d3
integer :: i, j, k
d1 = size(a, dim = 1)
d2 = size(b, dim = 2)
d3 = size(a, dim = 2) ! Should match size(b, dim = 1)
#ifdef GK_DEBUG
if (d3 /= size(b, dim = 1)) error stop
#endif
allocate(output(d1, d2))
internal_method = matmul_default_method
if (present(method)) internal_method = method
select case (internal_method%flag)
case(matmul_intrinsic%flag)
output = matmul(a, b)
#ifdef LAPACK
case(matmul_lapack%flag)
call gemm('n', 'n', d1, d2, d3, one, a, d1, b, d3, zero, output, d1)
#endif
case(matmul_custom%flag)
!$OMP PARALLEL DO DEFAULT(none)&
!$OMP PRIVATE(i, j, k, tmp) &
!$OMP SHARED(output, d1, d2, d3, a, b) &
!$OMP SCHEDULE(static)
do j = 1, d2
output(:, j) = zero
do k = 1, d3
tmp = b(k, j)
do i = 1, d1
output(i, j) = output(i, j) + tmp * a(i, k)
end do
end do
end do
!$OMP END PARALLEL DO
case default
call mp_abort("Unknown multiply method passed to matrix_multiply::matmul_wrapper", .true.)
end select
end function matmul_wrapper_complex_2d