Failed to use INT8 precision mode when using caffemodel on Xavier

Hi,
I am trying to run my RFCN caffemodel on Xavier by tensorRT. It works well when using FP32 and FP16 mode, but doesn’t work when using INT8 mode. My caffeToTRTModel function is as follow:

void caffeToTRTModelINT8(const std::string& deployFile,
                     const std::string& modelFile,
                     const std::vector<std::string>& outputs,
                     unsigned int maxBatchSize,
                     nvcaffeparser1::IPluginFactoryExt* pluginFactory,
                     IHostMemory*& trtModelStream,
                     const std::string& imageSetFile,
                     const std::string& cachePath
                     )
{

    IBuilder* builder = createInferBuilder(gLogger.getTRTLogger());
    assert(builder != nullptr);

    INetworkDefinition* network = builder->createNetwork();
    ICaffeParser* parser = createCaffeParser();

    parser->setPluginFactoryExt(pluginFactory);

// ----INT8----

    const IBlobNameToTensor* blobNameToTensor = parser->parse(locateFile(deployFile, gArgs.dataDirs).c_str(),
                                                              locateFile(modelFile, gArgs.dataDirs).c_str(),
                                                              *network,
                                                              DataType::kFLOAT);

// specify which tensors are outputs
    for (auto& s : outputs){
        network->markOutput(*blobNameToTensor->find(s.c_str()));
    }

    // Build the engine
    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(10 << 20);

    // std::cout << "debug@tys: before setDummyInt8Scales" << std::endl;
    // samplesCommon::setDummyInt8Scales(builder, network);
    //std::cout << "debug@tys: before enableDLA" << std::endl;
    //samplesCommon::enableDLA(builder, gArgs.useDLACore);
    //builder->allowGPUFallback(true);

    builder->setAverageFindIterations(1);
    builder->setMinFindIterations(1);
    builder->setDebugSync(true);
    builder->setInt8Mode(true);

    DimsCHW mDims(INPUT_C, INPUT_H, INPUT_W);
    DataLoader dataLoader(imageSetFile,
                          BATCH_SIZE,mDims.w(),mDims.h(),
                          mDims.c());

    Int8EntropyCalibrator calibrator(&dataLoader,
                            mDims.c(),mDims.h(),mDims.w(),
                            false,cachePath);

    builder->setInt8Calibrator(&calibrator);
    builder->setDefaultDeviceType(DeviceType::kGPU);

ICudaEngine* engine = builder->buildCudaEngine(*network);
    assert(engine);

// we don't need the network any more, and we can destroy the parser
    network->destroy();
    parser->destroy();

    // serialize the engine, then close everything down
    // (*trtModelStream) = engine->serialize();

    trtModelStream = engine->serialize();

    std::ofstream ofs("/home/xxxx/tensorrt/samples/sampleRFCN_test/layerOut/serialData.txt", std::ios::out | std::ios::binary);
    ofs.write((char*)(trtModelStream->data()), trtModelStream->size());
    ofs.close();

    engine->destroy();
    builder->destroy();
    shutdownProtobufLibrary();
}

And my Dataloader and Int8EntropyCalibrator classes are as follow:

class DataLoader
{
public:
    DataLoader(const std::string& imageSetFile,
               int batchSize,
               int width,
               int height,
               int channels)
    : batchSize(batchSize), width(width), height(height), channels(channels)
    {
        index = 0;
        batchData = new float[batchSize * width * height * channels];
        imInfo = new float[batchSize * 3];
        for(int i = 0; i < batchSize; i++)
        {
            imInfo[i*batchSize + 0] = height;
            imInfo[i*batchSize + 1] = width;
            imInfo[i*batchSize + 2] = 1;  //image scale
        }
        // read image list from imageSetFile
        std::ifstream infile(imageSetFile);
        std::string imageName;
        int count = 0;

         while(!infile.eof() && count < CALIBRATE_MAX_NUM)
        {
            count++;
            string sTmp;
            infile >> imageName;
            std::cout << "@debug: DataLodaer imageName:  " << imageName << std::endl;
            getline(infile,sTmp);
            imageList.push_back(imageName);
        }

        infile.close();
    }

    ~DataLoader()
    {
        delete[] batchData;
        delete[] imInfo;
    }

    float* getBatchData() { return this->batchData;}

    float* getImInfo() {  return this->imInfo; }

    bool next()
    {
        std::cout << "Generate batch data for calibration: " << index+1 << "/" 
                  << imageList.size() << std::endl;
        if((index + batchSize) >= imageList.size())
            return false;

        for(int i = 0; i < batchSize; i++)
        {
            std::string imageFile = imageList[index+i];
            cv::Mat im = cv::imread(imageFile);
            int im3d = height * width * channels;
            int im2d = height * width;
            if(!im.empty())
            {
                cv::resize(im, im, cv::Size(width, height));
                for(int ch=0; ch < channels; ch++)
                {
                    for(int r=0; r < height; r++)
                    {
                        for(int c=0; c < width; c++)
                        {
                            batchData[i*im3d + ch*im2d + r*width + c] =
                            (float)(im.at<cv::Vec3b>(r,c)[ch]) - pixelMean[ch];

                            // std::cout << "@debug: " << (float)(im.at<cv::Vec3b>(r,c)[ch]) <<
                            // " " << pixelMean[ch] << " " << batchData[i*im3d + ch*im2d + r*width + c] << std::endl;
                        }
                    }
                }
            }
            else
            {
                std::cout << "Can't open " << imageFile << std::endl;
            }
        }

        index += batchSize;
        return true;
    }

private:
    unsigned int index;
    std::vector<std::string> imageList;
    float* batchData;
    float* imInfo;
    int batchSize;
    int width;
    int height;
    int channels;
};

class Int8EntropyCalibrator : public IInt8EntropyCalibrator
{
public:
    Int8EntropyCalibrator(DataLoader* dataLoader,
        int channel, int height, int width,
        bool readCache, const std::string& cachePath)
        : dataLoader(dataLoader), mReadCache(readCache), gNetWorkName(cachePath)
    {
        mDims = nvinfer1::DimsNCHW{ BATCH_SIZE, channel, height, width };
        mInputCount = mDims.n() * mDims.c() * mDims.h() * mDims.w();
        CHECK(cudaMalloc(&mDeviceInput1, mInputCount * sizeof(float)));
        CHECK(cudaMalloc(&mDeviceInput2, 3 * mDims.n() * sizeof(float)));
    }

    virtual ~Int8EntropyCalibrator()
    {
        CHECK(cudaFree(mDeviceInput1));
        CHECK(cudaFree(mDeviceInput2));
    }

    int getBatchSize() const override { return mDims.n(); }

    bool getBatch(void* bindings[], const char* names[], int nbBindings) override
    {
        if(!dataLoader->next())
            return false;

        CHECK(cudaMemcpy(mDeviceInput1, dataLoader->getBatchData(),
              mInputCount * sizeof(float), cudaMemcpyHostToDevice));
        assert(!strcmp(names[0], INPUT_BLOB_NAME0));
        bindings[0] = mDeviceInput1;

        CHECK(cudaMemcpy(mDeviceInput2, dataLoader->getImInfo(),
              3 * mDims.n() * sizeof(float), cudaMemcpyHostToDevice));
        assert(!strcmp(names[1], INPUT_BLOB_NAME1));
        bindings[1] = mDeviceInput2;
        return true;
    }

    const void* readCalibrationCache(size_t& length) override
    {
        //printf("ReadCalibrationCache\n");
        mCalibrationCache.clear();
        std::ifstream input(calibrationTableName(), std::ios::binary);
        input >> std::noskipws;
        if (mReadCache && input.good())
        {
            std::copy(std::istream_iterator<char>(input), 
                      std::istream_iterator<char>(), 
                      std::back_inserter(mCalibrationCache));
            std::cout << "Read Calibration Cache from " << calibrationTableName() << std::endl;
        }

        length = mCalibrationCache.size();
        return length ? &mCalibrationCache[0] : nullptr;
    }

    void writeCalibrationCache(const void* cache, size_t length) override
    {
        //printf("WriteCalibrationCache\n");
        std::ofstream output(calibrationTableName(), std::ios::binary);
        output.write(reinterpret_cast<const char*>(cache), length);
        std::cout << "Write Calibration Cache to file " << calibrationTableName() << std::endl;
    }

private:
    std::string calibrationTableName()
    {
        return gNetWorkName + std::string("/CalibrationTable");
    }

    DataLoader* dataLoader;
    bool mReadCache{ true };
    std::string gNetWorkName;
    size_t mInputCount;
    void* mDeviceInput1{ nullptr };
    void* mDeviceInput2{ nullptr };
    std::vector<char> mCalibrationCache;
    nvinfer1::DimsNCHW mDims;

};

I find it can generate batch data for calibration and do inference, but the rois and the rpn_cls_pre are not correct.Could anyone give me some advice?

By the way, I checked the output feature map of each layer of int8 and fp16, the results was same in the former 2 dense blocks, until layer “conv2_3/x1/scale”. The former 3 blocks is as follow:

name: “DENSENET_121_rcnn”
input: “data”
input_shape {
dim: 1
dim: 3
dim: 300
dim: 1000
}

input: “im_info”
input_shape {
dim: 1
dim: 1
dim: 1
dim: 3
}

layer {
name: “conv1”
type: “Convolution”
bottom: “data”
top: “conv1”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 64
bias_term: false
pad: 3
kernel_size: 7
stride: 2
}
}
layer {
name: “conv1/bn”
type: “BatchNorm”
bottom: “conv1”
top: “conv1/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv1/scale”
type: “Scale”
bottom: “conv1/bn”
top: “conv1/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu1”
type: “ReLU”
bottom: “conv1/bn”
top: “conv1/bn”
}
layer {
name: “pool1”
type: “Pooling”
bottom: “conv1/bn”
top: “pool1”
pooling_param {
pool: MAX
kernel_size: 3
stride: 2
pad: 1

ceil_mode: false

}
}
layer {
name: “conv2_1/x1/bn”
type: “BatchNorm”
bottom: “pool1”
top: “conv2_1/x1/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_1/x1/scale”
type: “Scale”
bottom: “conv2_1/x1/bn”
top: “conv2_1/x1/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_1/x1”
type: “ReLU”
bottom: “conv2_1/x1/bn”
top: “conv2_1/x1/bn”
}
layer {
name: “conv2_1/x1”
type: “Convolution”
bottom: “conv2_1/x1/bn”
top: “conv2_1/x1”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 128
bias_term: false
kernel_size: 1
}
}
layer {
name: “conv2_1/x2/bn”
type: “BatchNorm”
bottom: “conv2_1/x1”
top: “conv2_1/x2/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_1/x2/scale”
type: “Scale”
bottom: “conv2_1/x2/bn”
top: “conv2_1/x2/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_1/x2”
type: “ReLU”
bottom: “conv2_1/x2/bn”
top: “conv2_1/x2/bn”
}
layer {
name: “conv2_1/x2”
type: “Convolution”
bottom: “conv2_1/x2/bn”
top: “conv2_1/x2”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 32
bias_term: false
pad: 1
kernel_size: 3
}
}
layer {
name: “concat_2_1”
type: “Concat”
bottom: “pool1”
bottom: “conv2_1/x2”
top: “concat_2_1”
}
layer {
name: “conv2_2/x1/bn”
type: “BatchNorm”
bottom: “concat_2_1”
top: “conv2_2/x1/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_2/x1/scale”
type: “Scale”
bottom: “conv2_2/x1/bn”
top: “conv2_2/x1/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_2/x1”
type: “ReLU”
bottom: “conv2_2/x1/bn”
top: “conv2_2/x1/bn”
}
layer {
name: “conv2_2/x1”
type: “Convolution”
bottom: “conv2_2/x1/bn”
top: “conv2_2/x1”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 128
bias_term: false
kernel_size: 1
}
}
layer {
name: “conv2_2/x2/bn”
type: “BatchNorm”
bottom: “conv2_2/x1”
top: “conv2_2/x2/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_2/x2/scale”
type: “Scale”
bottom: “conv2_2/x2/bn”
top: “conv2_2/x2/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_2/x2”
type: “ReLU”
bottom: “conv2_2/x2/bn”
top: “conv2_2/x2/bn”
}
layer {
name: “conv2_2/x2”
type: “Convolution”
bottom: “conv2_2/x2/bn”
top: “conv2_2/x2”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 32
bias_term: false
pad: 1
kernel_size: 3
}
}
layer {
name: “concat_2_2”
type: “Concat”
bottom: “concat_2_1”
bottom: “conv2_2/x2”
top: “concat_2_2”
}
layer {
name: “conv2_3/x1/bn”
type: “BatchNorm”
bottom: “concat_2_2”
top: “conv2_3/x1/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_3/x1/scale”
type: “Scale”
bottom: “conv2_3/x1/bn”
top: “conv2_3/x1/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_3/x1”
type: “ReLU”
bottom: “conv2_3/x1/bn”
top: “conv2_3/x1/bn”
}
layer {
name: “conv2_3/x1”
type: “Convolution”
bottom: “conv2_3/x1/bn”
top: “conv2_3/x1”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 128
bias_term: false
kernel_size: 1
}
}
layer {
name: “conv2_3/x2/bn”
type: “BatchNorm”
bottom: “conv2_3/x1”
top: “conv2_3/x2/bn”
batch_norm_param {
eps: 1e-5
}
}
layer {
name: “conv2_3/x2/scale”
type: “Scale”
bottom: “conv2_3/x2/bn”
top: “conv2_3/x2/bn”
scale_param {
bias_term: true
}
}
layer {
name: “relu2_3/x2”
type: “ReLU”
bottom: “conv2_3/x2/bn”
top: “conv2_3/x2/bn”
}
layer {
name: “conv2_3/x2”
type: “Convolution”
bottom: “conv2_3/x2/bn”
top: “conv2_3/x2”
param {
lr_mult: 1
decay_mult: 1
}
convolution_param {
num_output: 32
bias_term: false
pad: 1
kernel_size: 3
}
}
layer {
name: “concat_2_3”
type: “Concat”
bottom: “concat_2_2”
bottom: “conv2_3/x2”
top: “concat_2_3”
}

SDK: JetPack 4.2.2
CUDA version: 10.0.326
Python version: 3.6.8
Tensorflow version: 1.14.0
TensorRT version: 5.1.6.1

Thanks!

Hi,

Could you illustrate more about the ‘not working’ of INT8 mode?
Do you meet any error or the app can work but with incorrect output?

Thanks.

Hi,

Thanks for your reply.
There is no error arised. The INT8 mode can run normlly, but the output is incorrect.
Just like this:

debug@yx: bbox_cls_num = 2
debug@tys: after bboxTransformInvAndClip
debug@tys: numDetections is 0 1
debug@tys: numDetections is 0 2
debug@tys: numDetections is 0 3
debug@tys: numDetections is 0 4
debug@tys: numDetections is 0 5
debug@tys: numDetections is 0 6
debug@tys: numDetections is 0 7
debug@tys: numDetections is 0 8
debug@tys: numDetections is 0 9
debug@tys: numDetections is 0 10
debug@tys: numDetections is 0 11
debug@tys: numDetections is 0 12
debug@tys: numDetections is 0 13
debug@tys: numDetections is 0 14
debug@tys: numDetections is 0 15
debug@tys: numDetections is 0 16
debug@tys: numDetections is 0 17
debug@tys: numDetections is 0 18
debug@tys: numDetections is 0 19
debug@tys: numDetections is 0 20
debug@tys: numDetections is 0 21
debug@tys: numDetections is 0 22
debug@tys: numDetections is 0 23
debug@tys: numDetections is 0 24
debug@tys: numDetections is 0 25
debug@tys: numDetections is 0 26
debug@tys: numDetections is 0 27
debug@tys: numDetections is 0 28
debug@tys: numDetections is 0 29
debug@tys: numDetections is 0 30
debug@tys: numDetections is 0 31
debug@tys: numDetections is 0 32
debug@tys: numDetections is 0 33
debug@tys: numDetections is 0 34
debug@tys: numDetections is 0 35
debug@tys: numDetections is 0 36
debug@tys: numDetections is 0 37
debug@tys: numDetections is 0 38
debug@tys: numDetections is 0 39
debug@tys: numDetections is 0 40
debug@tys: numDetections is 0 41
debug@tys: numDetections is 0 42
debug@tys: numDetections is 0 43
debug@tys: numDetections is 0 44
debug@tys: numDetections is 0 45
debug@tys: numDetections is 0 46
debug@tys: numDetections is 0 47
debug@tys: numDetections is 0 48
debug@tys: numDetections is 0 49
debug@tys: numDetections is 0 50
debug@tys: numDetections is 0 51
debug@tys: numDetections is 0 52
debug@tys: numDetections is 0 53
debug@tys: numDetections is 0 54
debug@tys: numDetections is 0 55
debug@tys: numDetections is 0 56
debug@tys: numDetections is 0 57
debug@tys: numDetections is 0 58
debug@tys: numDetections is 0 59
debug@tys: numDetections is 0 60
debug@tys: numDetections is 0 61
debug@tys: numDetections is 0 62
debug@tys: numDetections is 0 63
debug@tys: numDetections is 0 64
debug@tys: numDetections is 0 65
debug@tys: numDetections is 0 66
debug@tys: numDetections is 0 67
debug@tys: numDetections is 0 68
debug@tys: numDetections is 0 69
debug@tys: numDetections is 0 70
debug@tys: numDetections is 0 71
debug@tys: numDetections is 0 72
debug@tys: numDetections is 0 73
debug@tys: numDetections is 0 74
debug@tys: numDetections is 0 75
debug@tys: numDetections is 0 76
debug@tys: numDetections is 0 77
debug@tys: numDetections is 0 78
debug@tys: numDetections is 0 79
debug@tys: numDetections is 0 80
debug@tys: numDetections is 0 81
debug@tys: numDetections is 0 82
debug@tys: numDetections is 0 83
debug@tys: numDetections is 0 84
debug@tys: numDetections is 0 85
debug@tys: numDetections is 0 86
debug@tys: numDetections is 0 87
debug@tys: numDetections is 0 88
debug@tys: numDetections is 0 89
debug@tys: numDetections is 0 90
debug@tys: numDetections is 0 91
debug@tys: numDetections is 0 92
debug@tys: numDetections is 0 93
debug@tys: numDetections is 0 94
debug@tys: numDetections is 0 95
debug@tys: numDetections is 0 96
debug@tys: numDetections is 0 97
debug@tys: numDetections is 0 98
debug@tys: numDetections is 0 99
debug@tys: numDetections is 0 100
debug@tys: numDetections is 0 101
debug@tys: numDetections is 0 102
debug@tys: numDetections is 0 103
debug@tys: numDetections is 0 104
debug@tys: numDetections is 0 105
debug@tys: numDetections is 0 106
debug@tys: numDetections is 0 107
debug@tys: numDetections is 0 108
debug@tys: numDetections is 0 109
debug@tys: numDetections is 0 110
debug@tys: numDetections is 0 111
debug@tys: numDetections is 0 112
debug@tys: numDetections is 0 113
debug@tys: numDetections is 0 114
debug@tys: numDetections is 0 115
debug@tys: numDetections is 0 116
debug@tys: numDetections is 0 117
debug@tys: numDetections is 0 118
debug@tys: numDetections is 0 119
debug@tys: numDetections is 0 120
debug@tys: numDetections is 0 121
debug@tys: numDetections is 0 122
debug@tys: numDetections is 0 123
debug@tys: numDetections is 0 124
debug@tys: numDetections is 0 125
debug@tys: numDetections is 0 126
debug@tys: numDetections is 0 127
debug@tys: numDetections is 0 128
debug@tys: numDetections is 0 129
debug@tys: numDetections is 0 130
debug@tys: numDetections is 0 131
debug@tys: numDetections is 0 132
debug@tys: numDetections is 0 133
debug@tys: numDetections is 0 134
debug@tys: numDetections is 0 135
debug@tys: numDetections is 0 136
debug@tys: numDetections is 0 137
debug@tys: numDetections is 0 138
debug@tys: numDetections is 0 139
debug@tys: numDetections is 0 140
debug@tys: numDetections is 0 141
debug@tys: numDetections is 0 142
debug@tys: numDetections is 0 143
debug@tys: numDetections is 0 144
debug@tys: numDetections is 0 145
debug@tys: numDetections is 0 146
debug@tys: numDetections is 0 147
debug@tys: numDetections is 0 148
debug@tys: numDetections is 0 149
debug@tys: numDetections is 0 150
debug@tys: numDetections is 0 151
debug@tys: numDetections is 0 152
debug@tys: numDetections is 0 153
debug@tys: numDetections is 0 154
debug@tys: numDetections is 0 155
debug@tys: numDetections is 0 156
debug@tys: numDetections is 0 157
debug@tys: numDetections is 0 158
debug@tys: numDetections is 0 159
debug@tys: numDetections is 0 160
debug@tys: numDetections is 0 161
debug@tys: numDetections is 0 162
debug@tys: numDetections is 0 163
debug@tys: numDetections is 0 164
debug@tys: numDetections is 0 165
debug@tys: numDetections is 0 166
debug@tys: numDetections is 0 167
debug@tys: numDetections is 0 168
debug@tys: numDetections is 0 169
debug@tys: numDetections is 0 170
debug@tys: numDetections is 0 171
debug@tys: numDetections is 0 172
debug@tys: numDetections is 0 173
debug@tys: numDetections is 0 174
debug@tys: numDetections is 0 175
debug@tys: numDetections is 0 176
debug@tys: numDetections is 0 177
debug@tys: numDetections is 0 178
debug@tys: numDetections is 0 179
debug@tys: numDetections is 0 180
debug@tys: numDetections is 0 181
debug@tys: numDetections is 0 182
debug@tys: numDetections is 0 183
debug@tys: rpn_cls_pred is
min -3.0699, max 3.6275, count0 is 0
debug@tys: rpn_bbox_pred is:
debug@tys: end of batch 0
debug@tys: end of write
debug@tys: after delete data
debug@tys: after destroy
&&&& FAILED TensorRT.sample_fasterRCNN # ./sample_RFCN -d /home/nvidia/fuji/detection/tensorrt/data/RFCN

My calibrator file can be generated ,and the former few layers’ calibrations are as follow:

TRT-5106-EntropyCalibration
data: 3f99411a
im_info: 40fc07af
conv1: 41058db6
(Unnamed Layer* 1) [Scale]_output: 3d920920
conv1/bn: 3c15e829
pool1: 3c177b6b
(Unnamed Layer* 5) [Scale]_output: 3dacb7c8
conv2_1/x1/bn: 3ccfdd83
conv2_1/x1: 3c59b524
(Unnamed Layer* 9) [Scale]_output: 3dae534a
conv2_1/x2/bn: 3c2f270b
conv2_1/x2: 3ca2521d
(Unnamed Layer* 14) [Scale]_output: 3dacb7c8
conv2_2/x1/bn: 3c7bb4b0
conv2_2/x1: 3c2919be
(Unnamed Layer* 18) [Scale]_output: 3db18b67
conv2_2/x2/bn: 3c1f0a86
conv2_2/x2: 3c5f99ec
(Unnamed Layer* 23) [Scale]_output: 3dac7c3a
conv2_3/x1/bn: 3c432ab8
conv2_3/x1: 3bd9362f
(Unnamed Layer* 27) [Scale]_output: 3d8904fa
conv2_3/x2/bn: 3c2e7163
conv2_3/x2: 3c13901c
(Unnamed Layer* 32) [Scale]_output: 3dac7c3a
conv2_4/x1/bn: 3c6589ec
conv2_4/x2/bn: 3c0fb077
conv2_4/x2: 3bed2c00
(Unnamed Layer* 41) [Scale]_output: 3dac7c3a
conv2_5/x1/bn: 3c2cc7ab
conv2_5/x2/bn: 3bf46eff
conv2_5/x2: 3bd744a7
(Unnamed Layer* 50) [Scale]_output: 3dac7c3a
conv2_6/x1/bn: 3c29cdd6
conv2_6/x2/bn: 3c1c589a
conv2_6/x2: 3c05bd05
(Unnamed Layer* 59) [Scale]_output: 3dac7c3a
conv2_blk/bn: 3c81e075
conv2_blk: 3cae6515
pool2: 3cae68a1
(Unnamed Layer* 64) [Scale]_output: 3d7bda2e
conv3_1/x1/bn: 3c41ddd4
conv3_1/x2/bn: 3c04bfbc
conv3_1/x2: 3c887797
(Unnamed Layer* 73) [Scale]_output: 3d7bda2e
conv3_2/x1/bn: 3c288953
conv3_2/x2/bn: 3c17724d
conv3_2/x2: 3c18c5e0
(Unnamed Layer* 82) [Scale]_output: 3d7bda2e
conv3_3/x1/bn: 3bf1e14a
conv3_3/x2/bn: 3be2c436
conv3_3/x2: 3bde2f51
(Unnamed Layer* 91) [Scale]_output: 3d7bda2e
conv3_4/x1/bn: 3c26f4c2
conv3_4/x2/bn: 3c2f491a
conv3_4/x2: 3c010578
(Unnamed Layer* 100) [Scale]_output: 3d7bda2e
conv3_5/x1/bn: 3c1a13c4
conv3_5/x2/bn: 3c2846e4
conv3_5/x2: 3c04be83
(Unnamed Layer* 109) [Scale]_output: 3d7bda2e
conv3_6/x1/bn: 3becf226
conv3_6/x2/bn: 3c0d0006
conv3_6/x2: 3c0ec87b
(Unnamed Layer* 118) [Scale]_output: 3d7bda2e
conv3_7/x1/bn: 3becb6ef
conv3_7/x2/bn: 3c45c660
conv3_7/x2: 3bf828d1

Hi,

Is there any NMS or related plugin layer in your model?
If yes, could you check if it support INT8 mode first.

In genera, INT8 mode quantize tensor value into different data range.
Without the corresponding handling, the plugin layer may not perform correctly.

Thanks.