I have 3 dimensional matrix S with dimensions [A, R, M], and I have 2 dimensional kernel K with dimensions [M, M].
I need to generate the output array Q with dimensions [A, R], and Qij element of this array is calculated as
S[i,j,:] * K * S[i,j,:] ** H
(Hermitian conjugated).
I can do this brute force by doing gemm on flattened S’ with dimensions [A*R, M], and then extract diagonal elements from the resulting 2D array, but this is not efficient.
Any pointers/suggestions on writing kernel to handle this task?