Problem with custom layers and Python UFF parser in TensorRT 3.0 RC

The plugin is the same as from the samplePlugin(caffe), i was just trying to use .uff instead of a caffe model. Now I have moved to TensorRT 4GA as without using graph surgeon, the plugin is never called.But, I am getting an assertion failure at nbWeights, which according to the documentation should be 2(Kernel, bias), but for my uff model it is 1. I am using “fc2/add” which should be equivalent to “ip2”

Hi,

Please remember to update the custom layer name with the UFF operation name.
Thanks.

Yes, I did. I used the collapse namespace function from graph surgeon to do so. I am still getting wrong output.

import graphsurgeon as gs
import tensorflow as tf

fc2 = gs.create_node("fc2")
namespace_plugin_map = {
    "fc2/Relu":fc2,
    "fc2/add":fc2,
    "fc2/MatMul":fc2
    }

def preprocess(dynamic_graph):
    # Now create a new graph by collapsing namespaces
    dynamic_graph.collapse_namespaces(namespace_plugin_map)
    # Remove the outputs, so we just have a single output node (NMS).
    # dynamic_graph.remove(dynamic_graph.graph_outputs, remove_exclusive_dependencies=False)

Hi,

Sorry that we don’t have an example to demonstrate UFF plugin currently.

Please check the source of NvUffParser.h.
Try to parse the operation argument which is described with FieldMap.

Thanks.

I am looking at TRT 4GA, and it does have a UFF sample(sampleUFFSSD), can you help me point out what is causing the wrong output with FC layer plugin based on the code I shared above. I have followed the exact workflow from the sample.

Hi,

The sampleUFFSSD is a C++ based sample while your program shared in comment#23 is python.
Could you share the complete source you modified for us checking?

Thanks.

See #19 for C++ part, and #23 for Python

Hi,

Suppose you should pass some layer information inside gs.create_node.

For example,

PriorBox = gs.create_node("PriorBox",
    numLayers=6,
    minScale=0.2,
    maxScale=0.95,
    aspectRatios=[1.0, 2.0, 0.5, 3.0, 0.33],
    layerVariances=[0.1,0.1,0.2,0.2],
    featureMapShapes=[19, 10, 5, 3, 2, 1])

Could you check if there is anything missing in your config.py?
Thanks.

I don’t believe so(but that’s why I posted the question), the original SSD sample is trying to convert the whole “Postprocessor” block into a single NMS node, that’s why it needs intermediate attributes during create_node. In my example, it’s just an FC layer

Hi.

Sorry for the late reply.

Which mode do you use?
If you are using the INT8 mode, re-calibration is required for a custom model.
Thanks.

Yes, I know, I have been using TRT since 3.0.4, this is not int8 as it is not defined in the code. Also, this is just replacing the relevant lines of code from sampleUFF that used .caffemodel(TRT 3) by the .uff usage from sampleUFFSSD(4GA), everything else is the same.

Hi,

How many fully-connected layers inside your model?
Please remember that you need to create multiple mPlugin since the weighs is different.

Thanks.

It is the exact same model as the one in samplePlugin as I mentioned in #21, the only difference is it is a tensorflow model, not caffe. It has 2 FC layers, and I make “fc2” as the unsupported layer as I mentioned in #23. I do not understand why I need to create multiple mPlugin, as I am only replacing fc2 with a custom plugin? Also, how would one achieve something like this?

Model: model.uff - Google Drive

Hi,

We try to run your program in comment #19, but found the implementation is not compatible with our TensorRT package.

[Application]

registerInput(const char*&, nvinfer1::DimsCHW)

[Expected]

registerInput(const char* inputName, nvinfer1::Dims inputDims, UffInputOrder inputOrder)

Could you share which TensorRT version do you use with us?
Thanks.

As I mentioned before in #25, #31, I am working with TRT 4GA(4.0.1.6). With regards to the line of code you mentioned, I did have this in one my iterations of code, I am still getting the wrong output.
Following is the latest iteration of code.

#include <assert.h>
#include <chrono>
#include <fstream>
#include <sstream>
#include <iostream>
#include <cmath>
#include <sys/stat.h>
#include <cmath>
#include <time.h>
#include <cudnn.h>
#include <cublas_v2.h>
#include <memory>
#include <string.h>

#include "NvInfer.h"
#include "NvUffParser.h"
#include "common.h"
#include "fp16.h"

using namespace nvuffparser;
using namespace nvinfer1;
// stuff we know about the network and the caffe input/output blobs
static const int INPUT_H = 28;
static const int INPUT_W = 28;
static const int OUTPUT_SIZE = 10;
static Logger gLogger;

#define MAX_WORKSPACE (1 << 30)

#define RETURN_AND_LOG(ret, severity, message)                                 \
    do                                                                         \
    {                                                                          \
        std::string error_message = "sample_uff_ssd: " + std::string(message); \
        gLogger.log(ILogger::Severity::k##severity, error_message.c_str());    \
        return (ret);                                                          \
    } while (0)

const char* INPUT_BLOB_NAME = "Input/input"; // Refer graphsurgeon conf script
const char* OUTPUT_BLOB_NAME = "fc2";


std::string locateFile(const std::string& input)
{
    std::vector<std::string> dirs{"data/samples/mnist/", "data/mnist/"};
    return locateFile(input, dirs);
}


// simple PGM (portable greyscale map) reader
void readPGMFile(const std::string& filename,  uint8_t buffer[INPUT_H*INPUT_W])
{
    readPGMFile(locateFile(filename), buffer, INPUT_H, INPUT_W);
}

std::vector<std::pair<int64_t, DataType>>
calculateBindingBufferSizes(const ICudaEngine& engine, int nbBindings, int batchSize) // Unchanged
{
    std::vector<std::pair<int64_t, DataType>> sizes;
    for (int i = 0; i < nbBindings; ++i)
    {
        Dims dims = engine.getBindingDimensions(i);
        DataType dtype = engine.getBindingDataType(i);

        int64_t eltCount = samples_common::volume(dims) * batchSize;
        sizes.push_back(std::make_pair(eltCount, dtype));
    }

    return sizes;
}

ICudaEngine* UFFtoGIEModel(const char* uffFile,
                            int maxBatchSize,IUffParser* parser,
                            nvuffparser::IPluginFactory* pluginFactory,
                            IHostMemory *&gieModelStream)
{
    IBuilder* builder = createInferBuilder(gLogger);
    INetworkDefinition* network = builder->createNetwork();
    parser->setPluginFactory(pluginFactory);

    std::cout << "Begin parsing model..." << std::endl;
    if (!parser->parse(uffFile, *network, nvinfer1::DataType::kFLOAT))
        RETURN_AND_LOG(nullptr, ERROR, "Fail to parse");

    std::cout << "End parsing model..." << std::endl;
    // specify which tensors are outputs
	/* we create the engine */
    builder->setMaxBatchSize(maxBatchSize);
    builder->setMaxWorkspaceSize(1 << 20);
    builder->setHalf2Mode(false);

    std::cout << "Begin building engine..." << std::endl;
    ICudaEngine* engine = builder->buildCudaEngine(*network);

    if (!engine)
        RETURN_AND_LOG(nullptr, ERROR, "Unable to create engine");
    std::cout << "End building engine..." << std::endl;

    // serialize the engine, then close everything down
    network->destroy();
    parser->destroy();
	// serialize the engine, then close everything down
	gieModelStream = engine->serialize();
    builder->destroy();
    shutdownProtobufLibrary();
    return engine;
}


void doInference(IExecutionContext& context, float* input, float* output, int batchSize)
{
	const ICudaEngine& engine = context.getEngine();
	// input and output buffer pointers that we pass to the engine - the engine requires exactly IEngine::getNbBindings(),
	// of these, but in this case we know that there is exactly one input and one output.
    assert(engine.getNbBindings() == 2);
    void* buffers[2];
	// In order to bind the buffers, we need to know the names of the input and output tensors.
	// note that indices are guaranteed to be less than IEngine::getNbBindings()
	int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME),
		outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);

	// // create GPU buffers and a stream
    CHECK(cudaMalloc(&buffers[inputIndex], batchSize * INPUT_H * INPUT_W * sizeof(float)));
    CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_SIZE * sizeof(float)));

	cudaStream_t stream;
	CHECK(cudaStreamCreate(&stream));

    // DMA the input to the GPU,  execute the batch asynchronously, and DMA it back:
    CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_H * INPUT_W * sizeof(float), cudaMemcpyHostToDevice, stream));
    context.enqueue(batchSize, buffers, stream, nullptr);
    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_SIZE*sizeof(float), cudaMemcpyDeviceToHost, stream));
    cudaStreamSynchronize(stream);

    // release the stream and the buffers
    cudaStreamDestroy(stream);
    CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));
}


class FCPlugin: public IPluginExt
{
public:
    FCPlugin(const Weights *weights, int nbWeights, int nbOutputChannels): mNbOutputChannels(nbOutputChannels)
    {   std::cout << "nbweightsout " << weights[0].count << std::endl; //DBG
        assert(nbWeights == 2);

        mKernelWeights = weights[0];
        assert(mKernelWeights.type == DataType::kFLOAT || mKernelWeights.type == DataType::kHALF);

        mBiasWeights = weights[1];
        std::cout << "mBiasWeights: " << mBiasWeights.count << " nbOutputChannels: " << nbOutputChannels << " mKernelWeights: " << mKernelWeights.count << std::endl; //DBG
        assert(mBiasWeights.count == 0 || mBiasWeights.count == nbOutputChannels);
        assert(mBiasWeights.type == DataType::kFLOAT || mBiasWeights.type == DataType::kHALF);

        mKernelWeights.values = malloc(mKernelWeights.count*type2size(mKernelWeights.type));
        memcpy(const_cast<void*>(mKernelWeights.values), weights[0].values, mKernelWeights.count*type2size(mKernelWeights.type));
        mBiasWeights.values = malloc(mBiasWeights.count*type2size(mBiasWeights.type));
        memcpy(const_cast<void*>(mBiasWeights.values), weights[1].values, mBiasWeights.count*type2size(mBiasWeights.type));

        mNbInputChannels = int(weights[0].count / nbOutputChannels);
    }

    // create the plugin at runtime from a byte stream
    FCPlugin(const void* data, size_t length)
    {
        const char* d = static_cast<const char*>(data), *a = d;
        read(d, mNbInputChannels);
        read(d, mNbOutputChannels);

        mKernelWeights.count = mNbInputChannels * mNbOutputChannels;
        mKernelWeights.values = nullptr;

        read(d, mBiasWeights.count);
        mBiasWeights.values = nullptr;

        read(d, mDataType);

        deserializeToDevice(d, mDeviceKernel, mKernelWeights.count*type2size(mDataType));
        deserializeToDevice(d, mDeviceBias, mBiasWeights.count*type2size(mDataType));
        assert(d == a + length);
    }

    ~FCPlugin()
    {
        if (mKernelWeights.values)
        {
            free(const_cast<void*>(mKernelWeights.values));
            mKernelWeights.values = nullptr;
        }
        if (mBiasWeights.values)
        {
            free(const_cast<void*>(mBiasWeights.values));
            mBiasWeights.values = nullptr;
        }
    }

    int getNbOutputs() const override
    {
        return 1;
    }

    Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
    {
        assert(index == 0 && nbInputDims == 1 && inputs[0].nbDims == 3);
        assert(mNbInputChannels == inputs[0].d[0] * inputs[0].d[1] * inputs[0].d[2]);
        return Dims3(mNbOutputChannels, 1, 1);
    }

    bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT || type == DataType::kHALF) && format == PluginFormat::kNCHW; }

    void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims, int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override
    {
        assert((type == DataType::kFLOAT || type == DataType::kHALF) && format == PluginFormat::kNCHW);
        mDataType = type;
    }

    int initialize() override
    {
        CHECK(cudnnCreate(&mCudnn));// initialize cudnn and cublas
        CHECK(cublasCreate(&mCublas));
        CHECK(cudnnCreateTensorDescriptor(&mSrcDescriptor));// create cudnn tensor descriptors we need for bias addition
        CHECK(cudnnCreateTensorDescriptor(&mDstDescriptor));
        if (mKernelWeights.values)
            convertAndCopyToDevice(mDeviceKernel, mKernelWeights);
        if (mBiasWeights.values)
            convertAndCopyToDevice(mDeviceBias, mBiasWeights);

        return 0;
    }

    virtual void terminate() override
    {
        CHECK(cudnnDestroyTensorDescriptor(mSrcDescriptor));
        CHECK(cudnnDestroyTensorDescriptor(mDstDescriptor));
        CHECK(cublasDestroy(mCublas));
        CHECK(cudnnDestroy(mCudnn));
        if (mDeviceKernel)
        {
            cudaFree(mDeviceKernel);
            mDeviceKernel = nullptr;
        }
        if (mDeviceBias)
        {
            cudaFree(mDeviceBias);
            mDeviceBias = nullptr;
        }
    }

    virtual size_t getWorkspaceSize(int maxBatchSize) const override
    {
        return 0;
    }

    virtual int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream) override
    {
        float onef{1.0f}, zerof{0.0f};
        __half oneh = fp16::__float2half(1.0f), zeroh = fp16::__float2half(0.0f);

        cublasSetStream(mCublas, stream);
        cudnnSetStream(mCudnn, stream);

        if (mDataType == DataType::kFLOAT)
        {
            CHECK(cublasSgemm(mCublas, CUBLAS_OP_T, CUBLAS_OP_N, mNbOutputChannels, batchSize, mNbInputChannels, &onef, //alpha = 1
                              reinterpret_cast<const float*>(mDeviceKernel), mNbInputChannels,
                              reinterpret_cast<const float*>(inputs[0]), mNbInputChannels, &zerof,
                              reinterpret_cast<float*>(outputs[0]), mNbOutputChannels));
        }
        else
        {
            CHECK(cublasHgemm(mCublas, CUBLAS_OP_T, CUBLAS_OP_N, mNbOutputChannels, batchSize, mNbInputChannels, &oneh,
                              reinterpret_cast<const __half*>(mDeviceKernel), mNbInputChannels,
                              reinterpret_cast<const __half*>(inputs[0]), mNbInputChannels, &zeroh,
                              reinterpret_cast<__half*>(outputs[0]), mNbOutputChannels));
        }
        if (mBiasWeights.count)
        {
            cudnnDataType_t cudnnDT = mDataType == DataType::kFLOAT ? CUDNN_DATA_FLOAT : CUDNN_DATA_HALF;
            CHECK(cudnnSetTensor4dDescriptor(mSrcDescriptor, CUDNN_TENSOR_NCHW, cudnnDT, 1, mNbOutputChannels, 1, 1));
            CHECK(cudnnSetTensor4dDescriptor(mDstDescriptor, CUDNN_TENSOR_NCHW, cudnnDT, batchSize, mNbOutputChannels, 1, 1));
            CHECK(cudnnAddTensor(mCudnn, &onef, mSrcDescriptor, mDeviceBias, &onef, mDstDescriptor, outputs[0]));
        }

        return 0;
    }

    virtual size_t getSerializationSize() override
    {
        return sizeof(mNbInputChannels) + sizeof(mNbOutputChannels) + sizeof(mBiasWeights.count) + sizeof(mDataType) +
               (mKernelWeights.count + mBiasWeights.count) * type2size(mDataType);
    }

    virtual void serialize(void* buffer) override
    {
        char* d = static_cast<char*>(buffer), *a = d;

        write(d, mNbInputChannels);
        write(d, mNbOutputChannels);
        write(d, mBiasWeights.count);
        write(d, mDataType);
        convertAndCopyToBuffer(d, mKernelWeights);
        convertAndCopyToBuffer(d, mBiasWeights);
        assert(d == a + getSerializationSize());
    }

private:
    size_t type2size(DataType type) { return type == DataType::kFLOAT ? sizeof(float) : sizeof(__half); }

    template<typename T> void write(char*& buffer, const T& val)
    {
        *reinterpret_cast<T*>(buffer) = val;
        buffer += sizeof(T);
    }

    template<typename T> void read(const char*& buffer, T& val)
    {
        val = *reinterpret_cast<const T*>(buffer);
        buffer += sizeof(T);
    }

    void* copyToDevice(const void* data, size_t count)
    {
        void* deviceData;
        CHECK(cudaMalloc(&deviceData, count));
        CHECK(cudaMemcpy(deviceData, data, count, cudaMemcpyHostToDevice));
        return deviceData;
    }

    void convertAndCopyToDevice(void*& deviceWeights, const Weights& weights)
    {
        if (weights.type != mDataType) // Weights are converted in host memory first, if the type does not match
        {
            size_t size = weights.count*(mDataType == DataType::kFLOAT ? sizeof(float) : sizeof(__half));
            void* buffer = malloc(size);
            for (int64_t v = 0; v < weights.count; ++v)
                if (mDataType == DataType::kFLOAT)
                    static_cast<float*>(buffer)[v] = fp16::__half2float(static_cast<const __half*>(weights.values)[v]);
                else
                    static_cast<__half*>(buffer)[v] = fp16::__float2half(static_cast<const float*>(weights.values)[v]);

            deviceWeights = copyToDevice(buffer, size);
            free(buffer);
        }
        else
            deviceWeights = copyToDevice(weights.values, weights.count * type2size(mDataType));
    }

    void convertAndCopyToBuffer(char*& buffer, const Weights& weights)
    {
        if (weights.type != mDataType)
            for (int64_t v = 0; v < weights.count; ++v)
                if (mDataType == DataType::kFLOAT)
                    reinterpret_cast<float*>(buffer)[v] = fp16::__half2float(static_cast<const __half*>(weights.values)[v]);
                else
                    reinterpret_cast<__half*>(buffer)[v] = fp16::__float2half(static_cast<const float*>(weights.values)[v]);
        else
            memcpy(buffer, weights.values, weights.count * type2size(mDataType));
        buffer += weights.count * type2size(mDataType);
    }

    void deserializeToDevice(const char*& hostBuffer, void*& deviceWeights, size_t size)
    {
        deviceWeights = copyToDevice(hostBuffer, size);
        hostBuffer += size;
    }

    int mNbOutputChannels, mNbInputChannels;
    Weights mKernelWeights, mBiasWeights;

    DataType mDataType{DataType::kFLOAT};
    void* mDeviceKernel{nullptr};
    void* mDeviceBias{nullptr};

    cudnnHandle_t mCudnn;
    cublasHandle_t mCublas;
    cudnnTensorDescriptor_t mSrcDescriptor, mDstDescriptor;
};

// integration for serialization
class PluginFactory : public nvinfer1::IPluginFactory, public nvuffparser::IPluginFactoryExt
{
public:
    // caffe parser plugin implementation
    bool isPlugin(const char* name) override
    {
        return isPluginExt(name);
    }

    bool isPluginExt(const char* name) override
    {
        return !strcmp(name, "_fc2");
    }

	virtual nvinfer1::IPlugin* createPlugin(const char* layerName, const nvinfer1::Weights* weights, int nbWeights, const nvuffparser::FieldCollection fc) override
	{
		// there's no way to pass parameters through from the model definition, so we have to define it here explicitly
        assert(isPlugin(layerName));
        const nvuffparser::FieldMap* fields = fc.fields;
        int nbFields = fc.nbFields;
        std::cout << "Layer Name: " << layerName << std::endl; //DBG
        std::cout << "fc.fields: " << fc.fields << std::endl ; //DBG
        std::cout << "fc.nbFields: " << fc.nbFields << std::endl; //DBG
        std::cout << "nbWeights: " << nbWeights << std::endl; //DBG
        static const int NB_OUTPUT_CHANNELS = 10;
		assert(isPlugin(layerName));// && nbWeights == 2 && weights[0].type == DataType::kFLOAT && weights[1].type == DataType::kFLOAT);
        if(!strcmp(layerName, "_fc2")){
            assert(mPlugin.get() == nullptr);
            mPlugin = std::unique_ptr<FCPlugin>(new FCPlugin(weights, nbWeights, NB_OUTPUT_CHANNELS));
            return mPlugin.get();
        }
        else{
            assert(0);
            return nullptr;
        }


	}

	// deserialization plugin implementation
	IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
	{
		assert(isPlugin(layerName));
        if(!strcmp(layerName, "_fc2")){
            assert(mPlugin.get() == nullptr);
            mPlugin = std::unique_ptr<FCPlugin>(new FCPlugin(serialData, serialLength));
            return mPlugin.get();
        }
        else{
            assert(0);
            return nullptr;
        }

	}
	// the application has to destroy the plugin when it knows it's safe to do so
	void destroyPlugin()
	{
		mPlugin.release();
	}

	std::unique_ptr<FCPlugin> mPlugin{ nullptr };
};


int main(int argc, char** argv)
{
  auto fileName = locateFile("model.pb.uff");
  std::cout << fileName << std::endl;
  // int maxBatchSize = 1;
   auto parser = createUffParser();
   parser->registerInput(INPUT_BLOB_NAME, DimsCHW(1, 28, 28), UffInputOrder::kNCHW);
   parser->registerOutput(OUTPUT_BLOB_NAME);
	// create a GIE model from the caffe model and serialize it to a stream
	IHostMemory *gieModelStream{ nullptr };
    PluginFactory pluginFactorySerialize;

	// caffeToGIEModel("mnist.prototxt", "mnist.caffemodel", std::vector < std::string > { OUTPUT_BLOB_NAME }, 1, &pluginFactory, gieModelStream);

    ICudaEngine* tmpEngine = UFFtoGIEModel(fileName.c_str(),  1,  parser,  &pluginFactorySerialize,  gieModelStream);
    assert(tmpEngine != nullptr);
    assert(gieModelStream != nullptr);
    tmpEngine->destroy();
    pluginFactorySerialize.destroyPlugin();

	// read a random digit file
	srand(unsigned(time(nullptr)));
	uint8_t fileData[INPUT_H*INPUT_W];
    int num{rand()%10};
    std::cout << "Input file num: " << num << std::endl; //DBG
	readPGMFile(std::to_string(num) + ".pgm", fileData);

    // Deserialize the engine.
    std::cout << "*** deserializing" << std::endl;
    IRuntime* runtime = createInferRuntime(gLogger);
    assert(runtime != nullptr);
    PluginFactory pluginFactory;
    ICudaEngine* engine = runtime->deserializeCudaEngine(gieModelStream->data(), gieModelStream->size(), &pluginFactory);
    assert(engine != nullptr);
    gieModelStream->destroy();
    IExecutionContext* context = engine->createExecutionContext();
    assert(context != nullptr);

	// run inference
	float prob[OUTPUT_SIZE];
	doInference(*context, (float*)fileData, prob, 1);

    std::vector<double> probs;
    int maxIdx = 0;
    for (int i = 0; i < 10; ++i){
        // std::cout << "Prob:" << i<< ": "<<prob[i] << std::endl; //DBG
        probs.push_back(prob[i]);
        if (prob[i] > prob[maxIdx]){
           maxIdx = i;
       }
    }
    std::cout << "maxIdx: " << maxIdx << std::endl;
   // destroy the engine

   context->destroy();
   engine->destroy();
   runtime->destroy();
   pluginFactory.destroyPlugin();
  // print a histogram of the output distribution
    bool pass{false};
	for (int i = 0; i < 10; i++)
    {
        int res = std::floor(prob[i] * 10 + 0.5);
        // std::cout << prob[i] <<std::endl;
        // // std::cout << "Res: " << res << std::endl;
        // std::cout << "i: " <<i << std::endl;
        if (res == 10 && i == num) {pass = true;
		std::cout << i << ": " << std::string(res, '*') << "\n";}
    }
	std::cout << std::endl;

	return pass ? EXIT_SUCCESS : EXIT_FAILURE;
}

convert-to-uff tensorflow --input-file model.pb 
-O fc2/Relu -p config.py
Loading model.pb
Traceback (most recent call last):
  File "/home/dhingratul/.virtualenvs/trt4/bin/convert-to-uff", line 11, in <module>
    sys.exit(main())
  File "/home/dhingratul/.virtualenvs/trt4/local/lib/python2.7/site-packages/uff/bin/convert_to_uff.py", line 105, in main
    output_filename=args.output
  File "/home/dhingratul/.virtualenvs/trt4/local/lib/python2.7/site-packages/uff/converters/tensorflow/conversion_helpers.py", line 149, in from_tensorflow_frozen_model
    return from_tensorflow(graphdef, output_nodes, preprocessor, **kwargs)
  File "/home/dhingratul/.virtualenvs/trt4/local/lib/python2.7/site-packages/uff/converters/tensorflow/conversion_helpers.py", line 55, in from_tensorflow
    pre = importlib.import_module(preprocessor.replace(".py", ""))
  File "/usr/lib/python2.7/importlib/__init__.py", line 37, in import_module
    __import__(name)
  File "tensorrt_samples/samples/samplePlugin_tf/config.py", line 8, in <module>
    "fc2/Relu":fc2_relu,
NameError: name 'fc2_relu' is not defined

Hi,

Sorry that my previous comment is incorrect.

We are trying to write a sample to demonstrate TensorFlow(MNIST) plugin(fc2).
Will share information here once we have done.

Thanks and sorry for any inconvenience we bring to you.

Hi,

You may collapse the ReLU node since the plugin doesn’t perform a ReLU op on the output. It only handles MatMul and BiasAdd.
Our internal team has tested this issue without the plugin and it works correctly.
So either the plugin or preprocessing script is incorrect.

However, we couldn’t find the exact reason as to why the given plugin implementation isn’t working.
The inputs and weights to the plugin are correct and match what is seen by the TensorRT engine when the same layer is run as a FC layer.
So there is no bug in TensorRT, but looks like some issue with the plugin implementation.

We would recommend to use native TensorRT FC layer rather than implementing a plugin for it.
For demonstration of how to use plugins with TensorFlow/UFF, please use TRT 5.0 GA.
It will have both c++ and python samples demonstrating the same.

Thanks

I didn’t write a single line of code for this one, even the plugin is from one of the samples from TensorRT 4RC. I just wanted to check the usage, therefore converted the 4RC sample to 4GA. This was also a test for backward compatibility between versions. As most of my models have more than one custom operation, I want to make sure that upgrading to a new version doesn’t break my code. Also, the latest version I see is TRT 4GA, can you give me the link for 5GA. Also, can you help me with this as well:
https://devtalk.nvidia.com/default/topic/1036974/