Thank you very much!!!
Well, I am using official code: cutlass/examples/cute/tutorial/sgemm_nt_1.cu at main · NVIDIA/cutlass · GitHub
But this template function uses “auto sC = make_layout(make_shape(bM,bN));” as input and I can not instantiate it directly… I tried this:
using StrideType_dA = decltype(make_stride(Int<1>{}, ldA));
using StrideType_dB = decltype(make_stride(Int<1>{}, ldB));
using StrideType_dC = decltype(make_stride(Int<1>{}, ldC));
using LayoutType_sA = decltype(make_layout(make_shape(bM,bK)));
using LayoutType_sB = decltype(make_layout(make_shape(bN,bK)));
using LayoutType_sC = decltype(make_layout(make_shape(bM,bN)));
using LayoutType_tA = decltype(make_layout(make_shape(Int<16>{}, Int< 8>{})));
using LayoutType_tB = decltype(make_layout(make_shape(Int<16>{}, Int< 8>{})));
using LayoutType_tC = decltype(make_layout(make_shape(Int<8>{}, Int<16>{})));
// Host code
int maxbytes = 99328; // 96 KB+1024
cudaFuncSetAttribute(gemm_device<int, int, int,
TA const*, StrideType_dA, LayoutType_sA, LayoutType_tA,
TB const*, StrideType_dB, LayoutType_sB, LayoutType_tB,
TC *, StrideType_dC, LayoutType_sC, LayoutType_tC,
Alpha, Beta>, cudaFuncAttributeMaxDynamicSharedMemorySize, maxbytes);
And I failed:
(base) a100-04% nvcc -o sgemm_nt_1 sgemm_nt_1.cu -arch=sm_80 -std=c++17 -I /home/zyhuang/cutlass/include -I /home/zyhuang/cutlass/tools/util/include
/home/zyhuang/cutlass/include/cute/tensor.hpp(391): error: no instance of overloaded function "cute::MakeTensor<T>::operator() [with T=float *]" matches the argument list
argument types are: (const cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>)
object type is: cute::MakeTensor<float *>
detected during:
instantiation of "auto cute::make_tensor<T,Args...>(const Args &...) [with T=float *, Args=<cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>>]"
(445): here
instantiation of "auto cute::make_fragment_like<NewT,Layout>(const Layout &) [with NewT=float *, Layout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_8, int>>]"
(461): here
instantiation of "auto cute::make_fragment_like(const cute::Tensor<Engine, Layout> &) [with Engine=cute::ViewEngine<cute::gmem_ptr<float *>>, Layout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_8, int>>]"
sgemm_nt_1.cu(153): here
instantiation of "void gemm_device(MShape, NShape, KShape, const TA *, AStride, ABlockLayout, AThreadLayout, const TB *, BStride, BBlockLayout, BThreadLayout, TC *, CStride, CBlockLayout, CThreadLayout, Alpha, Beta) [with MShape=int, NShape=int, KShape=int, TA=const float *, AStride=cute::tuple<cute::C<1>, int>, ABlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, AThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TB=const float *, BStride=cute::tuple<cute::C<1>, int>, BBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, BThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TC=float *, CStride=cute::tuple<cute::C<1>, int>, CBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<128>>, cute::tuple<cute::_1, cute::_128>>, CThreadLayout=cute::Layout<cute::tuple<cute::C<8>, cute::C<16>>, cute::tuple<cute::_1, cute::_8>>, Alpha=float, Beta=float]"
sgemm_nt_1.cu(331): here
instantiation of "void gemm(int, int, int, Alpha, const TA *, int, const TB *, int, Beta, TC *, int, cudaStream_t) [with TA=float, TB=float, TC=float, Alpha=float, Beta=float]"
sgemm_nt_1.cu(432): here
sgemm_nt_1.cu(156): error: no instance of overloaded function "clear" matches the argument list
argument types are: (<error-type>)
detected during:
instantiation of "void gemm_device(MShape, NShape, KShape, const TA *, AStride, ABlockLayout, AThreadLayout, const TB *, BStride, BBlockLayout, BThreadLayout, TC *, CStride, CBlockLayout, CThreadLayout, Alpha, Beta) [with MShape=int, NShape=int, KShape=int, TA=const float *, AStride=cute::tuple<cute::C<1>, int>, ABlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, AThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TB=const float *, BStride=cute::tuple<cute::C<1>, int>, BBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, BThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TC=float *, CStride=cute::tuple<cute::C<1>, int>, CBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<128>>, cute::tuple<cute::_1, cute::_128>>, CThreadLayout=cute::Layout<cute::tuple<cute::C<8>, cute::C<16>>, cute::tuple<cute::_1, cute::_8>>, Alpha=float, Beta=float]"
(331): here
instantiation of "void gemm(int, int, int, Alpha, const TA *, int, const TB *, int, Beta, TC *, int, cudaStream_t) [with TA=float, TB=float, TC=float, Alpha=float, Beta=float]"
(432): here
sgemm_nt_1.cu(240): error: no instance of overloaded function "gemm" matches the argument list
argument types are: (cute::Tensor<cute::ViewEngine<cute::smem_ptr<const float *>>, cute::Layout<cute::tuple<cute::C<16>, cute::C<64>>, cute::tuple<cute::_8, cute::_128>>>, cute::Tensor<cute::ViewEngine<cute::smem_ptr<const float *>>, cute::Layout<cute::tuple<cute::_8, cute::_64>, cute::tuple<cute::_16, cute::C<128>>>>, <error-type>)
detected during:
instantiation of "void gemm_device(MShape, NShape, KShape, const TA *, AStride, ABlockLayout, AThreadLayout, const TB *, BStride, BBlockLayout, BThreadLayout, TC *, CStride, CBlockLayout, CThreadLayout, Alpha, Beta) [with MShape=int, NShape=int, KShape=int, TA=const float *, AStride=cute::tuple<cute::C<1>, int>, ABlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, AThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TB=const float *, BStride=cute::tuple<cute::C<1>, int>, BBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, BThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TC=float *, CStride=cute::tuple<cute::C<1>, int>, CBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<128>>, cute::tuple<cute::_1, cute::_128>>, CThreadLayout=cute::Layout<cute::tuple<cute::C<8>, cute::C<16>>, cute::tuple<cute::_1, cute::_8>>, Alpha=float, Beta=float]"
(331): here
instantiation of "void gemm(int, int, int, Alpha, const TA *, int, const TB *, int, Beta, TC *, int, cudaStream_t) [with TA=float, TB=float, TC=float, Alpha=float, Beta=float]"
(432): here
sgemm_nt_1.cu(251): error: no instance of overloaded function "axpby" matches the argument list
argument types are: (float, <error-type>, float, cute::Tensor<cute::ViewEngine<cute::gmem_ptr<float *>>, cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_8, int>>>)
detected during:
instantiation of "void gemm_device(MShape, NShape, KShape, const TA *, AStride, ABlockLayout, AThreadLayout, const TB *, BStride, BBlockLayout, BThreadLayout, TC *, CStride, CBlockLayout, CThreadLayout, Alpha, Beta) [with MShape=int, NShape=int, KShape=int, TA=const float *, AStride=cute::tuple<cute::C<1>, int>, ABlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, AThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TB=const float *, BStride=cute::tuple<cute::C<1>, int>, BBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<64>>, cute::tuple<cute::_1, cute::_128>>, BThreadLayout=cute::Layout<cute::tuple<cute::_16, cute::_8>, cute::tuple<cute::_1, cute::_16>>, TC=float *, CStride=cute::tuple<cute::C<1>, int>, CBlockLayout=cute::Layout<cute::tuple<cute::C<128>, cute::C<128>>, cute::tuple<cute::_1, cute::_128>>, CThreadLayout=cute::Layout<cute::tuple<cute::C<8>, cute::C<16>>, cute::tuple<cute::_1, cute::_8>>, Alpha=float, Beta=float]"
(331): here
instantiation of "void gemm(int, int, int, Alpha, const TA *, int, const TB *, int, Beta, TC *, int, cudaStream_t) [with TA=float, TB=float, TC=float, Alpha=float, Beta=float]"
(432): here
4 errors detected in the compilation of "sgemm_nt_1.cu".
Do you have any idea? Thanks!!!