I’m trying to implement branchynet on some models and testing with the CIFAR-10 dataset on the Jetson Orin Nano 8GB. Basically, I split the model into a first subgraph (common) that will be executed eagerly, and at a certain point, I introduce a conditional to check if the result is good enough, in which case the model finishes prematurely (branch1), thus saving time. If it doesn’t meet the condition, the output of the common subgraph goes through the rest of the model (branch2). I’ve been creating the models in tflite like so:
model = tf.keras.models.load_model(r"resnet8.h5")
model.trainable = False
# Common path
common = Model(inputs=model.input, outputs=model.layers[18].output)
# Conditional branches
branch1 = Model(inputs=model.layers[18].output, outputs=model.layers[-2].output)
branch2 = Model(inputs=model.layers[18].output, outputs=model.layers[-1].output)
# Custom layer to choose between branches
class ChooseBranchLayer(tf.keras.layers.Layer):
def __init__(self):
super(ChooseBranchLayer, self).__init__()
self.branch1 = branch1
self.branch2 = branch2
def call(self, inputs):
common_output = inputs
output1 = self.branch1(common_output)
condition = tf.reduce_max(output1) > 0.90
return tf.cond(condition, lambda: output1, lambda: self.branch2(common_output))
# Input layer
inputs = tf.keras.layers.Input(shape=(32, 32, 3))
# Common output
common_output = common(inputs)
# Use the custom layer to choose output based on condition
final_output = ChooseBranchLayer()(common_output)
model_EE= tf.keras.Model(inputs=inputs, outputs=final_output)
spec = (tf.TensorSpec((None, 32, 32, 3), tf.float32, name="input"),)
converter = tf.lite.TFLiteConverter.from_keras_model(model_EE)
tflite_model = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
with open(r"EE_resnet8.tflite", 'wb') as f:
f.write(tflite_model)
Then I convert the model to onnx with python -m tf2onnx.convert --tflite EE_resnet8.tflite --output EE_resnet8.onnx --opset 17
. For testing, I need to know which branch the model has taken, so I tried returning two values, the output and a flag to signal which branch was taken: return tf.cond(condition, lambda: [output1,0], lambda: [self.branch2(common_output),1])
, but when I build the engine with that, I get error:
[E] [TRT] ModelImporter.cpp:771: --- End node ---
[E] [TRT] ModelImporter.cpp:773: ERROR: ModelImporter.cpp:222 In function parseGraph:
[5] Assertion failed: (node.output().size() <= static_cast<int32_t>(outputs.size())) && "Node has more output tensors than TRT expected."
[E] Failed to parse onnx file
onnxruntime, however, runs ok and gives the same results as tflite. I’ve read in the developer guide that onnx is more flexible with the outputs of the conditional construct, but I don’t get what I am doing wrong here. A workaround I found is to concatenate the flag to the output tensor, so instead of having one vector of 10 elements with the predicted classes and one constant indicating the executed branch, I have one output of 11 elements where the last element is the flag:
class ChooseBranchLayer(tf.keras.layers.Layer):
def __init__(self):
super(ChooseBranchLayer, self).__init__()
self.branch1 = branch1
self.branch2 = branch2
def call(self, inputs):
common_output = inputs
output1 = self.branch1(common_output)
flag0 = tf.broadcast_to(tf.constant([[0.0]]), [tf.shape(output1)[0], 1])
flag1 = tf.broadcast_to(tf.constant([[1.0]]), [tf.shape(output1)[0], 1])
condition = tf.reduce_max(output1) > 0.9
return tf.cond(condition, lambda: tf.concat([output1, flag0], axis=-1), lambda: tf.concat([self.branch2(common_output), flag1], axis=-1))
This does work and produces the expected results. However, I have some problems. First, I measure latency, so I use CUDA graphs because the original resnet8 has an average execution time of 0.15 ms, and this is clearly an enqueue-bound workload. However, I’ve seen it is not possible to use them in this context due to the conditional flow of the model. I’ve read in this blog Dynamic Control Flow in CUDA Graphs with Conditional Nodes that " Beginning in CUDA 12.4, CUDA Graphs supports conditional nodes, which enable the conditional or repeated execution of portions of a graph without returning control to the CPU", so I guess until CUDA 12.4 is implemented in JetPack we are out of luck.
When I measure the latencies of this branchynet model, it is way slower than the original model, which I assume is due to the lack of CUDA graphs. However, I don’t understand that this implementation does not scale well with model quantization. For some reason, INT8 is noticeably slower than FP16, whereas in the original model, the latency is FP32 > FP16 > INT8, as expected. I’ve tested this for Resnet8, Resnet56 and Alexnet, and all of them show this problem. I have no idea why this is happening, and I would like to know if this has to do with a poor implementation of conditional flow on my part. I tried using Nsight Systems to check if the branchynet models were not using tensor cores, but they seem to be active. I have also tried changing the threshold in the conditional so it always takes the same branch, and I get the same latency for both; sometimes, even the second branch (which is longer) beats the first.
I wanted to know the cause of this unexpected behavior and possible ways to address it. Thank you in advance for your attention.
Some details of my set-up:
Jetson Orin Nano 8GB
Jetpack 6.0
TensorRT version: 8.6.2