Inference with Keras Pretrained Models

Weird issues come up when inferencing Keras Pretrained Models. I have exported some Keras pretrained models to UFF file and parsing them, serialize them into plan file. But inference with models other than resnet50 doesn’t seem to give right results. Information of my computer is as follows:

  • TensorRT4
  • Cuda9.0
  • Ubuntu16
  • TensorFlow 1.12.0

I first export UFF files. UFF exporting script is like this:

tf.keras.backend.set_learning_phase(0)
model = tf.keras.applications.VGG16(include_top = True)
model.load_weights('<.h5 file>')
# According to sample code save the model and convert to uff file
def save(model, filename):
    output_names = model.output.op.name
    sess = tf.keras.backend.get_session()
    frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])
    uff.from_tensorflow(graphdef=frozen_graph,
                        output_filename=filename,
                        output_nodes=[output_names],
                        text=True)
save(model, 'vgg16.uff')

Information will come out as:

INFO:tensorflow:Froze 32 variables.
INFO:tensorflow:Converted 32 variables to const ops.
Using output node predictions/Softmax
Converting to UFF graph
DEBUG: convert reshape to flatten node
No. nodes: 88
UFF Output written to uff_models/vgg16.uff
UFF Text Output written to uff_models/vgg16.uff.pbtxt
Input: input_1
Output: predictions/Softmax

Then parse the uff file to engine and serialize it for later use. Script is like this:

ICudaEngine* UFFParser(const char* uff_file, int maxBatchsize, IUffParser* parser)
{
    IBuilder* builder = createInferBuilder(gLogger);
    INetworkDefinition* network = builder->createNetwork();
    
    if(!parser->parse(uff_file, *network, DataType::kFLOAT))
    {
        std::cout << "Fail to parse" << std::endl;
        exit(-1);
    }
    
    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(8 << 30);
    ICudaEngine* engine = builder->buildCudaEngine(*network);
    if(!engine)
    {
        std::cout << "Unable to create engine" << std::endl;
        exit(-1);
    }
    return engine;
}
// ....
int main()
{
    auto parser = createUffParser();
    parser->registerInput(input_name, Dims3{3, 224, 224}, UffInputOrder::kNCHW);
    parser->registerOutput(output_name);
    ICudaEngine* engine = UFFParser(uff_file, 2, parser);
    // serialization
    return 0;
}

Then for the inference script, i’ve done things like this:

static const float pixelMean[3]{103.939f, 116.779f, 123.68f}; // in BGR order
void image_preprocess(PPM* img, float* data)
{
	for(int c = 0; c < INPUT_C; ++c)
	{
		for(int j = 0, volChl = INPUT_H * INPUT_W; j < volChl; ++j)
		{
			data[c*volChl + j] = float(img->buffer[j*INPUT_C + 2 - c]) - pixelMean[c];
		}
	}
}

void doInference(nvinfer1::IExecutionContext& context, float* input, float* output, int batchSize)
{
	const nvinfer1::ICudaEngine& engine = context.getEngine();
	assert(engine.getNbBindings() == 2);
	void* buffers[2];

    int inputIndex, outputIndex;
    for(int b = 0; b < engine.getNbBindings(); ++b)
    {
        if(engine.bindingIsInput(b))
            inputIndex = b;
        else
            outputIndex = b;
    }

    // Create GPU buffers on device
    CHECK(cudaMalloc(&buffers[inputIndex], batchSize * INPUT_C * INPUT_H * INPUT_W * sizeof(float)));
    CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float)));

    // Create stream
    cudaStream_t stream;
    CHECK(cudaStreamCreate(&stream));

    // DMA input batch data to device, infer on the batch asynchronously, and DMA output back to host
    CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_C * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
    context.enqueue(batchSize, buffers, stream, nullptr);
    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    cudaStreamSynchronize(stream);

    // Release stream and buffers
    cudaStreamDestroy(stream);
    CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));    
}
int main()
{
    //....
    // PPM* input_img = read_img(...)
    image_preprocess(Input_Img, data);
    doInference(*context, data, prob, 1);
    //...
}

I read the .ppm image, do preprocessing, which for resnet50 is substract the imagenet mean and convert to BGR order. Then do inference. I run the whole imagenet validation set. Only resnet50 gives me the correct result. Other models are just classifying images totally wrong. It seems like the problem shouldn’t locate in how i do inference with TensorRT, but the preprocessing Steps. But i have checked the keras pretrained script, it seems the preprocessing they used is just substract imagenet mean and convert to BGR order. Why only inference with resnet50 works? For vgg16, vgg19, inceptionV3, densenet121, they all fail to give correct results.

Is there any problem in my implementation of TensorRT inference? Where could the problem be located… It really confuses me.

Successfully tackled the problem with inceptionV3 and densenet121, it’s about the preprocessing step. But it is still confusing for VGG16 and VGG19 model. Because their preprocessing steps are the same as resnet50 (This time i already checked this!). I found interesting issues that could be the problem:

When parsing resnet50, inceptionV3 or densenet121, the parser automatically insert a layer to do reshape or shuffle(using the term in TensorRT) like layer in 311 here (i print out the layer information after parsing):

310: mixed10/concat: Layer Type: 8 Output dim (2048, 5, 5)
311: (Unnamed Layer* 310) [Shuffle]: Layer Type: 14 Output dim (5, 5, 2048)
312: avg_pool_1/Mean: Layer Type: 15 Output dim (1, 1, 2048)
313: predictions/MatMul: Layer Type: 1 Output dim (1000, 1, 1)
314: predictions/BiasAdd: Layer Type: 5 Output dim (1000, 1, 1)
315: predictions/Softmax_HL_1804289383: Layer Type: 6 Output dim (1000, 1, 1)
316: predictions/Softmax: Layer Type: 14 Output dim (1, 1, 1000)

It seems it automatically inserts a shuffle layer before fully connected layer and Softmax layer.
While for vgg16 or vgg19, there is no such automatically inserted shuffle layer:

49: fc2/BiasAdd: Layer Type: 5 Output dim (4096, 1, 1)
50: fc2/Relu: Layer Type: 2 Output dim (4096, 1, 1)
51: predictions/MatMul: Layer Type: 1 Output dim (1000, 1, 1)
52: predictions/BiasAdd: Layer Type: 5 Output dim (1000, 1, 1)
53: predictions/Softmax_HL_1804289383: Layer Type: 6 Output dim (1000, 1, 1)
54: predictions/Softmax: Layer Type: 14 Output dim (1, 1, 1000)

What’s the problem here?

Hello,

We are triaging and will keep you updated on what we find. Also, can you point me to the pretrained Keras models you are using? From Keras model zoo? Keras Applications

regards,
NVES.

Per engineering:

The parser will insert a transpose (or shuffle) before FC layer if the order is not the same as what the FC layer expects. Is there a difference in the order fields for the FC layer between these 2 networks?
Another transpose inserted by the parser is at the output. If it realizes that the network has modified the order - it tries to restore it to the original order. I see that this is inserted in both cases, irrespective of whether or not the initial transpose before FC was inserted.

How are the outputs incorrect? If you transpose the final output does it give you the correct result?

Yes, exactly.

The predictions are just not correct. It seems there are some operations error in it. But it somehow can be parsed and ran through…

I found an interesting thing in Parsing. I print out the parsing informations:

41: block5_conv3/Conv2D: Layer Type: 0 Output dim (512, 14, 14)
42: block5_conv3/BiasAdd: Layer Type: 5 Output dim (512, 14, 14)
43: block5_conv3/Relu: Layer Type: 2 Output dim (512, 14, 14)
44: flatten/Reshape: Layer Type: 3 Output dim (512, 7, 7)
45: fc1/MatMul: Layer Type: 1 Output dim (4096, 1, 1)
46: fc1/BiasAdd: Layer Type: 5 Output dim (4096, 1, 1)
47: fc1/Relu: Layer Type: 2 Output dim (4096, 1, 1)

for Layer 44, the layer type is 3 which is Pooling layer, and it’s also exactly what’s done in source code: it should first be an maxpooling layer, then flatten layer. According the parsing information it gave me, the output dimension cannot be sensible cause how can the reshape layer output dim (512,7,7) from input dim (512, 7, 7)?

Hello,

can you share a .pb and uff file with us? it’ll help us debug.

Hallo,

i didn’t download it from any source, just exported using the code.

model = tf.keras.applications.VGG16(include_top = True)
model.load_weights('~/.keras/models/vgg16_weights_tf_dim_ordering_tf_kernels.h5')
# According to sample code save the model and convert to uff file
def save(model, filename):
    output_names = model.output.op.name
	sess = tf.keras.backend.get_session()
    frozen_graph = tf.graph_util.convert_variables_to_constants(sess, sess.graph.as_graph_def(), [output_names])
    uff.from_tensorflow(graphdef=frozen_graph,
                        output_filename=filename,
                        output_nodes=[output_names],
                        text=True)
save(model, 'vgg16.uff')

simply run this script in your host and you will get the uff file provided that you already installed TensorRT4 and TensorFlow 1.12.0.