Jetson AGX Orin - Transformer Neural Network Token Removal - Latency vs. Workload Size Behavior

Good afternoon, all. I am evaluating and testing a variety of vision transformer neural networks on a Jetson AGX Orin Developer Kit (32 GB). My evaluations involve progressively removing input tokens (workload size) from the transformer neural network, then measuring the latency reduction from doing so. I am experiencing an odd trend when I plot latency of a forward pass versus the % of tokens kept. I would like to know what underlying GPU behavior could lead to such a situation.

I expect this latency plot would look more like a staircase (tail effect), or a line.

Below is a plot of this behavior. For reference, I am using PyTorch 2.1.2 (built from source with CUDA 11.4) using the Torch Image Models (TIMM) implementation of deit-small, a popular vision transformer.:

I find this latency vs. workload relationship to be odd - I am aware of the GPU tail effect, which could partially explain this behavior. However, if anyone has any other insight into how GPU hardware characteristics could be affecting causing this trend, I’m all ears.