Performances of multi-thread vs multi-process with MPS

Hello,

I did some benchmarking of different architectures to deploy some CUDA enabled software and this raised some questions I’m unable to answer. I hope someone here might have a clue.

The architectures studied here for deployment are:

  1. Multi-threaded application using 1 cuda stream per thread, MPS not running (“threads” in plot legend)
  2. Multiple instances of a single threaded application running over MPS (“processes + MPS” in plot legend)
  3. Multi-threaded application using 1 cuda stream per thread running over MPS (“threads + MPS” in plot legend)
  4. Multiple instances of a multi-threaded application running over MPS (not drawn in plot to keep sample code simple enough. This gives the better throughput for the actual workload we have using 2-threads per process, 16 processes over MPS)

The work (c.f. doWork function in code below) consist of copying memory from host to device, running a bunch of CUDA kernels, copying some memory back from the device. I’ve tested different things like using pageable memory or page-locked/pinned memory on the host side. When host side memory is pageable, I’ve tested with and without an explicit stream synchronization before the cudaMemcpyAsync(DeviceToHost).

Here’s a plot of the different experiments (all done on Ubuntu 16.04, CUDA 9.2, Drivers 396.37, Tesla P40, 2 x Xeon E5-2680 v4 - 56 threads total - , all processes started with affinity on the processor closest to the GPU):
https://user-images.githubusercontent.com/9768336/44277788-121e7400-a24c-11e8-995c-f0da744f0db3.png

  • A. let’s compare “threads + nosync + pageable” and “threads + sync + pageable”. The only difference is that a cudaStreamSynchronize is done before copying memory back from the GPU into host pageable memory i.e. cudaMemcpyAsync(DeviceToHost). The CUDA Runtime API documentation only states that “For transfers from device memory to pageable host memory, the function will return only once the copy has completed.” - Chapter 2. API synchronization behavior. Never does it state that it also locks CUDA Runtime API calls happening on other threads in different streams (my understanding). Am I understanding the plot correctly in this case ?
  • B. Why does multi-thread does not achieve the throughput of multi-process with MPS ? (note that, whatever host side memory type and synchronization mechanism, multi-process with MPS has the same throughput given a number of processes). Is there something wrong in my multi-threaded code or is there some contention in the CUDA runtime ?
  • Here is the code used to draw the plot, compiled with nvcc --machine 64 -O3 --use_fast_math -gencode=“arch=compute_61,code="sm_61,compute_61"” -o testConcurrency testConcurrency.cu

    #include <stdlib.h>
    #include <string.h>
    #include <stdio.h>
    #include <time.h>
    #include <stdint.h>
    #include <inttypes.h>
    #include <pthread.h>
    #include <sys/mman.h>
    #include <sys/types.h>
    #include <sys/wait.h>
    #include <unistd.h>
    
    #define NBADD 220
    #define NBINNERITER 1000
    #define NBOUTERITER 8
    #define LEN (352 * 1024)
    
    #define checkCudaErrors(val) check( (val), #val, __FILE__, __LINE__ )
    
    static void check(cudaError_t result, char const *const func, const char *const file, int const line)
    {
      if (result != cudaSuccess) {
        fprintf(stderr, "CUDA error at %s:%d code=%d \"%s\" \n", file, line, static_cast<unsigned int>(result), func);
        cudaDeviceReset();
        // Make sure we call CUDA Device Reset before exiting
        exit(EXIT_FAILURE);
      }
    }
    
    static uint64_t getTicksMicroSeconds()
    {
      struct timespec ts;
      clock_gettime(CLOCK_MONOTONIC, &ts);
      return ((uint64_t)ts.tv_sec) * UINT64_C(1000000) + (uint64_t)(ts.tv_nsec / 1000);
    }
    
    __device__ int4 nop4(int4 v)
    {
      v.x = __sad(v.x, v.x, v.x);
      v.y = __sad(v.y, v.y, v.y);
      v.z = __sad(v.z, v.z, v.z);
      v.w = __sad(v.w, v.w, v.w);
      return v;
    }
    
    __global__ void kernelAddConstant(const int* pSrc, int* pDst, const int value, const int length)
    {
      const int4* pSrcI4 = (const int4*)pSrc;
      int4* pDstI4 = (int4*)pDst;
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      int gridSize = blockDim.x * gridDim.x;
    
      for (int i = idx; i < length/4; i += gridSize) {
        int4 v = pSrcI4[i];
        v.x += value;
        v.y += value;
        v.z += value;
        v.w += value;
        if (idx == i) {
          for (int j = 0; j < 3072; ++j) {
            v = nop4(v);
          }
        }
        pDstI4[i] = v;
      }
    }
    
    static int doWork(const int* cpuExpected, const int* cpuInput, int* cpuOutput, const size_t len, int* gpuInput, int* gpuOutput, cudaStream_t stream, int usePinnedMemory, int noSync)
    {
      checkCudaErrors(cudaMemcpyAsync(gpuInput, cpuInput, len * sizeof(int), cudaMemcpyHostToDevice, stream));
    
      for (int i = 0; i < NBADD; ++i) {
        kernelAddConstant<<<dim3(5), dim3(64), 0, stream>>>(gpuInput, gpuOutput, i, (int)(len / 64));
        int* tmp = gpuInput;
        gpuInput = gpuOutput;
        gpuOutput = tmp;
      }
    
      if (!usePinnedMemory && !noSync) {
        checkCudaErrors(cudaStreamSynchronize(stream));
      }
      checkCudaErrors(cudaMemcpyAsync(cpuOutput, gpuOutput, len * sizeof(int), cudaMemcpyDeviceToHost, stream));
      if (usePinnedMemory) {
        checkCudaErrors(cudaStreamSynchronize(stream));
      }
      return memcmp(cpuOutput, cpuExpected, (len / 64) * sizeof(int));
    }
    
    typedef struct {
      int const* cpuExpected;
      size_t len;
      uint64_t volatile* elapsed;
      int usePinnedMemory;
      int noSync;
    } runThreadArgs;
    
    static void* runThread(void* argsVoid)
    {
      runThreadArgs* args = (runThreadArgs*)argsVoid;
      const int* cpuExpected = args->cpuExpected;
      const size_t len = args->len;
      const int usePinnedMemory = args->usePinnedMemory;
      const int noSync = args->noSync;
      int* cpuInput;
      int* cpuOutput;
      int* gpuInput;
      int* gpuOutput;
      cudaStream_t stream;
    
      if (!usePinnedMemory) {
        cpuInput = (int*)malloc(len * sizeof(int));
        cpuOutput = (int*)malloc(len * sizeof(int));
      }
      else {
        checkCudaErrors(cudaMallocHost(&cpuInput, len * sizeof(int)));
        checkCudaErrors(cudaMallocHost(&cpuOutput, len * sizeof(int)));
      }
      if ((cpuInput == NULL) || (cpuOutput == NULL)) {
        fprintf(stderr, "Allocation failure\n");
        exit(EXIT_FAILURE);
      }
      memset(cpuInput, 0, len * sizeof(int));
    
      checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
      checkCudaErrors(cudaMalloc(&gpuInput, len * sizeof(int)));
      checkCudaErrors(cudaMalloc(&gpuOutput, len * sizeof(int)));
      *args->elapsed = 1U;
      while (*args->elapsed == 1U);
      uint64_t start = getTicksMicroSeconds();
      for (int i = 0; i < NBINNERITER; ++i) {
        if (doWork(cpuExpected, cpuInput, cpuOutput, len, gpuInput, gpuOutput, stream, usePinnedMemory, noSync)) {
          fprintf(stderr, "Invalid result, iter %d\n", i);
          exit(EXIT_FAILURE);
        }
      }
      uint64_t stop = getTicksMicroSeconds();
      *args->elapsed = stop - start;
      checkCudaErrors(cudaFree(gpuInput));
      checkCudaErrors(cudaFree(gpuOutput));
      checkCudaErrors(cudaStreamDestroy(stream));
      if (!usePinnedMemory) {
        free(cpuInput);
        free(cpuOutput);
      }
      else {
        checkCudaErrors(cudaFreeHost(cpuInput));
        checkCudaErrors(cudaFreeHost(cpuOutput));
      }
      return NULL;
    }
    
    int main(int argc, char* argv[])
    {
      const size_t len = LEN;
      int* cpuExpected;
      pthread_t thread[32];
      runThreadArgs args[32];
      pid_t pid[32];
      uint64_t volatile* pElapsed;
      pthread_attr_t attr;
    
      if (argc < 5) {
        exit(EXIT_FAILURE);
      }
      int useFork = atoi(argv[1]);
      int usePinnedMemory = atoi(argv[2]);
      int noSync = atoi(argv[3]);
      argc -= 4;
      argv += 4;
    
      cpuExpected = (int*)malloc(len *sizeof(int));
      if (cpuExpected == NULL) {
        fprintf(stderr, "Allocation failure\n");
        exit(EXIT_FAILURE);
      }
      for (size_t i = 0; i < len; ++i) {
        cpuExpected[i] = ((NBADD-1) * (NBADD-2)) / 2;
      }
    
      pElapsed = (uint64_t volatile*)mmap( 0, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1 /*fd*/, 0);
      if (pElapsed == MAP_FAILED) {
        fprintf(stderr, "mmap failed\n");
        exit(EXIT_FAILURE);
      }
      if (useFork) {
        fprintf(stdout, "Using sub-processes\n");
      }
      else {
        fprintf(stdout, "Using threads\n");
        cudaFree(NULL); /* init runtime API */
        pthread_attr_init(&attr);
        pthread_attr_setstacksize( &attr, 1024 * 1024);
        pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
      }
      for (size_t t = 0; t < (sizeof(args) / sizeof(args[0])); ++t) {
        args[t].cpuExpected = cpuExpected;
        args[t].len = len;
        args[t].elapsed = pElapsed + t;
        args[t].usePinnedMemory = usePinnedMemory;
        args[t].noSync = noSync;
      }
      while (argc > 0)
      {
        double maxThroughput = 0.0F;
        size_t nbThreads = atoi(argv[0]);
        argc--;
        argv++;
        for (int i = 0; i < NBOUTERITER; ++i)
        {
          memset((void*)pElapsed, 0, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t));
          if (useFork) {
            memset(pid, 0, sizeof(pid));
            for (size_t t = 0; t < nbThreads; ++t) {
              pid[t] = fork();
              if (pid[t] == 0) {
                /* I'm a child process */
                cudaFree(NULL); /* init runtime API */
                runThread(args + t);
                exit(EXIT_SUCCESS);
              }
            }
            /* Parent process */
            /* wait process ready */
            for (size_t t = 0; t < nbThreads; ++t) {
              while (*args[t].elapsed != 1U);
            }
            /* release processes */
            for (size_t t = 0; t < nbThreads; ++t) {
              *args[t].elapsed = 0U;
            }
            int status = EXIT_SUCCESS;
            for (size_t t = 0; t < nbThreads; ++t) {
              if (pid[t] < 0) {
                /* fork failed */
                status = EXIT_FAILURE;
              }
              else {
                int wstatus;
                waitpid(pid[t], &wstatus, 0);
                if (WIFEXITED(wstatus)) {
                  if (WEXITSTATUS(wstatus) != EXIT_SUCCESS) {
                    status = EXIT_FAILURE;
                  }
                }
                else {
                  status = EXIT_FAILURE;
                }
              }
            }
            if (status != EXIT_SUCCESS) {
              exit(EXIT_FAILURE);
            }
          }
          else {
            for (size_t t = 0; t < nbThreads; ++t) {
              pthread_create(thread + t, &attr, &runThread, (void*)(args + t));
            }
            /* wait process ready */
            for (size_t t = 0; t < nbThreads; ++t) {
              while (*args[t].elapsed != 1U);
            }
            /* release processes */
            for (size_t t = 0; t < nbThreads; ++t) {
              *args[t].elapsed = 0U;
            }
            for (size_t t = 0; t < nbThreads; ++t) {
              pthread_join(thread[t], NULL);
            }
          }
          uint64_t elapsed = 0;
          for (size_t t = 0; t < nbThreads; ++t) {
            elapsed += pElapsed[t];
          }
          double throughput = (1000000.0 / (double)elapsed) * ((double)NBINNERITER * (double)(nbThreads * nbThreads));
          if (throughput > maxThroughput) {
            maxThroughput = throughput;
          }
        }
        fprintf(stdout, "%d threads - %f ips\n", (int)nbThreads, maxThroughput);
      }
      munmap((void*)pElapsed, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t));
      return EXIT_SUCCESS;
    }
    

    Obviously, the kernel does nothing interesting here. It has only been tuned to more or less mimic the actual workload.
    One thing to be noted is the number of kernel launch for each iteration. I think it’s quite high but I have little to no leverage to reduce this number and it does not change the fact that multi-process + MPS has higher throughput than multi-thread.
    I can attach some profiling data in the multi-thread case (mostly, it shows average time going up for cudaLaunchKernel and cudaMemcpyAsync while cudaStreamSynchronize average time goes down)

    Thanks

    1 Like

    First, let’s clarify that the case where cudaStreamSynchronize() is issued before the cudaMemcpyAsync call is the faster of the two cases, when the host memory is pageable. I think you are saying this, but I wanted to be explicit.

    I think this makes sense. The short answer to your question is “yes”. The issuance of a cudaMemcpyAsync with a destination of pageable host memory may have an impact on other CUDA API calls issued in other streams/threads. Stated another way, it is no longer fully asynchronous. This is referred to somewhat in the programming guide section on asynchronous activity:

    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
    “Async memory copies will also be synchronous if they involve host memory that is not page-locked.”

    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#overlap-of-data-transfer-and-kernel-execution
    For overlap of data transfers with kernel execution:
    “If host memory is involved in the copy, it must be page-locked.”

    https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-data-transfers
    For overlap of data transfers with other data transfers:
    “If host memory is involved in the copy, it must be page-locked.”

    When you don’t do that, you’re exploring what will happen. Basically, bad things.

    The details of the transfer process to a pageable host buffer are not spelled out in the CUDA documents that I am aware of, however it’s evident that multiple steps are involved:

    1. programming of a DMA engine to point to a primary pinned buffer owned by the GPU driver
    2. executing a DMA transfer to that buffer
    3. updating the DMA engine to point to alternate pinned buffer
    4. execute DMA transfer to that buffer
    5. simultaneously with 4, copy data from primary buffer to pageable area
    6. alternate sense of primary and alternate buffers
    7. repeat steps 3 through 6 as needed to effect entire transfer

    If we compare that with a pinned copy, the process is quite simple:

    1. programming of a DMA engine to point to a pinned buffer provided by user
    2. executing a DMA transfer to that buffer

    The pageable sequence requires a host CPU thread (created by the GPU driver) to manage the process (this is stated in the documentation).

    It seems fairly evident that the processing involved in the first sequence is not something that fits within the definition of asynchrony. In fact, if we imagine a limited number of sets of ping-pong buffers, then we cannot really have many transfers “outstanding”, and it seems evident they will serialize.

    I believe that issuing transfers to pageable buffers effectively results in transfers being completed in issue order rather than the order of completion of previous work (or “ready” order). Ordinarily with multiple async transfers to pinned buffers in separate streams/threads, the transfers should be able to take place in any order (in “ready” order), depending on when the previous GPU activity in that stream/thread has completed. I suspect that the above mechanism requires that transfers to pageable buffers must take place in issue order.

    The complexity of your application makes demonstrating any of these ideas unwieldy, but we can get an idea of some of the problems by reducing your application to one that just spins up 4 threads and issues a single kernel from each thread. Then we can study nvprof output from --print-gpu-trace and --print-api-trace to get an idea of some of the interactions/hazards.

    Here’s a modified version of your app. I reduced it to one kernel per thread, launching with 4 threads. I’ve replaced your kernel with just a delay kernel so that each thread has a variable execution duration, the longest being the first thread and the shortest being the last thread. I’ve also introduced a 10ms “stagger” of release of each of the threads:

    $ cat t276.cu
    #include <stdlib.h>
    #include <string.h>
    #include <stdio.h>
    #include <time.h>
    #include <stdint.h>
    #include <inttypes.h>
    #include <pthread.h>
    #include <sys/mman.h>
    #include <sys/types.h>
    #include <sys/wait.h>
    #include <unistd.h>
    
    #define NBADD 1
    #define NBINNERITER 1
    #define NBOUTERITER 1
    #define LEN (352 * 1024)
    
    #define checkCudaErrors(val) check( (val), #val, __FILE__, __LINE__ )
    
    static void check(cudaError_t result, char const *const func, const char *const file, int const line)
    {
      if (result != cudaSuccess) {
        fprintf(stderr, "CUDA error at %s:%d code=%d \"%s\" \n", file, line, static_cast<unsigned int>(result), func);
        cudaDeviceReset();
        // Make sure we call CUDA Device Reset before exiting
        exit(EXIT_FAILURE);
      }
    }
    
    static uint64_t getTicksMicroSeconds()
    {
      struct timespec ts;
      clock_gettime(CLOCK_MONOTONIC, &ts);
      return ((uint64_t)ts.tv_sec) * UINT64_C(1000000) + (uint64_t)(ts.tv_nsec / 1000);
    }
    
    __device__ int4 nop4(int4 v)
    {
      v.x = __sad(v.x, v.x, v.x);
      v.y = __sad(v.y, v.y, v.y);
      v.z = __sad(v.z, v.z, v.z);
      v.w = __sad(v.w, v.w, v.w);
      return v;
    }
    __global__ void kernelDelay(size_t dt){
      size_t start = clock64();
      while (clock64() < start+dt);
    }
    
    __global__ void kernelAddConstant(const int* pSrc, int* pDst, const int value, const int length)
    {
      const int4* pSrcI4 = (const int4*)pSrc;
      int4* pDstI4 = (int4*)pDst;
      int idx = blockIdx.x * blockDim.x + threadIdx.x;
      int gridSize = blockDim.x * gridDim.x;
    
      for (int i = idx; i < length/4; i += gridSize) {
        int4 v = pSrcI4[i];
        v.x += value;
        v.y += value;
        v.z += value;
        v.w += value;
        if (idx == i) {
          for (int j = 0; j < 3072; ++j) {
            v = nop4(v);
          }
        }
        pDstI4[i] = v;
      }
    }
    #define DELAYT 1000000000ULL
    
    static int doWork(const int* cpuExpected, const int* cpuInput, int* cpuOutput, const size_t len, int* gpuInput, int* gpuOutput, cudaStream_t stream, int usePinnedMemory, int noSync, size_t id)
    {
      checkCudaErrors(cudaMemcpyAsync(gpuInput, cpuInput, len * sizeof(int), cudaMemcpyHostToDevice, stream));
    
      for (int i = 0; i < NBADD; ++i) {
    //    kernelAddConstant<<<dim3(5), dim3(64), 0, stream>>>(gpuInput, gpuOutput, i, (int)(len / 64));
        kernelDelay<<<1,1,0,stream>>>((8-id)*DELAYT);
        int* tmp = gpuInput;
        gpuInput = gpuOutput;
        gpuOutput = tmp;
      }
    
      if (!usePinnedMemory && !noSync) {
        checkCudaErrors(cudaStreamSynchronize(stream));
      }
      checkCudaErrors(cudaMemcpyAsync(cpuOutput, gpuOutput, len * sizeof(int), cudaMemcpyDeviceToHost, stream));
      if (usePinnedMemory) {
        checkCudaErrors(cudaStreamSynchronize(stream));
      }
      return memcmp(cpuOutput, cpuExpected, (len / 64) * sizeof(int));
    }
    
    typedef struct {
      int tid;
      int const* cpuExpected;
      size_t len;
      uint64_t volatile* elapsed;
      int usePinnedMemory;
      int noSync;
    } runThreadArgs;
    
    static void* runThread(void* argsVoid)
    {
      runThreadArgs* args = (runThreadArgs*)argsVoid;
      const int* cpuExpected = args->cpuExpected;
      const size_t len = args->len;
      const int usePinnedMemory = args->usePinnedMemory;
      const int noSync = args->noSync;
      int* cpuInput;
      int* cpuOutput;
      int* gpuInput;
      int* gpuOutput;
      cudaStream_t stream;
    
      if (!usePinnedMemory) {
        cpuInput = (int*)malloc(len * sizeof(int));
        cpuOutput = (int*)malloc(len * sizeof(int));
      }
      else {
        checkCudaErrors(cudaMallocHost(&cpuInput, len * sizeof(int)));
        checkCudaErrors(cudaMallocHost(&cpuOutput, len * sizeof(int)));
      }
      if ((cpuInput == NULL) || (cpuOutput == NULL)) {
        fprintf(stderr, "Allocation failure\n");
        exit(EXIT_FAILURE);
      }
      memset(cpuInput, 0, len * sizeof(int));
    #ifndef USE_BLOCKING
      checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
    #else
      checkCudaErrors(cudaStreamCreateWithFlags(&stream, cudaStreamDefault));
    #endif
      checkCudaErrors(cudaMalloc(&gpuInput, len * sizeof(int)));
      checkCudaErrors(cudaMalloc(&gpuOutput, len * sizeof(int)));
      *args->elapsed = 1U;
      while (*args->elapsed == 1U);
      uint64_t start = getTicksMicroSeconds();
      for (int i = 0; i < NBINNERITER; ++i) {
        if (doWork(cpuExpected, cpuInput, cpuOutput, len, gpuInput, gpuOutput, stream, usePinnedMemory, noSync, args->tid)) {
          fprintf(stderr, "Invalid result, iter %d\n", i);
          exit(EXIT_FAILURE);
        }
      }
      uint64_t stop = getTicksMicroSeconds();
      *args->elapsed = stop - start;
      checkCudaErrors(cudaFree(gpuInput));
      checkCudaErrors(cudaFree(gpuOutput));
      checkCudaErrors(cudaStreamDestroy(stream));
      if (!usePinnedMemory) {
        free(cpuInput);
        free(cpuOutput);
      }
      else {
        checkCudaErrors(cudaFreeHost(cpuInput));
        checkCudaErrors(cudaFreeHost(cpuOutput));
      }
      return NULL;
    }
    
    int main(int argc, char* argv[])
    {
      const size_t len = LEN;
      int* cpuExpected;
      pthread_t thread[32];
      runThreadArgs args[32];
      pid_t pid[32];
      uint64_t volatile* pElapsed;
      pthread_attr_t attr;
      checkCudaErrors(cudaSetDeviceFlags(cudaDeviceScheduleYield));
      if (argc < 5) {
        fprintf(stderr, "./app useFork usePinned noSync #cputhreads\n");
        exit(EXIT_FAILURE);
      }
      int useFork = atoi(argv[1]);
      int usePinnedMemory = atoi(argv[2]);
      int noSync = atoi(argv[3]);
      argc -= 4;
      argv += 4;
    
      cpuExpected = (int*)malloc(len *sizeof(int));
      if (cpuExpected == NULL) {
        fprintf(stderr, "Allocation failure\n");
        exit(EXIT_FAILURE);
      }
      for (size_t i = 0; i < len; ++i) {
        cpuExpected[i] = ((NBADD-1) * (NBADD-2)) / 2;
      }
    
      pElapsed = (uint64_t volatile*)mmap( 0, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t), PROT_READ | PROT_WRITE, MAP_SHARED | MAP_ANONYMOUS, -1 /*fd*/, 0);
      if (pElapsed == MAP_FAILED) {
        fprintf(stderr, "mmap failed\n");
        exit(EXIT_FAILURE);
      }
      if (useFork) {
        fprintf(stdout, "Using sub-processes\n");
      }
      else {
        fprintf(stdout, "Using threads\n");
        cudaFree(NULL); /* init runtime API */
        pthread_attr_init(&attr);
        pthread_attr_setstacksize( &attr, 1024 * 1024);
        pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
      }
      for (size_t t = 0; t < (sizeof(args) / sizeof(args[0])); ++t) {
        args[t].cpuExpected = cpuExpected;
        args[t].len = len;
        args[t].elapsed = pElapsed + t;
        args[t].usePinnedMemory = usePinnedMemory;
        args[t].noSync = noSync;
        args[t].tid = t;
      }
      while (argc > 0)
      {
        double maxThroughput = 0.0F;
        size_t nbThreads = atoi(argv[0]);
        argc--;
        argv++;
        for (int i = 0; i < NBOUTERITER; ++i)
        {
          memset((void*)pElapsed, 0, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t));
          if (useFork) {
            memset(pid, 0, sizeof(pid));
            for (size_t t = 0; t < nbThreads; ++t) {
              pid[t] = fork();
              if (pid[t] == 0) {
                /* I'm a child process */
                cudaFree(NULL); /* init runtime API */
                runThread(args + t);
                exit(EXIT_SUCCESS);
              }
            }
            /* Parent process */
            /* wait process ready */
            for (size_t t = 0; t < nbThreads; ++t) {
              while (*args[t].elapsed != 1U);
            }
            /* release processes */
            for (size_t t = 0; t < nbThreads; ++t) {
              *args[t].elapsed = 0U;
            }
            int status = EXIT_SUCCESS;
            for (size_t t = 0; t < nbThreads; ++t) {
              if (pid[t] < 0) {
                /* fork failed */
                status = EXIT_FAILURE;
              }
              else {
                int wstatus;
                waitpid(pid[t], &wstatus, 0);
                if (WIFEXITED(wstatus)) {
                  if (WEXITSTATUS(wstatus) != EXIT_SUCCESS) {
                    status = EXIT_FAILURE;
                  }
                }
                else {
                  status = EXIT_FAILURE;
                }
              }
            }
            if (status != EXIT_SUCCESS) {
              exit(EXIT_FAILURE);
            }
          }
          else {
            for (size_t t = 0; t < nbThreads; ++t) {
              pthread_create(thread + t, &attr, &runThread, (void*)(args + t));
            }
            /* wait process ready */
            for (size_t t = 0; t < nbThreads; ++t) {
              while (*args[t].elapsed != 1U);
            }
            /* release processes */
            for (size_t t = 0; t < nbThreads; ++t) {
              *args[t].elapsed = 0U;
              usleep(10000);  // 10ms delay
            }
            for (size_t t = 0; t < nbThreads; ++t) {
              pthread_join(thread[t], NULL);
            }
          }
          uint64_t elapsed = 0;
          for (size_t t = 0; t < nbThreads; ++t) {
            elapsed += pElapsed[t];
          }
          double throughput = (1000000.0 / (double)elapsed) * ((double)NBINNERITER * (double)(nbThreads * nbThreads));
          if (throughput > maxThroughput) {
            maxThroughput = throughput;
          }
        }
        fprintf(stdout, "%d threads - %f ips\n", (int)nbThreads, maxThroughput);
      }
      munmap((void*)pElapsed, (sizeof(args) / sizeof(args[0])) * sizeof(uint64_t));
      return EXIT_SUCCESS;
    }
    $ nvcc -arch=sm_60 -o t276 t276.cu
    $
    

    I think the most interesting output is the --print-api-trace output, but let’s take a look at the --print-gpu-trace output first:

    $ nvprof --print-gpu-trace ./t276 0 0 1 4
    ==15785== NVPROF is profiling process 15785, command: ./t276 0 0 1 4
    Using threads
    4 threads - 0.337067 ips
    ==15785== Profiling application: ./t276 0 0 1 4
    ==15785== Profiling result:
       Start  Duration            Grid Size      Block Size     Regs*    SSMem*    DSMem*      Size  Throughput  SrcMemType  DstMemType           Device   Context    Stream  Name
    605.51ms  231.07us                    -               -         -         -         -  1.3750MB  5.8111GB/s    Pageable      Device  Tesla P100-PCIE         1        18  [CUDA memcpy HtoD]
    605.85ms  6.04460s              (1 1 1)         (1 1 1)         8        0B        0B         -           -           -           -  Tesla P100-PCIE         1        18  kernelDelay(unsigned long) [423]
    6.65045s  114.88us                    -               -         -         -         -  1.3750MB  11.689GB/s      Device    Pageable  Tesla P100-PCIE         1        18  [CUDA memcpy DtoH]
    6.65222s  225.28us                    -               -         -         -         -  1.3750MB  5.9605GB/s    Pageable      Device  Tesla P100-PCIE         1        20  [CUDA memcpy HtoD]
    6.65260s  5.26744s              (1 1 1)         (1 1 1)         8        0B        0B         -           -           -           -  Tesla P100-PCIE         1        20  kernelDelay(unsigned long) [429]
    6.65381s  405.69us                    -               -         -         -         -  1.3750MB  3.3098GB/s    Pageable      Device  Tesla P100-PCIE         1        17  [CUDA memcpy HtoD]
    6.65436s  3.76247s              (1 1 1)         (1 1 1)         8        0B        0B         -           -           -           -  Tesla P100-PCIE         1        17  kernelDelay(unsigned long) [431]
    10.4168s  131.04us                    -               -         -         -         -  1.3750MB  10.247GB/s      Device    Pageable  Tesla P100-PCIE         1        17  [CUDA memcpy DtoH]
    11.9219s  354.20us                    -               -         -         -         -  1.3750MB  3.7910GB/s    Pageable      Device  Tesla P100-PCIE         1        19  [CUDA memcpy HtoD]
    11.9224s  4.51495s              (1 1 1)         (1 1 1)         8        0B        0B         -           -           -           -  Tesla P100-PCIE         1        19  kernelDelay(unsigned long) [436]
    16.4374s  122.53us                    -               -         -         -         -  1.3750MB  10.959GB/s      Device    Pageable  Tesla P100-PCIE         1        19  [CUDA memcpy DtoH]
    16.4385s  138.49us                    -               -         -         -         -  1.3750MB  9.6955GB/s      Device    Pageable  Tesla P100-PCIE         1        20  [CUDA memcpy DtoH]
    
    Regs: Number of registers used per CUDA thread. This number includes registers used internally by the CUDA driver and/or tools and can be more than what the compiler shows.
    SSMem: Static shared memory allocated per CUDA block.
    DSMem: Dynamic shared memory allocated per CUDA block.
    SrcMemType: The type of source memory accessed by memory operation/copy
    DstMemType: The type of destination memory accessed by memory operation/copy
    $
    

    We see the 4 kernel launches from 4 threads, but instead of fully asynchronous (depth-first) work issuing of:

    H->D
    kernel
    D->H
    H->D
    kernel
    D->H
    H->D
    kernel
    D->H
    H->D
    kernel
    D->H

    we see a somehwat “jumbled” pattern.

    To get a better idea of why, let’s study the tail end of the --print-api-trace output:

    $ nvprof --print-api-trace ./t276 0 0 1 4
    ==15858== NVPROF is profiling process 15858, command: ./t276 0 0 1 4
    Using threads
    4 threads - 0.317309 ips
    ==15858== Profiling application: ./t276 0 0 1 4
    ==15858== Profiling result:
       Start  Duration  Name
    162.11ms  8.0580us  cuDeviceGetPCIBusId
    ...
    289.79ms     348ns  cuDeviceGetAttribute
    289.80ms  16.606us  cudaSetDeviceFlags
    292.76ms  245.34ms  cudaFree
    539.84ms  75.357us  cudaStreamCreateWithFlags
    539.92ms  310.11us  cudaMalloc
    540.23ms  311.43us  cudaMalloc
    540.50ms  170.83us  cudaStreamCreateWithFlags
    540.53ms  2.1818ms  cudaStreamCreateWithFlags
    540.53ms  937.96us  cudaStreamCreateWithFlags
    540.67ms  727.39us  cudaMalloc
    541.40ms  1.1850ms  cudaMalloc
    541.47ms  571.17us  cudaMalloc
    542.04ms  2.4356ms  cudaMalloc
    542.71ms  591.51us  cudaMalloc
    543.31ms  525.93us  cudaMalloc
    544.51ms  1.0149ms  cudaMemcpyAsync
    545.54ms  133.49us  cudaLaunchKernel (kernelDelay(unsigned long) [423])
    545.68ms  6.03009s  cudaMemcpyAsync
    554.54ms  6.02233s  cudaMemcpyAsync
    564.61ms  6.01336s  cudaMemcpyAsync
    574.70ms  10.5227s  cudaMemcpyAsync
    6.57583s  4.51955s  cudaFree
    6.57691s  8.28431s  cudaLaunchKernel (kernelDelay(unsigned long) [429])
    6.57799s  169.46us  cudaLaunchKernel (kernelDelay(unsigned long) [430])
    6.57816s  4.51590s  cudaMemcpyAsync
    11.0941s  396.31us  cudaFree
    11.0945s  223.10us  cudaFree
    11.0947s  17.061us  cudaStreamDestroy
    11.0954s  456.86us  cudaFree
    11.0958s  30.409us  cudaStreamDestroy
    11.0975s  142.79us  cudaLaunchKernel (kernelDelay(unsigned long) [437])
    11.0976s  3.76340s  cudaMemcpyAsync
    14.8610s  5.26797s  cudaFree
    14.8612s  5.27025s  cudaMemcpyAsync
    20.1290s  273.26us  cudaFree
    20.1293s  9.7960us  cudaStreamDestroy
    20.1315s  492.53us  cudaFree
    20.1320s  406.05us  cudaFree
    20.1324s  20.843us  cudaStreamDestroy
    $
    

    Starting with the first cudaMemcpyAsync, we see the expected work issuance pattern for thread 0. H->D, kernel, D->H. However we then see 3 more cudaMemcpyAsync operations, approximately 10ms apart. These are the 3 H->D copies in threads 1,2,3. After that is a 6 second gap (due to the delay kernel in thread 0). Since the cudaMemcpyAsync operations don’t return until the copy is complete, when a pageable buffer is involved, we don’t get the kernel launches for threads 1,2,3 until 6 seconds into the timeline. This means that the H->D copies for threads 1,2,3 which are launched after the D->H copy in thread 0, are waiting for that thread 0 operation to complete, before they will begin. This is evident from the 6+ second duration of each of those operations when they should only require ~1 ms. So, effectively, these operations are serializing and executing in issue order rather than with asynchrony. Because they execute in issue order, the D->H operation for thread 0 (which must wait for the kernel to complete) is also holding up the H->D operation from other threads.

    The issuance of the cudaStreamSynchronize before the D->H calls actually improves this scenario, because it prevents the issuance of the disruptive pageable transfers until they are actually approximately ready to execute/complete. By holding these out of the work stream until the last possible moment, more asynchrony is possible.

    The moral of the story is simple: If you want understandable asynchrony, use pinned buffers.

    Yes, I’m aware you asked a question B. I’m not responding to that at this time. I haven’t thought about it.

    Thanks for the detailed answered regarding question A.
    I was indeed expecting no overlap of transfers and kernel execution in the case of pageable memory and was aware (to some degree, thanks for the extra explanation) of what’s happening in this case. It’s the fact that I expected cudaMemcpyAsync(DeviceToHost) to be fully synchronous only for the effective transfer time (and setup/managment/…) that surprised me.
    Anyway, as you said, to get understandable asynchrony, use pinned buffers. I already moved from pageable to pinned memory for the code that was still using pageable memory so no issue here.
    Remains question B which is much more complex I think.