problem adding custom TensorRT layer to a network defined using TensorRT API

Hello,
I’m trying to add a custom layer to my architecture defined using the TensorRT API; before implementing more complicated layer I’ve been trying to add to my network the same Reshape layer as the one implemented in the sampleCharRNN sample; this results in a segmentation fault

I add the layer to network this way, and i define it as a output:

Reshape reshape(PREVIOUS_LAYER_NCHANNELS * PREVIOUS_LAYER_H * PREVIOUS_LAYER_W);
ITensor *ptr = previous_layer->getOutput(0);
auto plugin = network->addPlugin(&ptr, 1, reshape);
assert(plugin != nullptr);
plugin->setName("reshape");

plugin->getOutput(0)->setName(OUTPUT_BLOB_NAME);
network->markOutput(*plugin->getOutput(0));

the layer class implementation, and the Plugin factory implementation are as follows (same as the charRNN sample)

// Reshape plugin to feed RNN into FC layer correctly.
class Reshape : public IPlugin
{
public:
	Reshape(size_t size) : mSize(size) {} 
	Reshape(const void*buf, size_t size)
    {
        assert(size == sizeof(mSize));
        mSize = *static_cast<const size_t*>(buf);
    }
	int getNbOutputs() const override													{	return 1;	}
	int initialize() override															{	return 0;	}
	void terminate() override															{}
	size_t getWorkspaceSize(int) const override											{	return 0;	}
	int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
    {
        CHECK(cudaMemcpyAsync(static_cast<float*>(outputs[0]),
                   static_cast<const float*>(inputs[0]),
                   sizeof(float) * mSize * batchSize, cudaMemcpyDefault, stream));
        return 0;
    }
	size_t getSerializationSize() override
    {
        return sizeof(mSize);
    }
	void serialize(void* buffer) override
    {
        (*static_cast<size_t*>(buffer)) = mSize;

    }
	void configure(const Dims*, int, const Dims*, int, int)	override					{ }
    // The RNN outputs in {L, N, C}, but FC layer needs {C, 1, 1}, so we can convert RNN
    // output to {L*N, C, 1, 1} and TensorRT will handle the rest.
	Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
	{
        assert(nbInputDims == 1);
        assert(index == 0);
        assert(inputs[index].nbDims == 3);
		return DimsNCHW(inputs[index].d[1] * inputs[index].d[0], inputs[index].d[2], 1, 1);
	}
    private:
    size_t mSize{0};
};
class PluginFactory : public nvinfer1::IPluginFactory
{
public:
	// deserialization plugin implementation
	IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
	{
        assert(!strncmp(layerName, "reshape", 7));
        if (!mPlugin) mPlugin = new Reshape(serialData, serialLength);
        return mPlugin;
    }
    void destroyPlugin()
    {
        if (mPlugin) delete mPlugin;
        mPlugin = nullptr;
    }
private:
    Reshape *mPlugin{nullptr};
}; // PluginFactory

I manage to create the engine

auto engine = builder->buildCudaEngine(*network);

but as I call the serialize method

(*modelStream) = engine->serialize();

i get a segmentation fault;

without the custom plugin the architecture compiles and work as expected;
unfortunately I cannot inspect the code of the serialize() function to see when exactly it goes wrong;
any clue of what is happening?

thanks,

f

I attach a simplified version of my code which allows to reproduce the problem;
thanks,

f

#include "NvInfer.h"
#include "NvCaffeParser.h"
#include "NvUtils.h"
#include "cuda_runtime_api.h"
#include <cassert>
#include <cmath>
#include <ctime>
#include <time.h>
#include <cstring>
#include <fstream>
#include <iostream>
#include <map>
#include <sstream>
#include <sys/stat.h>
#include <vector>
#include <algorithm>
#include <opencv2/opencv.hpp>
#include <opencv2/core/core.hpp>

#define CHECK(status)                                                                                           \
    {                                                                                                                           \
        if (status != 0)                                                                                                \
        {                                                                                                                               \
            std::cout << "Cuda failure: " << cudaGetErrorString(status) \
                      << " at line " << __LINE__                                                        \
                      << std::endl;                                                                     \
            abort();                                                                                                    \
        }                                                                                                                               \
    }

// stuff we know about the network and the input/output blobs
static const int INPUT_H = 216;
static const int INPUT_W = 384;
static const int INPUT_C = 3;
static const int OUTPUT_TEST_SIZE = 3 * 216 * 384;

static const int BATCH_SIZE = 4;
static const int MAX_BATCH_SIZE = 32;

const char* INPUT_BLOB_NAME = "data";
const char* OUTPUT_BLOB_NAME = "out";

using namespace nvinfer1;
using namespace nvcaffeparser1;

// Logger for GIE info/warning/errors
class Logger : public nvinfer1::ILogger			
{
    public:
	void log(nvinfer1::ILogger::Severity severity, const char* msg) override
	{
		// suppress info-level messages
        if (severity == Severity::kINFO) return;

        switch (severity)
        {
            case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
            case Severity::kERROR: std::cerr << "ERROR: "; break;
            case Severity::kWARNING: std::cerr << "WARNING: "; break;
            case Severity::kINFO: std::cerr << "INFO: "; break;
            default: std::cerr << "UNKNOWN: "; break;
        }
        std::cerr << msg << std::endl;
	}
};

static Logger gLogger;

// print tensor dimensions
void printDims(ITensor* data)
{
    Dims dims = data->getDimensions();
    int nbDims = dims.nbDims;
    for (int d = 0; d < nbDims; d++)
        std::cout << dims.d[d] << " ";
    std::cout << std::endl;
}

///////////////////////////////////////////
///////////////////////////////////////////
///////////////////////////////////////////
class Reshape : public IPlugin
{
public:
	Reshape(size_t size) : mSize(size) {} 
	Reshape(const void*buf, size_t size)
    {
        assert(size == sizeof(mSize));
        mSize = *static_cast<const size_t*>(buf);
    }
	int getNbOutputs() const override													{	return 1;	}
	int initialize() override															{	return 0;	}
	void terminate() override															{}
	size_t getWorkspaceSize(int) const override											{	return 0;	}
	int enqueue(int batchSize, const void*const * inputs, void** outputs, void* workspace, cudaStream_t stream)
    {
        CHECK(cudaMemcpyAsync(static_cast<float*>(outputs[0]),
                   static_cast<const float*>(inputs[0]),
                   sizeof(float) * mSize * batchSize, cudaMemcpyDefault, stream));
        return 0;
    }
	size_t getSerializationSize() override
    {
        return sizeof(mSize);
    }
	void serialize(void* buffer) override
    {
        (*static_cast<size_t*>(buffer)) = mSize;

    }
	void configure(const Dims*, int, const Dims*, int, int)	override					{ }
    // The RNN outputs in {L, N, C}, but FC layer needs {C, 1, 1}, so we can convert RNN
    // output to {L*N, C, 1, 1} and TensorRT will handle the rest.
	Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
	{
        assert(nbInputDims == 1);
        assert(index == 0);
        assert(inputs[index].nbDims == 3);
		return DimsNCHW(inputs[index].d[1] * inputs[index].d[0], inputs[index].d[2], 1, 1);
	}
    private:
    size_t mSize{0};
};
class PluginFactory : public nvinfer1::IPluginFactory
{
public:
	// deserialization plugin implementation
	IPlugin* createPlugin(const char* layerName, const void* serialData, size_t serialLength) override
	{
        assert(!strncmp(layerName, "reshape", 7));
        if (!mPlugin) mPlugin = new Reshape(serialData, serialLength);
        return mPlugin;
    }
    void destroyPlugin()
    {
        if (mPlugin) delete mPlugin;
        mPlugin = nullptr;
    }
private:
    Reshape *mPlugin{nullptr};
}; // PluginFactory

/////////////////////////////////////////////
////////////////////////////////////////////
/////////////////////////////////////////

// Creat the Engine using only the API and not any parser.
ICudaEngine *
createMyEngine(unsigned int maxBatchSize, IBuilder *builder, DataType dt)
{
	INetworkDefinition* network = builder->createNetwork();
    
	// define input
	auto data = network->addInput(INPUT_BLOB_NAME, dt, DimsCHW{INPUT_C, INPUT_H, INPUT_W});
	assert(data != nullptr);
    std::cout << "input" << std::endl;
    printDims(data);

    std::cout << "=========" << std::endl;

Reshape reshape(OUTPUT_TEST_SIZE);
    ITensor *ptr = data;
    auto plugin = network->addPlugin(&ptr, 1, reshape);
    assert(plugin != nullptr);
    plugin->setName("reshape");

    plugin->getOutput(0)->setName(OUTPUT_BLOB_NAME);
	network->markOutput(*plugin->getOutput(0));

std::cout << "Resize in" << std::endl;
    printDims(ptr);
    std::cout << "Resize out" << std::endl;
    printDims(plugin->getOutput(0));

	// Build the engine
	builder->setMaxBatchSize(maxBatchSize);
	builder->setMaxWorkspaceSize(1 << 20);

    std::cout << "***Building the engine***" << std::endl;

	auto engine = builder->buildCudaEngine(*network);

    std::cout << "***Built the engine***" << std::endl;

	// we don't need the network any more
	network->destroy();

	return engine;
}

void APIToModel(unsigned int maxBatchSize, // batch size - NB must be at least as large as the batch we want to run with)
		     IHostMemory **modelStream)
{
	// create the builder
	IBuilder* builder = createInferBuilder(gLogger);

	// create the model to populate the network, then set the outputs and create an engine
    ICudaEngine* engine = createMyEngine(maxBatchSize, builder, DataType::kFLOAT);

    std::cout << "******" << std::endl; 
    std::cout << "******" << std::endl; 
    std::cout << "******" << std::endl; 

	assert(engine != nullptr);

	// serialize the engine, then close everything down
	(*modelStream) = engine->serialize();
    
	engine->destroy();
	builder->destroy();
}

void doInference(IExecutionContext& context, float* input, float* output, int batchSize)
{
	const ICudaEngine& engine = context.getEngine();
	// input and output buffer pointers that we pass to the engine - the engine requires exactly IEngine::getNbBindings(),
	// of these, but in this case we know that there is exactly one input and one output.
	assert(engine.getNbBindings() == 2);
	void* buffers[2];

	// In order to bind the buffers, we need to know the names of the input and output tensors.
	// note that indices are guaranteed to be less than IEngine::getNbBindings()
	int inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME),
        outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);

	// create GPU buffers and a stream
	CHECK(cudaMalloc(&buffers[inputIndex], batchSize * INPUT_H * INPUT_W * INPUT_C * sizeof(float)));
    std::cout << "Allocating output " << batchSize * OUTPUT_TEST_SIZE * sizeof(float) << std::endl;
    CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_TEST_SIZE * sizeof(float)));

	cudaStream_t stream;
	CHECK(cudaStreamCreate(&stream));

	// DMA the input to the GPU,  execute the batch asynchronously, and DMA it back:
	CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_H * INPUT_W * INPUT_C * sizeof(float), cudaMemcpyHostToDevice, stream));
	context.enqueue(batchSize, buffers, stream, nullptr);
    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_TEST_SIZE * sizeof(float), cudaMemcpyDeviceToHost, stream));
    std::cout << "H2D " << batchSize * INPUT_H * INPUT_W * INPUT_C << std::endl; 
    std::cout << "D2H " << batchSize * OUTPUT_TEST_SIZE << std::endl; 
	cudaStreamSynchronize(stream);

	// release the stream and the buffers
	cudaStreamDestroy(stream);
	CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));
}

// rearrange image data to [N, C, H, W] order
void prepareDataBatch(float *data, std::vector<cv::Mat> &frames)
{
     assert(data && !frames.empty());
     unsigned int volChl = INPUT_H * INPUT_W;
     unsigned int volImg = INPUT_H * INPUT_W * INPUT_C;
     
     for (int b = 0; b < BATCH_SIZE; b++)
         for (int c = 0; c < INPUT_C; c++)
         {
              // the color image to input should be in BGR order
              for (unsigned j = 0; j < volChl; j++)
                   data[b * volImg + c * volChl + j] = float(frames[b].data[j * INPUT_C + c]) / 255.0;
         }
     
     return;
}

void preprocessData(float *data)
{
    int N = INPUT_H * INPUT_W * INPUT_C;
    float mean = 0.0;
    float stddev = 0.0;

    for (int n = 0; n < N; n++)
        mean += data[n];
    mean /= N;

    for (int n = 0; n < N; n++)
        stddev += (data[n] - mean) * (data[n] - mean);
    stddev = sqrt(stddev / (N - 1));

    float min_stddev = (1.0 / sqrt(N * 1.0));
    float adjusted_stddev = stddev;
    if (min_stddev > stddev)
        adjusted_stddev = min_stddev;

    for (int n = 0; n < N; n++)
        data[n] = (data[n] - mean) / adjusted_stddev;

    return;  
}

void preprocessDataBatch(float *data)
{
    int N = INPUT_H * INPUT_W * INPUT_C;
    
    for (int b = 0; b < BATCH_SIZE; b++)
    {
        float *data_img = data + N * b;
        preprocessData(data_img);
    }
    
    return;  
}

int main(int argc, char** argv)
{
    //read input, convert, resize
    std::vector<std::string> image_paths;
    image_paths.push_back("~/images/313129722_faa438c3fe_o.jpg");
    image_paths.push_back("~/images/3403366735_bef4d39487_o.jpg");
    image_paths.push_back("~/images/100338072_1910x1000.jpg");
    image_paths.push_back("~/images/K1KP72.jpg");

    std::cout << image_paths.size() << std::endl;
    std::vector<cv::Mat> images(image_paths.size());

    int im_id(0); 
    for (std::vector<std::string>::iterator it = image_paths.begin(); it != image_paths.end(); ++it)
    {
        std::cout << *it << std::endl;
        cv::Mat image_bgr = cv::imread(*it);
        cv::Mat image_rgb, image;
        cv::cvtColor(image_bgr, image_rgb, CV_BGR2RGB);
        cv::resize(image_rgb, image, cv::Size(INPUT_W, INPUT_H), 0, 0, CV_INTER_LINEAR);
        images[im_id] = image;
        im_id++;
    }

    // allocate CPU memory for input and output
    int inputSize = sizeof(float) * BATCH_SIZE * INPUT_C * INPUT_H * INPUT_W;
    int outputSize = sizeof(float) * BATCH_SIZE * OUTPUT_TEST_SIZE;
    float *data = (float *)malloc(inputSize);
    float *out = (float *)malloc(outputSize);

    // flatten image Mat, convert to float, do TF-style whitening
    prepareDataBatch(data, images);
    preprocessDataBatch(data);

	// create a model using the API directly and serialize it to a stream
    IHostMemory *modelStream{nullptr};
    APIToModel(MAX_BATCH_SIZE, &modelStream);
	IRuntime* runtime = createInferRuntime(gLogger);

    PluginFactory pluginFactory; // needed by the plugin

//	ICudaEngine* engine = runtime->deserializeCudaEngine(modelStream->data(), modelStream->size(), nullptr); // when only native layers are used
    ICudaEngine* engine = runtime->deserializeCudaEngine(modelStream->data(), modelStream->size(), &pluginFactory); // needed by the plugin

    if (modelStream) modelStream->destroy();
	IExecutionContext *context = engine->createExecutionContext();

    // run inference
    doInference(*context, data, out, BATCH_SIZE);
    
    // destroy the engine
    context->destroy();
    engine->destroy();
    runtime->destroy();
    pluginFactory.destroyPlugin(); // needed by the plugin

//free mem
    free(data);
    free(out);

    return 0;
}

Is it possible for you to run the application with GDB and when it crashes get the backtrace? Also to get a better idea what is wrong once we have a backtrace could you also provide the version of TensorRT that you are using?

Hi,
thanks for replying

from GDB:

(gdb) bt
#0  0x00007fffee6df474 in nvinfer1::cudnn::PluginLayer::serializeParams(flatbuffers::FlatBufferBuilder&) const () from /usr/lib/x86_64-linux-gnu/libnvinfer.so.4
#1  0x00007fffee65b12d in nvinfer1::cudnn::Engine::serialize() const () from /usr/lib/x86_64-linux-gnu/libnvinfer.so.4
#2  0x00000000004035af in APIToModel(unsigned int, nvinfer1::IHostMemory**) ()
#3  0x000000000040274b in main ()

I use TensorRT 3.0.0

thanks again,

f

The plugin object is created in function createMyEngine (line 169), and its life time is bound to that function. The engine however is passed back to APIToModel, and when you are serializing the plugin is gone. In general, the plugin instance must live as long as the engine.

A solution could be to move the plugin to APIToModel, and pass a reference to createMyEngine.

In fact, that was exactly the problem;
thanks for the help,

f