[nvshmem4py] nvshmem.core.finalize() does not handle everything

what I expect about nvshmem.core.finalize: find all not released nvshmem.core.tensor and release them. nvshmem.core.finalize will take care of everything without an error or exception.

but it’s not the case.

if you create

A = nshmem.core.tensor

and then call

B = nvshmem.core.get_peer_tensor(A, rank)

to increase the tensor reference. And you forget to call nvshmem.core.free_tensor(A) and just call nvshmem.core.finialize(). Then you get punished by nvshmem4py: nvshmem.core.finalize will only call deallocate on the buffer once, but the buffer has 2 references, so it only decreases the reference to 1 and does not free it.

after nvshmem.core.finalize() thinks it frees all buffer, it called nvshmem_finalize.

then python destruction goes on, and clean all not destructed buffers, then deallocate all again. This time it tried to call nvshmem_free, but now nvshmem is already finalized! then we got a panic.

this is the sample code:


import torch.distributed as dist
import torch
import nvshmem.core
import os
from cuda.core.experimental import Device, system
import nvshmem.core.utils

class PyTorchStreamWrapper:
    def __init__(self, pt_stream):
        self.pt_stream = pt_stream
        self.handle = pt_stream.cuda_stream

    def __cuda_stream__(self):
        stream_id = self.pt_stream.cuda_stream
        return (0, stream_id)  # Return format required by CUDA Python

def torchrun_uid_init():
    """
    Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
    """
    # Set Torch device
    local_rank = int(os.environ['LOCAL_RANK'])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    # nvshmem4py requires a cuda.core Device at init time
    global dev
    dev = Device(device.index)
    dev.set_current()
    global stream
    # Get PyTorch's current stream
    pt_stream = torch.cuda.current_stream()
    stream = PyTorchStreamWrapper(pt_stream)

    # Initialize torch.distributed process group
    world_size = torch.cuda.device_count()
    dist.init_process_group(
        backend="cpu:gloo,cuda:nccl",
        rank=local_rank,
        world_size=world_size,
        device_id=device
    )

    # Extract rank, nranks from process group
    num_ranks = dist.get_world_size()
    rank_id = dist.get_rank()

    # Create an empty uniqueid for all ranks
    uniqueid = nvshmem.core.get_unique_id(empty=True)
    if rank_id == 0:
        # Rank 0 gets a real uniqueid
        uniqueid = nvshmem.core.get_unique_id()
        broadcast_objects = [uniqueid]
    else:
        broadcast_objects = [None]

    # We use torch.distributed.broadcast_object_list to send the UID to all ranks
    dist.broadcast_object_list(broadcast_objects, src=0)
    dist.barrier()

    nvshmem.core.init(device=dev, uid=broadcast_objects[0], rank=rank_id, nranks=num_ranks, initializer_method="uid")


if __name__ == '__main__':
    torchrun_uid_init()
    n_elements = 867530
    nvshmem.core.utils._configure_logging(level="DEBUG")
    tensor_out = nvshmem.core.tensor((n_elements,), dtype=torch.float32)
    ts = [nvshmem.core.get_peer_tensor(tensor_out, peer) for peer in range(nvshmem.core.n_pes())]
    # nvshmem.core.free_tensor(tensor_out)
    nvshmem.core.finalize()
    dist.destroy_process_group()

this is the log

[W704 06:55:14.322782761 ProcessGroupGloo.cpp:727] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Creating NvshmemResource for device 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Created Buffer on resource <NvshmemResource device=<Device 0 (NVIDIA H800)>> at address 1406077503488 with size 3470120 on stream None
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Found already tracked peer buffer with address 1406077503488. Returning it. Ref count 2
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 1543516456960. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 1680955410432. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 1818394363904. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 1955833317376. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 2093272270848. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 2230711224320. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Did not find peer buffer with address 2368150177792. Creating a new one.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : nvshmem_finalize() called
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM ERROR : Found un-freed memory object with address 1406077503488 at fini time
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM ERROR : Found 1 un-freed memory objects at fini time
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 1406077503488 and ref count 2. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1406077503488
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on  buf 1406077503488 1
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 1543516456960 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1543516456960
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 1543516456960 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 1680955410432 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1680955410432
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 1680955410432 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 1818394363904 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1818394363904
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 1818394363904 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 1955833317376 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1955833317376
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 1955833317376 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 2093272270848 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2093272270848
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 2093272270848 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 2230711224320 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2230711224320
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 2230711224320 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM INFO : Found object open at pointer 2368150177792 and ref count 1. Freeing it.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2368150177792
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on peer buf 2368150177792 0
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : free() requested on a peer buffer. Not calling free()
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2368150177792
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 2368150177792 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2230711224320
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 2230711224320 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 2093272270848
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 2093272270848 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1955833317376
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 1955833317376 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1818394363904
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 1818394363904 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1680955410432
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 1680955410432 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1543516456960
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Ref count on 1543516456960 is already 0. Already freed.
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : Free called on buffer with address 1406077503488
H800-1-docker-n122-200-178:1585472:1585472 [0] NVSHMEM DEBUG : New ref count on  buf 1406077503488 0
/dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/host/mem/mem_heap.cpp:nvshmem_free:1702: NVSHMEM API called before NVSHMEM initialization has completed

/dvs/p4/build/sw/rel/gpgpu/toolkit/r12.8/main_nvshmem/src/host/util/cs.cpp:21: non-zero status: 16: Device or resource busy, exiting... mutex destroy failed

Hi, thank you for the report.

I think there are a couple of things going on here:

  1. If NVSHMEM4Py finds buffers with ref count above 1 at finalize time, it probably makes sense to free everything. This matches with the C++ NVSHMEM user experience. I will take this request back to the rest of the team for discussion.
  2. We have heard similar, but slightly different feedback from others where they requested the ability to allocate buffers but not track them internally to NVSHMEM4Py. In this mode, it’d be on the user to make sure that references are freed in the total global order satisfied by NVSHMEM4Py.

Note that you are never expected to explicitly free NVSHMEM4Py Peer tensors. Freeing the parent tensor will take care of that for you.

And actually another thing: torch tensor use stream-based allocator, that is in most case the CUDACachingAllocator. so memory allocate/deallocate is async. but here nvshmem works with pytorch, with also an async allocate/deallocate, but with no stream semantic. it’s error prune.

i suggest that nvshmem.core.tensor/free_tensor is sync: just like cudaMalloc/cudaFree with implicit sync semantic.

here is a sample that cause the problem:

with torch.cuda.stream(stream_A):
  buffer_A = nvshmem.core.tensor(shape, dtype)
  buffer_A.fill_(1)
  # so some other things.
  nvshmem.core.free_tensor(buffer_A)

with torch.cuda.stream(stream_B):
  buffer_B = nvshmem.core.tensor(shape, dtype)  # buffer_B may share the same symmetric memory with buffer_A, which may cause undefined behavior.
  buffer_B.fill_(1)
  # so some other things.
  nvshmem.core.free_tensor(buffer_B)

Thank you for your feedback.

nvshmem_malloc() and nvshmem_free() have implicit barriers, so they will already be synchronized across the NVSHMEM internal stream. We can consider your feedback about the explicit stream semantics. I would like to understand more of your use case.

the nvshmem_free have a implicit barriers, but on device side. and on CPU side I found a cudaStreamSynchronize(nvshmemi_device_start)->my_stream. i did not notice a cudaStreamWaitEvent or something like that, so i think that has nothing to do with torch.cuda.current_stream().

so i suppose nvshmem_malloc/nvshmem_free is still async.

i write a sample code:


import torch.distributed as dist
import torch
import triton
import triton.language as tl
import nvshmem.core
import os
from cuda.core.experimental import Device, system
import nvshmem.core.utils

from triton_dist.utils import group_profile
from torch.profiler import record_function


class PyTorchStreamWrapper:
    def __init__(self, pt_stream):
        self.pt_stream = pt_stream
        self.handle = pt_stream.cuda_stream

    def __cuda_stream__(self):
        stream_id = self.pt_stream.cuda_stream
        return (0, stream_id)  # Return format required by CUDA Python


def torchrun_uid_init():
    """
    Initialize NVSHMEM using UniqueID with `torchrun` as the launcher
    """
    # Set Torch device
    local_rank = int(os.environ["LOCAL_RANK"])
    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)

    # nvshmem4py requires a cuda.core Device at init time
    global dev
    dev = Device(device.index)
    dev.set_current()
    global stream
    # Get PyTorch's current stream
    pt_stream = torch.cuda.current_stream()
    stream = PyTorchStreamWrapper(pt_stream)

    # Initialize torch.distributed process group
    world_size = torch.cuda.device_count()
    dist.init_process_group(
        backend="cpu:gloo,cuda:nccl", rank=local_rank, world_size=world_size, device_id=device
    )

    # Extract rank, nranks from process group
    num_ranks = dist.get_world_size()
    rank_id = dist.get_rank()

    # Create an empty uniqueid for all ranks
    uniqueid = nvshmem.core.get_unique_id(empty=True)
    if rank_id == 0:
        # Rank 0 gets a real uniqueid
        uniqueid = nvshmem.core.get_unique_id()
        broadcast_objects = [uniqueid]
    else:
        broadcast_objects = [None]

    # We use torch.distributed.broadcast_object_list to send the UID to all ranks
    dist.broadcast_object_list(broadcast_objects, src=0)
    dist.barrier()

    nvshmem.core.init(
        device=dev,
        uid=broadcast_objects[0],
        rank=rank_id,
        nranks=num_ranks,
        initializer_method="uid",
    )


@triton.jit(do_not_specialize=["value"])
def sleep_and_memset_and_memcpy(duration_ms, symm_ptr, out_ptr, N, value, BLOCK_SIZE: tl.constexpr):

    nblocks = tl.cdiv(N, BLOCK_SIZE)
    for bid in range(nblocks):
        start = bid * BLOCK_SIZE
        offs = start + tl.arange(0, BLOCK_SIZE)
        tl.store(symm_ptr + offs, value, offs < N)

    for _ in range(duration_ms):
        tl.inline_asm_elementwise(
            """
            nanosleep.u32 1000000;
            bar.sync 0;
            mov.u32 $0, 0;""",
            constraints="=r",
            args=[],
            dtype=tl.uint32,
            is_pure=False,
            pack=1,
        )  # suppose this's 1ms

    for bid in range(nblocks):
        start = bid * BLOCK_SIZE
        offs = start + tl.arange(0, BLOCK_SIZE)
        x = tl.load(symm_ptr + offs, offs < N)
        tl.store(out_ptr + offs, x, mask=offs < N)


def test_multistream():
    n_elements = 1024 * 1024 * 200
    stream_A = torch.cuda.Stream()
    stream_B = torch.cuda.Stream()
    buffer_A = torch.empty(n_elements, dtype=torch.float32, device="cuda")
    buffer_B = torch.empty(n_elements, dtype=torch.float32, device="cuda")
    buffer_A.fill_(-1)
    buffer_B.fill_(-1)

    # as warmup
    sleep_and_memset_and_memcpy[(1,)](1000, buffer_A, buffer_B, n_elements, 0, BLOCK_SIZE=1024)

    torch.cuda.synchronize()
    torch.cuda._sleep(100000000)  # warmup

    with torch.cuda.stream(stream_A):
        with record_function("nvshmem_malloc_A"):
            out_A = nvshmem.core.tensor((n_elements,), dtype=torch.float32)
        sleep_and_memset_and_memcpy[(1,)](1000, out_A, buffer_A, n_elements, 2, BLOCK_SIZE=1024)
        with record_function("nvshmem_free_A"):
            nvshmem.core.free_tensor(out_A)


    with torch.cuda.stream(stream_B):
        with record_function("nvshmem_malloc_B"):
            out_B = nvshmem.core.tensor((n_elements,), dtype=torch.float32)
        sleep_and_memset_and_memcpy[(1,)](1000, out_B, buffer_B, n_elements, 3, BLOCK_SIZE=1024)
        with record_function("nvshmem_free_B"):
            nvshmem.core.free_tensor(out_B)

    torch.cuda.synchronize()
    print(buffer_A)  # 2 expected
    print(buffer_B)  # 3 expected


if __name__ == "__main__":
    torchrun_uid_init()
    nvshmem.core.utils._configure_logging(level="DEBUG")

    with group_profile(name="nvshmem_malloc", do_prof=True, group=torch.distributed.group.WORLD):
        test_multistream()

    nvshmem.core.finalize()
    dist.destroy_process_group()

and I got a log like this:

torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=10.122.200.178:12345 --log_dir log -r 3 -t 0:3 --rdzv_id 20250707_085901 /data01/houqi.1993/ProgramFiles/nvshmem_python-source-0.1.0.36132199_cuda12-archive/examples/torch_triton_test.py
+ torchrun --node_rank=0 --nproc_per_node=8 --nnodes=1 --rdzv_endpoint=10.122.200.178:12345 --log_dir log -r 3 -t 0:3 --rdzv_id 20250707_085901 /data01/houqi.1993/ProgramFiles/nvshmem_python-source-0.1.0.36132199_cuda12-archive/examples/torch_triton_test.py
[default0]:[W707 08:59:09.992420982 ProcessGroupGloo.cpp:727] Warning: Unable to resolve hostname to a (local) address. Using the loopback address as fallback. Manually set the network interface to bind to with GLOO_SOCKET_IFNAME. (function operator())
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Creating NvshmemResource for device 0
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Created Buffer on resource <NvshmemResource device=<Device 0 (NVIDIA H800)>> at address 139852068685312 with size 838860800 on stream None
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Free called on buffer with address 139852068685312
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : New ref count on  buf 139852068685312 0
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Freed buffer at address 139852068685312
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Created Buffer on resource <NvshmemResource device=<Device 0 (NVIDIA H800)>> at address 139852068685312 with size 838860800 on stream None
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Free called on buffer with address 139852068685312
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : New ref count on  buf 139852068685312 0
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Freed buffer at address 139852068685312
[default0]:tensor([3., 3., 3.,  ..., 3., 3., 3.], device='cuda:0')
[default0]:tensor([3., 3., 3.,  ..., 3., 3., 3.], device='cuda:0')
[default0]:DEBUG:nvshmem:nvshmem_finalize() called
[default0]:INFO:nvshmem:Found object open at pointer 139852068685312 and ref count 0. Freeing it.
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : nvshmem_finalize() called
[default0]:DEBUG:nvshmem:Free called on buffer with address 139852068685312
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM INFO : Found object open at pointer 139852068685312 and ref count 0. Freeing it.
[default0]:DEBUG:nvshmem:Ref count on 139852068685312 is already 0. Already freed.
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Free called on buffer with address 139852068685312
[default0]:H800-1-docker-n122-200-178:2043386:2043386 [0] NVSHMEM DEBUG : Ref count on 139852068685312 is already 0. Already freed.

in both stream_A/stream_B, we got 3 as output, which is not expected.

and by the way, i found kernel void barrier_on_stream_kernel_threadgroup<(threadgroup_t)1>(int, int) from nvshmem_free does not overlap with the other kernels except the last one, which is beyond my understanding.

I see no cudaEventRecord and cudaStreamWaitEvent. Why does barrier_on_stream_kernel_threadgroup wait for other kernels? Where is the code in nvshmem to ensure that.

Yes, that’s correct. The barriers are on the device side. I will bring back your feature request for alloc_on_stream /free_on_stream to the rest of the team. Thanks for the request.