Hi,
Here is a sample.
-
Pease copy the calibrator.py file from the int8_caffe_mnist folder to the network_api_pytorch_mnist folder.
-
Apply the below patch to the sample.py:
diff --git a/samples/python/network_api_pytorch_mnist/sample.py b/samples/python/network_api_pytorch_mnist/sample.py
index e5e95de2..3a5d47f8 100644
--- a/samples/python/network_api_pytorch_mnist/sample.py
+++ b/samples/python/network_api_pytorch_mnist/sample.py
@@ -24,9 +24,12 @@ import numpy as np
import pycuda.autoinit
import tensorrt as trt
+from calibrator import load_mnist_data, load_mnist_labels, MNISTEntropyCalibrator
+
sys.path.insert(1, os.path.join(sys.path[0], ".."))
import common
+
# You can set the logger severity higher to suppress messages (or lower to display more messages).
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
@@ -102,51 +105,98 @@ def populate_network(network, weights):
network.mark_output(tensor=fc2.get_output(0))
-def build_engine(weights):
+def build_int8_engine(weights, calib, batch_size=32):
# For more information on TRT basics, refer to the introductory samples.
builder = trt.Builder(TRT_LOGGER)
- network = builder.create_network(common.EXPLICIT_BATCH)
+ builder.max_batch_size = batch_size
+
+ network = builder.create_network()
config = builder.create_builder_config()
runtime = trt.Runtime(TRT_LOGGER)
config.max_workspace_size = common.GiB(1)
+ config.set_flag(trt.BuilderFlag.INT8)
+ config.int8_calibrator = calib
+
# Populate the network using weights from the PyTorch model.
populate_network(network, weights)
# Build and return an engine.
plan = builder.build_serialized_network(network, config)
+
+ #with open("sample.engine", "wb") as f:
+ # f.write(plan)
return runtime.deserialize_cuda_engine(plan)
-# Loads a random test case from pytorch's DataLoader
-def load_random_test_case(model, pagelocked_buffer):
- # Select an image at random to be the test case.
- img, expected_output = model.get_random_testcase()
- # Copy to the pagelocked input buffer
- np.copyto(pagelocked_buffer, img)
- return expected_output
+def check_accuracy(context, batch_size, test_set, test_labels):
+ inputs, outputs, bindings, stream = common.allocate_buffers(context.engine)
+
+ num_correct = 0
+ num_total = 0
+
+ batch_num = 0
+ for start_idx in range(0, test_set.shape[0], batch_size):
+ batch_num += 1
+ if batch_num % 10 == 0:
+ print("Validating batch {:}".format(batch_num))
+ # If the number of images in the test set is not divisible by the batch size, the last batch will be smaller.
+ # This logic is used for handling that case.
+ end_idx = min(start_idx + batch_size, test_set.shape[0])
+ effective_batch_size = end_idx - start_idx
+
+ # Do inference for every batch.
+ inputs[0].host = test_set[start_idx : start_idx + effective_batch_size]
+ [output] = common.do_inference(
+ context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream, batch_size=effective_batch_size
+ )
+
+ # Use argmax to get predictions and then check accuracy
+ preds = np.argmax(output.reshape(batch_size, 10)[0:effective_batch_size], axis=1)
+ labels = test_labels[start_idx : start_idx + effective_batch_size]
+ num_total += effective_batch_size
+ num_correct += np.count_nonzero(np.equal(preds, labels))
+
+ percent_correct = 100 * num_correct / float(num_total)
+ print("Total Accuracy: {:}%".format(percent_correct))
def main():
common.add_help(description="Runs an MNIST network using a PyTorch model")
+
# Train the PyTorch model
mnist_model = model.MnistModel()
mnist_model.learn()
weights = mnist_model.get_weights()
- # Do inference with TensorRT.
- engine = build_engine(weights)
-
- # Build an engine, allocate buffers and create a stream.
- # For more information on buffer allocation, refer to the introductory samples.
- inputs, outputs, bindings, stream = common.allocate_buffers(engine)
- context = engine.create_execution_context()
-
- case_num = load_random_test_case(mnist_model, pagelocked_buffer=inputs[0].host)
- # For more information on performing inference, refer to the introductory samples.
- # The common.do_inference function will return a list of outputs - we only have one in this case.
- [output] = common.do_inference_v2(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
- pred = np.argmax(output)
- print("Test Case: " + str(case_num))
- print("Prediction: " + str(pred))
+
+ _, data_files = common.find_sample_data(
+ description="Runs a Caffe MNIST network in Int8 mode",
+ subfolder="mnist",
+ find_files=[
+ "t10k-images-idx3-ubyte",
+ "t10k-labels-idx1-ubyte",
+ "train-images-idx3-ubyte",
+ ],
+ err_msg="Please follow the README to download the MNIST dataset",
+ )
+ [test_set, test_labels, train_set] = data_files
+
+ # Now we create a calibrator and give it the location of our calibration data.
+ # We also allow it to cache calibration data for faster engine building.
+ calibration_cache = "mnist_calibration.cache"
+ calib = MNISTEntropyCalibrator(train_set, cache_file=calibration_cache)
+
+ # Inference batch size can be different from calibration batch size.
+ batch_size = 32
+ #with open('sample.engine', 'rb') as f:
+ # plan = f.read()
+
+ with build_int8_engine(
+ weights, calib, batch_size
+ ) as engine, engine.create_execution_context() as context:
+ # Batch size for inference can be different than batch size used for calibration.
+ check_accuracy(
+ context, batch_size, test_set=load_mnist_data(test_set), test_labels=load_mnist_labels(test_labels)
+ )
if __name__ == "__main__":
We can get 97% accuracy with the INT8 calibration.
$ python3 sample.py
Train Epoch: 1 [0/60000 (0%)] Loss: 2.288751
...
Test set: Average loss: 0.0649, Accuracy: 9804/10000 (98%)
...
Validating batch 310
Total Accuracy: 97.5%
Thanks.