Overheads monitored by NCU for profiling DNN workloads

I am trying to profile a simple Fully-Connected layer in PyTorch using NCU and I see some unexplainable observations.

import torch
import torch.nn as nn

# Define a simple linear model
class SimpleLinearModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleLinearModel, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

# Parameters
input_dim = 256
output_dim = 29423
batch_size = 32
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('[DEBUG]:', device)
input_ = torch.randn(batch_size, input_dim).to(device)
epochs = 20

# Initialize the model
linear_model = SimpleLinearModel(input_dim, output_dim)
linear_model.to(device)
linear_model.eval()
print(sum(p.numel() for p in linear_model.parameters()))
for _ in range(epochs):
    pred = linear_model(input_)

In this code our estimation is that the number of floating point operations and number of memory accesses should increase linearly with the batch size of the input data. Where as we are getting the following observation:

BS Mem Accesses (MB) FLOPs (G)
1 32.44 0.4852
8 32.46 0.4851
16 32.44 0.4852
32 35.68 0.4852
64 32.45 0.4856
1024 35.60 0.4852

I am using the following code to compute the memory accesses:

col = "gpu__time_duration.sum"
filtered_df = df[["dram__bytes.sum.per_second [Gbyte/s]", col]]
filtered_df["mem"] = df["dram__bytes.sum.per_second [Gbyte/s]"] * df[col]

mem = np.sum(filtered_df["mem"].values)
print(f"mem: {mem/1000/20} MB")  # Divide 20 to get per epoch

I am using the following code to compute the floating point operations:

def _ncu_get_flops_single(kernel_data: dict, col: str, breakdown=None) -> float:
        flops = (kernel_data['smsp__sass_thread_inst_executed_op_fadd_pred_on.sum.per_cycle_elapsed [inst/cycle]'] \
                    + kernel_data['smsp__sass_thread_inst_executed_op_fmul_pred_on.sum.per_cycle_elapsed [inst/cycle]'] \
                    + kernel_data['derived__smsp__sass_thread_inst_executed_op_ffma_pred_on_x2 [inst]']) \
                * kernel_data['smsp__cycles_elapsed.avg.per_second [Ghz]'] \
                * kernel_data[col]
        return flops

def ncu_get_flops(kernel_data: dict, data_width: int, col: str) -> float:
    """return all double/single/half/tensor FLOPs (count of FLoat OP)"""
    double = _ncu_get_flops_double(kernel_data, col)
    single, fadd, fmul, ffma = _ncu_get_flops_single(kernel_data, col, breakdown=True)
    half = _ncu_get_flops_half(kernel_data, col)
    tensor, tensor_sum, factor = _ncu_get_flops_tensor(kernel_data, col, breakdown=True)
        
    all_flops = (
        _ncu_get_flops_double(kernel_data, col),
        _ncu_get_flops_single(kernel_data, col),
        _ncu_get_flops_half(kernel_data, col),
        _ncu_get_flops_tensor(kernel_data, col)
    )
    flops = sum(all_flops)
    return flops

all_flops, all_tensors = 0, 0
for i in range(len(df)):
    f, tensors = ncu_get_flops(df.iloc[i], 32, col)
    all_flops += f
    all_tensors += tensors
    
print('FLOPS:', all_flops/1e6/20)  # in GFLOPs

My assumption is that there are some heavy overheads which make the effect of batch sizess negligible. Can somebody explain me why is this happening?

Is FLOPs

  1. The number of floating point operations
  2. The number of floating point operations per second

(1) will scale with batch size. (2) is an efficiency metric.

fp32_flop_count  = 0 \
  + kernel_data['smsp__sass_thread_inst_executed_op_fadd_pred_on.sum'] \
  + kernel_data['smsp__sass_thread_inst_executed_op_fmul_pred_on.sum'] \
  + kernel_data['derived__smsp__sass_thread_inst_executed_op_ffma_pred_on_x2'])

If you want FLOP/sec

fp32_flops_per_second = fp32_flop_count / (kernel_data['gpu__time_duration.sum'] * 1e9)

For memory accesses just use the sum. The bandwidth is an efficiency metric.
dram__bytes.sum