Hi all,
I’m seeing some interesting behavior when colocating two kernels on the GPU with and without MPS and I’m having trouble explaining the difference in behavior.
I’d like to colocate a PyTorch Softmax kernel and a custom compute kernel of mine that just does some arithmetic operations and is supposed to keep all its operands in registers. Below is an extract of the code. I call the compute kernel from Python through the ctypes interface (I omitted that code for readability, but happy to share if helpful)
import torch.multiprocessing as mp
import argparse
import json
import os
import torch
from vllm.interference.inter_funcs import kernel_name_dict
def run_softmax(args, barrier):
logits = torch.rand([16, 128256]).cuda()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
if barrier:
barrier.wait()
for i in range(100):
start_event.record()
probs = logits.softmax(dim=-1, dtype=torch.float32)
end_event.record()
torch.cuda.synchronize()
print(f"Step {i} took {start_event.elapsed_time(end_event)} ms")
torch.cuda.synchronize()
def run_inter(args, barrier):
# setting an MPS share depending on the scenario
# os.environ['CUDA_MPS_ACTIVE_THREAD_PERCENTAGE'] = "50"
init_args = {"device": torch.device('cuda'), "num_floats": 1024}
run_args = {"num_tb": args.num_tb, "threads_per_block": 1024, "num_itrs": 500000000}
# run kernel once, lazy loading
inter_kernel = interference_class(**init_args)
inter_kernel.run(**run_args)
barrier.wait()
inter_kernel.run(**run_args)
print(f"Exiting....")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--num_tb", type=int, default=0, help="Number of thread blocks of the interfering kernel")
user_args = parser.parse_args()
mp.set_start_method('spawn')
do_interference = False
if user_args.inter_tb > 0:
do_interference = True
barrier = mp.Barrier(2) if do_interference else None
softmax = mp.Process(target=run_softmax, args=(user_args, barrier))
softmax.start()
if do_interference:
interf = mp.Process(target=run_inter, args=(user_args, barrier))
interf.start()
softmax.join()
if do_interference:
interf.join()
// the compute kernel
__global__ void compute_kernel(float *a, float *b, float *c, long long num_itrs)
{
float op1 = a[threadIdx.x];
float op2 = b[threadIdx.x];
float op3 = 1.0f;
float op4 = 1.0f;
float op5 = 1.0f;
float op6 = 1.0f;
for (long long i = 0; i < num_itrs; i++)
{
op3 = __fmaf_rn(op1, op2, op3);
op4 = __fmaf_rn(op1, op2, op4);
op5 = __fmaf_rn(op1, op2, op5);
op6 = __fmaf_rn(op1, op2, op6);
}
c[threadIdx.x] = op3 + op4 + op5 + op6;
}
I’m running on a H100 with CUDA 12.9 and driver 575.57.08. I’m seeing the following:
- when the softmax runs alone (setting num_tb arg to 0), it’s latency is roughly 65 microseconds
- when I colocate with the compute kernel WITHOUT MPS and launch the compute kernel with 132 thread blocks (matching the 132 SMs on the H100), the softmax latency is still on the order of 65 microseconds. I checked with nsys that both kernels indeed overlap. Note that the GPU is in the default compute mode. I also printed the SM IDs of each of the thread block of the compute kernel and can see that all of the 132 SMs host on thread block.
- when I colocate with the compute kernel WITH MPS (just running the MPS server) but NOT setting the CUDA_MPS_ACTIVE_THREAD_PERCENTAGE var, the latency of the softmax suddenly increases to 100ms (several orders of magnitude higher). I don’t get what exactly is happening here? How is this case different from the one before? Again I see both kernels overlapping in the nsys trace and that each SM is hosting one thread block of the compute kernel.
- when I repeat the previous scenario but this time set the CUDA_MPS_ACTIVE_THREAD_PERCENTAGE to 50, the latency of softmax again decreases to roughly 65 microseconds. My assumption here was that the compute kernel saturates the 66 SMs it is given access to by using all 2048 threads on them. The softmax kernel thus has to run on any of the other 66 SMs and should not suffer from any interference. However I still cannot explain the behavior I’m seeing between the second and third scenario.
Does anyone have an idea on what is happening here? I don’t see how to explain the increase in latency? If it was because of interference, why wouldn’t I see that interference as well in the second scenario? Also I wouldn’t expect such an explosion in latency purely due to interference.
Any help is as always much appreciated!! Thanks a lot!