Time intervals and non-concurrent in multi streaming

Hi,
I tried to speed up my program with cuda streams.However, with the help of the nvvp, I found that there are time intervals between executions in the same cuda stream.And streams are not completely parallel, they are executed sequentially.Are these unavoidable or can be optimized?Here are screenshot from nvvp and my code.

#include <cuda_runtime_api.h>
#include "cublas_v2.h"
#include "cuda_runtime.h"
#include <iostream>

#define SMGR_N_STREAMS 8

#ifndef CUDA_CHECK
#define CUDA_CHECK(status)                                                     \
    if (status != cudaSuccess) {                                                 \
      std::cout << "Cuda failure! Error=" << cudaGetErrorString(status) << std::endl; \
    }
#endif

// cublas: various checks for different function calls.
#ifndef CUBLAS_CHECK
#define CUBLAS_CHECK(status)                                 \
    if (status != CUBLAS_STATUS_SUCCESS) {                     \
      std::cout << "Cublas failure! Error=" << status << std::endl; \
    }
#endif

__device__ inline float sigmoid(const float& x) {
  return 1.0 / (1.0 + exp(-x));
}

__global__ void ScatterMappingKernel(const int* gate_idx, const int num_expert, const int idx_num, int* mapping,
                                     int* acc_histogram) {
  int idx = threadIdx.x;
  extern __shared__ int his[];
  if (idx < num_expert + 1) his[idx] = 0;

  __syncthreads();

  for (int i = threadIdx.x; i < idx_num; i += blockDim.x) {
    // calc his
    /*if (gate_idx[i] < 0 || gate_idx[i] > num_expert) return;*/
    auto old = atomicAdd(&his[gate_idx[i] + 1], 1);
    mapping[i] = old;
  }

  __syncthreads();

  // acc his
  if (threadIdx.x == 0) {
    for (int i = 0; i < num_expert; i++) his[i + 1] += his[i];
  }
  __syncthreads();

  for (int i = threadIdx.x; i < idx_num; i += blockDim.x) {
  // calc his
    mapping[i] += his[gate_idx[i]];
  }

  if (idx < num_expert + 1) acc_histogram[idx] = his[idx];
}

int ComputeScatterMapping(const int* gate_idx, const int num_expert, const int idx_num, int* mapping,
                          int* acc_histogram, cudaStream_t stream) {
  int block_size = 0;
  if (idx_num < 1024)
    block_size = 256;
  else if (idx_num < 4096)
    block_size = 512;
  else
    block_size = 1024;

  ScatterMappingKernel<<<1, block_size, (num_expert + 1) * sizeof(int), stream>>>(gate_idx, num_expert, idx_num,
                                                          mapping, acc_histogram);
  return 0;
}

template <class T>
__global__ void ScatterMappingCopyKernel(const T* input, const int* mapping, const int dim, const int numel,
                                         T* output) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= numel) return;

  int s = idx / dim;
  int i = idx % dim;

  int mapping_idx = mapping[s];

  output[mapping_idx * dim + i] = input[idx];
}

template <class T>
int ComputeScatterMappingCopyTpl(const T* input, const int* mapping, const int S, const int dim, T* output,
                                 cudaStream_t stream) {
  auto numel = S * dim;

  int block_size = 256;
  int grid_size = (numel + block_size - 1) / block_size;

  ScatterMappingCopyKernel<T><<<grid_size, block_size, 0, stream>>>(input, mapping, dim, numel, output);

  return 0;
}

int ComputeScatterMappingCopy(const float* input, const int* mapping, const int S, const int dim, float* output,
                              cudaStream_t stream) {
  return ComputeScatterMappingCopyTpl(input, mapping, S, dim, output, stream);
}


template <typename T>
__global__ void BiasSiluKernel(const T* input, const T* bias, const int N, const int dim, T* output) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx < N) {
    int bias_idx = idx % dim;
    auto tmp = input[idx] + bias[bias_idx];
    output[idx] = tmp * sigmoid(tmp);
  }
}

template <typename T>
int ComputeBiasSiluTpl(const T* input, const T* bias, const int N, const int dim, T* output, cudaStream_t stream) {
  constexpr int block_size = 512;
  const int grid_size = (N + block_size - 1) / block_size;
  BiasSiluKernel<T><<<grid_size, block_size, 0, stream>>>(input, bias, N, dim, output);

  return 0;
}

int ComputeBiasSilu(const float* input, const float* bias, const int N, const int dim, float* output,
                    cudaStream_t stream) {
  return ComputeBiasSiluTpl<float>(input, bias, N, dim, output, stream);
}

template <typename T>
__global__ void BiasKernel(const T* input, const T* bias, const int N, const int dim, T* output) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (idx < N) {
    int bias_idx = idx % dim;
    output[idx] = input[idx] + bias[bias_idx];
  }
}

template <typename T>
int ComputeBiasTpl(const T* input, const T* bias, const int N, const int dim, T* output, cudaStream_t stream) {
  constexpr int block_size = 512;
  const int grid_size = (N + block_size - 1) / block_size;
  BiasKernel<T><<<grid_size, block_size, 0, stream>>>(input, bias, N, dim, output);

  return 0;
}

int ComputeBias(const float* input, const float* bias, const int N, const int dim, float* output, cudaStream_t stream) {
  return ComputeBiasTpl<float>(input, bias, N, dim, output, stream);
}

template <class T>
__global__ void GatherrMappingCopyKernel(const T* input, const int* mapping, const int dim, const int numel,
                                         T* output) {
  int idx = blockIdx.x * blockDim.x + threadIdx.x;
  if (idx >= numel) return;
  int s = idx / dim;
  int i = idx % dim;

  int mapping_idx = mapping[s];

  output[idx] = input[mapping_idx * dim + i];
}

template <class T>
int ComputeGatherMappingCopyTpl(const T* input, const int* mapping, const int S, const int dim, T* output,
                                cudaStream_t stream) {
  auto numel = S * dim;

  int block_size = 256;
  int grid_size = (numel + block_size - 1) / block_size;

  GatherrMappingCopyKernel<T><<<grid_size, block_size, 0, stream>>>(input, mapping, dim, numel, output);

  return 0;
}

int ComputeGatherrMappingCopy(const float* input, const int* mapping, const int S, const int dim, float* output,
                              cudaStream_t stream) {
  return ComputeGatherMappingCopyTpl(input, mapping, S, dim, output, stream);
}

cublasStatus_t cublasGemm(cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
                          int m, int n, int k,
                          const float alpha, const float* A, const float* B, 
                          const float beta, float* C) {
  int lda = k;
  if (transa == CUBLAS_OP_T) lda = m;
  int ldb = n;
  if (transb == CUBLAS_OP_T) ldb = k;
  int ldc = n;

  auto status = cublasSgemm(handle, transb, transa, 
                            n, m, k, 
                            &alpha, B, ldb, A, lda,
                            &beta, C, ldc);
  return status;
}

int main() {
  int seq = 150;
  int dim = 512;
  int hidden_units = 512;
  int input_volume = seq * dim;
  int num_expert = 32;

  int* gate_idx = (int*)malloc(seq * sizeof(int));
  float* input = (float*)malloc(seq * dim * sizeof(float));
  float* output = (float*)malloc(seq * dim * sizeof(float));
  float* w1_weight = (float*)malloc(num_expert * hidden_units * dim * sizeof(float));
  float* w2_weight = (float*)malloc(num_expert * hidden_units * dim * sizeof(float));
  float* w1_bias = (float*)malloc(num_expert * dim * sizeof(float));
  float* w2_bias = (float*)malloc(num_expert * dim * sizeof(float));
  int* h_acc_his = (int*)malloc((num_expert + 1) * sizeof(int));

  int* d_gate_idx;
  float* d_input;
  float* d_output;
  float* w1_weight_ptr;
  float* w2_weight_ptr;
  float* w1_bias_ptr;
  float* w2_bias_ptr;

  int* mapping;
  int* his;
  int* acc_his;
  float* input_buffer;
  float* hidden_buffer;

  cudaMalloc((void**)&d_gate_idx, seq * sizeof(int));
  cudaMalloc((void**)&d_input, input_volume * sizeof(float));
  cudaMalloc((void**)&d_output, input_volume * sizeof(float));
  cudaMalloc((void**)&w1_weight_ptr, num_expert * dim * hidden_units * sizeof(float));
  cudaMalloc((void**)&w2_weight_ptr, num_expert * dim * hidden_units * sizeof(float));
  cudaMalloc((void**)&w1_bias_ptr, num_expert * dim * sizeof(float));
  cudaMalloc((void**)&w2_bias_ptr, num_expert * dim * sizeof(float));

  cudaMalloc((void**)&mapping, seq * sizeof(int));
  cudaMalloc((void**)&his, (num_expert + 1) * sizeof(int));
  cudaMalloc((void**)&acc_his, (num_expert + 1) * sizeof(int));
  cudaMalloc((void**)&input_buffer, input_volume * sizeof(float));
  cudaMalloc((void**)&hidden_buffer, seq * hidden_units * sizeof(float));

  for(int i = 0; i < seq; i++) {
    gate_idx[i] = rand() & 0x1f;
  }

  for(int i = 0; i < input_volume; i++) {
    input[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
  }

  for(int i = 0; i < num_expert * dim * hidden_units; i++) {
    w1_weight[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
    w2_weight[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
  }

  for(int i = 0; i < num_expert * dim; i++) {
    w1_bias[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
    w2_bias[i] = static_cast <float> (rand()) / static_cast <float> (RAND_MAX);
  }

  cudaMemcpy(d_gate_idx, gate_idx, seq * sizeof(int), cudaMemcpyHostToDevice);
  cudaMemcpy(d_input, input, input_volume * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(w1_weight_ptr, w1_weight, num_expert * dim * hidden_units * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(w2_weight_ptr, w2_weight, num_expert * dim * hidden_units * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(w1_bias_ptr, w1_bias, num_expert * dim * sizeof(float), cudaMemcpyHostToDevice);
  cudaMemcpy(w2_bias_ptr, w2_bias, num_expert * dim * sizeof(float), cudaMemcpyHostToDevice);
  int status = -1;
  status = ComputeScatterMapping(d_gate_idx, num_expert, seq, mapping, acc_his, 0);
  if (status != 0) {
    std::cout << "compute_scatter_mapping error!" << std::endl;
    return status;
  }
  const size_t word_size = sizeof(float);
  status = ComputeScatterMappingCopy(d_input, mapping, seq, dim, input_buffer, 0);
  if (status != 0) {
    std::cout << "ComputeScatterMappingCopy error!" << std::endl;
    return status;
  }

  cudaMemcpyAsync(h_acc_his, acc_his, sizeof(int) * (num_expert + 1), cudaMemcpyDeviceToHost, 0);

  cudaStreamSynchronize(0);
  cublasOperation_t transa = CUBLAS_OP_N;
  cublasOperation_t transb = CUBLAS_OP_T;

  cublasHandle_t* handles_;
  cudaStream_t* streams_;
  streams_ = new cudaStream_t[SMGR_N_STREAMS];
  handles_ = new cublasHandle_t[SMGR_N_STREAMS];
  for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
    cudaStreamCreate(streams_ + i);
    cublasCreate(handles_ + i);
    cublasSetStream(handles_[i], streams_[i]);
  }
  printf("his acc is %d\n", h_acc_his[31]);
  for (int i = 0; i < num_expert; i++) {
    auto cur_stream = streams_[i % SMGR_N_STREAMS];
    auto handle = handles_[i % SMGR_N_STREAMS];
    int m = h_acc_his[i + 1] - h_acc_his[i];
    if (m == 0) continue;
    // float* input_buffer_ptr = input_buffer + h_acc_his[i] * idim;
    // float* hidden_buffer_ptr = hidden_buffer + h_acc_his[i] * hidden_units;
    auto input_buffer_ptr = input_buffer + h_acc_his[i] * dim;
    auto hidden_buffer_ptr = hidden_buffer + h_acc_his[i] * hidden_units;

    // Weights offset
    auto w_offset = i * dim * hidden_units;
    auto cur_w1_weight_ptr = w1_weight_ptr + w_offset;
    auto cur_w1_bias_ptr = w1_bias_ptr + i * hidden_units;
    auto cur_w2_weight_ptr = w2_weight_ptr + w_offset;
    auto cur_w2_bias_ptr = w2_bias_ptr + i * dim;

    // print_data(input_buffer_ptr, 10, "input_buffer_ptr");
    // w1 gemm, tmp => output
    CUBLAS_CHECK(cublasGemm(handle, transa, transb, m, hidden_units, dim, 1.0f, input_buffer_ptr, cur_w1_weight_ptr,
                            0.0f, hidden_buffer_ptr));
    // print_data(hidden_buffer_ptr, 10, "w1_weight");

    // w1 bias + activate, tmp2
    status = ComputeBiasSilu(hidden_buffer_ptr, cur_w1_bias_ptr, m * hidden_units, hidden_units, hidden_buffer_ptr,
                             cur_stream);
    if (status != 0) {
      std::cout << "ComputeBiasSilu error!" << std::endl;
      return status;
    }

    // print_data(hidden_buffer_ptr, 10, "silu");
    // w2 gemm tmp2 => tmp1
    cublasGemm(handle, transa, transb, m, dim, hidden_units, 1.0f, hidden_buffer_ptr, cur_w2_weight_ptr,
                            0.0f, input_buffer_ptr);
    // w2 bias tmp1
    status = ComputeBias(input_buffer_ptr, cur_w2_bias_ptr, m * dim, dim, input_buffer_ptr, cur_stream);
    if (status != 0) {
      std::cout << "ComputeBias error!" << std::endl;
      return status;
    }
    // print_data(input_buffer_ptr, 10, "w2");
    // cout << "=================" << endl;
  }
  for (int i = 0; i < SMGR_N_STREAMS; ++i) cudaStreamSynchronize(streams_[i]);

  status = ComputeGatherrMappingCopy(input_buffer, mapping, seq, dim, d_output, 0);
  if (status != 0) {
    std::cout << "ComputeGatherrMappingCopy error!" << std::endl;
    return status;
  }
  cudaMemcpy(output, d_output, input_volume * sizeof(float), cudaMemcpyDeviceToHost);
  for(int i = 0; i < 10; i++) {
    printf("%f ", output[i]);
  }
  // print_data(output, 10, "output");
  // cout << "=================" << endl;
  for (size_t i = 0; i < SMGR_N_STREAMS; ++i) {
    cudaStreamDestroy(streams_[i]);
    cublasDestroy(handles_[i]);
  }
  delete[] streams_;
  delete[] handles_;
  return status;
}

I agree that they are not “completely parallel”. However they are not purely sequential, either. You can witness overlap between (at least 2) streams in your picture. So you are getting concurrency - which is one of the benefits of using streams.

This sort of question/reasoning/logic comes up frequently, and I don’t understand it. I guess what you are suggesting is that the timeline of all 8 streams should be completely overlapped and all work issuance should be back-to-back - no gaps anywhere.

I’m not sure where that expectation comes from, its illogical to me. With simple extensions it implies that the machine has no limits of any kind and has infinite capacity.

Let’s look at 2 aspects of possible limits. In order to have all 8 streams fully overlapped and all work issuance back-to-back (i.e. no gaps, anywhere) you would probably be able to find instances of anywhere between 2 and 8 kernels that are running simultaneously (concurrently). You might also find instances of 2 or more cudaMemcpyAsync ops that are in the same direction and overlapped.

  1. cudaMemcpyAsync ops in the same direction (let’s say D->H) targetting a single GPU never overlap. That is a machine limit, and there is nothing you can do with streams or anything else to modify that behavior. They always serialize.

  2. kernel overlap (concurrency) is a function of machine limits also. If you launch enough threads in a kernel, so as to fully occupy the machine, it is not logical that you should expect another kernel to run concurrently - there is no room for it. Once you completely fill up a box, you cannot put more things into it - there is no room. The box has limits. A GPU has limits. One of those limits is the number of threads and blocks that it can be processing at any given instant in time (a concept related to occupancy).

I’ve noted that for some of your kernel launches, the grid size is data-dependent - I wasn’t able to quickly discover it by inspection. Likewise I note that you are using shared memory, which is another limit to occupancy and kernel concurrency.

I didn’t observe any obvious errors in your code, and furthermore you are getting some overlap/concurrency (which could only come about if you were using streams more-or-less correctly) so without spending a lot more time on your code and also knowing what GPU type you are running on (which you haven’t indicated, but is important), I wouldn’t be able to say anything else. My guess would be that you are running into machine limits that are preventing additional concurrency/overlap beyond what is already evident in your nvvp output.

After some additional checking there don’t appear to be any cudaMemcpyAsync operations in your main work issuance loop (I was looking at the colors in nvvp output but don’t have those memorized) so the primary area to focus for understanding would probably be limiters to kernel concurrency.

Also, you don’t appear to be using streams with your CUBLAS ops, so those are being issued into the NULL stream. That is going to introduce synchronization points in your timelines, which will be additional barriers to overlap/concurrency. I believe the cublas kernels may be the purple ones in your diagram, and those do not overlap with anything AFAICT, which is the behavior we expect for work issued into the NULL stream.

Issuing CUBLAS work into streams might be the first thing to look at.

1 Like

Thank you for your reply!

I tried to hugely reduce the blocks and threads in my kernel by reducing the data processed, however the situation didn’t improved.

My device is Tesla T4, and my environment is:
Cuda 10.2
Cudnn 8.0.4
Driver Version: 450.102.04
OS centos-7

I called cublasSetStream in my code and I thought it would work.

I have another question about Memcpy. I see V100 has 6 copy engines. However, the test results show that the cudaMemcpy ops in the same direction are serial, and the six copy engines are not used fully. If parallel copying in the same direction is not possible, why do we need 6 copy engines?

I don’t know that this is documented anywhere, but a reasonable hypothesis is that additional copy engines are dedicated to communication across NVLink.

A hypothetical parallel copy across the same PCIe link in the same direction would not really achieve anything, as the uni-directional throughput of the link is fixed (roughly 12 GB/sec for PCIe3, 25 GB/sec for PCIe4, 50 GB/sec for PCIe5).

As Robert_Crovella pointed out, performing any particular activity in parallel only leads to a performance increase (higher throughput) if it exploits a resource that is not yet fully utilized. In addition, doing things in parallel often comes with some sort of coordination overhead for sharing a resource.

Hi Njuffa,

I have the same question, I’m doing a test trying to cocurrently tansfer data from Device to Host, single direction transfer, the reason why we want to do that is we monitored the throughtput with single process is far less than the specs you put in your post (roughly 12 GB/sec for PCIe3, 25 GB/sec for PCIe4, 50 GB/sec for PCIe5), it’s only around 1.5GB/s which are way less than our expectation, thats why we think the resource is not fully utilized, and cocurrency of D2H comes into our mind, 2 questions though,

  1. what’s the best practice to fully utilize the bandwidth with single process in single direction transfer (D2H) ?
  2. what’s cocurrency approach should we employ when single process is not able to achieve specs throughput ?

(0) The maximum transfer rates mentioned pertain to x16 links. A PCIe throughput of only 1.5 GB/second suggest that either the transfers are small (see next item), or that the GPU is not installed in a x16 slot. In most systems, only a few of the physical PCIe slots are of the x16 type (check system documentation), others are x8 or x4 with correspondingly lower throughput. Also, if there are multiple GPUs in the system, the CPU, which contains the PCIe root complex(es), may not provide enough PCIe lanes to dedicate a x16 link to every GPU. This is something to watch out for when spec’ing a system.

(1) PCIe uses packetized transfer, so there is fixed overhead for every chunk sent. In order to achieve maximum throughput, individual transfers need to be >= 16 MB, or thereabouts. You can easily make a shmoo plot of throughput at various transfer sizes to see where the curve approximates 12 GB/sec (for PCIe3; correspondingly higher numbers for later PCIe generations). The bandwidthTest sample app for CUDA can get you started on measurements.

(2) For maximum throughput, host ↔ device transfers must use pinned host memory with asynchronous copy API calls, i.e. cudaMemcpyAsync(). Because pinned allocations are physically contiguous chunks of memory, the DMA mechanism (= “copy engine”, which is a marketing term) can deliver (or pull) data directly from the user’s host memory. Without pinned host memory, each host ↔ device transfer becomes two transfers: (a) DMA from/into a driver-allocated pinned buffer (b) system memory copy between user memory and driver’s pinned buffer.

(3) On systems with multiple CPU sockets, or with CPUs that internally comprise multiple processor complexes, the host system has NUMA (non-uniform memory access) characteristics. Use numactl and similar system utilities to control process and memory affinity such that each GPU communicates with the “near” CPU and its attached “near” memory. Otherwise host ↔ device transfers will take a detour through inter-processor interconnect (one or even several “hops”) to reach the “far” CPU with its attaches “far” memory.

(4) PCIe is a full-duplex interconnect. As long as a GPU has >= 2 copy engines, host → device and device ← host transfers can therefore take place concurrently. At the HLL progamming level, this requires the use of CUDA streams to independently queue copy operations and the kernel launches.

With these items addressed, a single process will easily max out unidirectional PCIe x16 bandwidth. An exception to this could occur if system memory read or writes speeds are lower than the PCIe transfer rate but I would really hope that no vendor misconfigures systems in this fashion (basically there should be enough DRAM channels provided by the CPU and populated with actual DRAM sticks).

Below is a rudimentary CUDA program (without proper error checking!) that hammers the PCIe connection to the GPU specified with the #define DEVICE_ID in the code with bi-directional traffic. Sample output from my PCIe3-based workstation looks like this:

running on device 0 (Quadro RTX 4000)
PCIe throughput (both directions combined): 21.07 GB/sec
#include <stdlib.h>
#include <string.h>

#define BUF_SIZE  (16 * 1024 * 1024)
#define MAX_ITER  (1000)
#define DEVICE_ID (0)

#if defined(_WIN32)
#if !defined(WIN32_LEAN_AND_MEAN)
#define WIN32_LEAN_AND_MEAN
#endif
#include <windows.h>
double second (void)
{
    LARGE_INTEGER t;
    static double oofreq;
    static int checkedForHighResTimer;
    static BOOL hasHighResTimer;

    if (!checkedForHighResTimer) {
        hasHighResTimer = QueryPerformanceFrequency (&t);
        oofreq = 1.0 / (double)t.QuadPart;
        checkedForHighResTimer = 1;
    }
    if (hasHighResTimer) {
        QueryPerformanceCounter (&t);
        return (double)t.QuadPart * oofreq;
    } else {
        return (double)GetTickCount() * 1.0e-3;
    }
}
#elif defined(__linux__) || defined(__APPLE__)
#include <stddef.h>
#include <sys/time.h>
double second (void)
{
    struct timeval tv;
    gettimeofday(&tv, NULL);
    return (double)tv.tv_sec + (double)tv.tv_usec * 1.0e-6;
}
#else
#error unsupported platform
#endif

int main (void)
{
    double start, stop, elapsed;
    cudaStream_t stream[2];
    unsigned char * dbuf[2] = {0, 0};
    unsigned char * hbuf[2] = {0, 0};
    size_t totalSize;
    struct cudaDeviceProp props;

    cudaSetDevice (DEVICE_ID);
    cudaGetDeviceProperties (&props, DEVICE_ID);
    printf ("running on device %d (%s)\n", DEVICE_ID, props.name);
    cudaStreamCreate (&stream[0]);    
    cudaStreamCreate (&stream[1]);
    cudaMalloc ((void**)&dbuf[0], BUF_SIZE);
    cudaMalloc ((void**)&dbuf[1], BUF_SIZE);
    cudaMallocHost ((void**)&hbuf[0], BUF_SIZE);
    cudaMallocHost ((void**)&hbuf[1], BUF_SIZE);
    
    start = second();
    for (int i = 0; i < MAX_ITER; i++) {
        cudaMemcpyAsync (dbuf[0], hbuf[0], BUF_SIZE,
                         cudaMemcpyHostToDevice, stream [0]);
        cudaMemcpyAsync (hbuf[1], dbuf[1], BUF_SIZE,
                         cudaMemcpyDeviceToHost, stream [1]);
    }
    cudaDeviceSynchronize();
    totalSize = ((size_t)BUF_SIZE) * MAX_ITER * 2;
    stop = second();
    elapsed = stop - start;
    printf ("PCIe throughput (both directions combined): %.2f GB/sec\n", 
            totalSize / elapsed / 1024 / 1024 / 1024);
    cudaFreeHost (hbuf[0]);    
    cudaFreeHost (hbuf[1]);
    cudaFree (dbuf[0]);    
    cudaFree (dbuf[1]);
    cudaStreamDestroy (stream[0]);
    cudaStreamDestroy (stream[1]);
    return EXIT_SUCCESS;
}