[Question] How does the threads in a warp work collectively?

I’m very new to cuda and cuda ptx, my question after reading the ptx (which is very hard for me to understand) is that, for example m16n8k8, since each thread holds 4 elements from A and 2 from B, how do they collectively ‘sum’ each other’s result to generate the final 16x8 matrix?

presumably you are referring to matrix-multiply instructions.

The ptx instructions that begin (for example) with mma are “feeding” a functional unit in the SM that does all the work to produce the result. The threads are not executing microcode or otherwise collaborating, except insofar as they feed data to the functional unit (in the form of a register patch holding input data), and insofar as they receive the result from the functional unit in the form of a register patch holding the output data.

Conceptually, this is similar to other functional unit behavior. The SP (FP32) units, for example, when processing a FFMA instruction, receive input data from registers and put their output data into registers. The “threads” do not otherwise assist in the generation of the result, other than to feed the functional unit.

One difference with the matrix-multiply instructions is of course that an entire warp must participate. However the combining of input data from different threads is not done by the threads themselves, but rather by the tensorcore unit. The tensorcore unit then distributes results back to each thread’s register patch.

Thanks! It really is a very basic concept of tensorcore lol. Suggests that for now we only need to focus on the data layout, which is critical to the output.

After that I was using wmma (which I think the layout corresponding to the thread is similar to mma) and performing a direct access to the data in the fragment (just want to see how the data is stored).

// m16n16k16 wmma
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::col_major> frag_in;

However I found that in frag_in (16 elements in total of frag_in), the first eight elements are the same as the last eight elements.

Is this normal?

AFAIK the data layout for the wmma instructions is unspecified. Opaque load and store instruction(s) are provided.

Just create some code like:

		half2 b{ threadIdx.x, threadIdx.x + .5 };
		unsigned int B0 = reinterpret_cast<unsigned int&>(b); // actually UB, but usually accepted by nvcc
		asm volatile(
			"mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%5, %6};\n"
			: "=r"(D0), "=r"(D1)
			: "r"(A0), "r"(A1), "r"(B0), "r"(C0), "r"(C1));
		);

A0 and A1 similar as B0, set C0 and C1 to zero and try out, which results appear where.
Then you can compare with the documentation.

1 Like

This topic was automatically closed 14 days after the last reply. New replies are no longer allowed.