Unexpected behavior of TopKLayer

Hello,
I am confused about the usage of the TopK layer;

  1. As from the documentation "The TopK layer has two outputs of the same dimensions. The first contains data values, the second contains index positions for the values."; https://docs.nvidia.com/deeplearning/sdk/tensorrt-api/topics/classnvinfer1_1_1_i_network_definition.html#a384a409318bf416be3aa4442f2b0ce76 what I experience, though, is that the method getNbOutputs() returns a 1; and in fact calling the method getOuput(index) with index = 1, it results in a segmentation fault;
  2. I would expect the size of the output - may it be data values or index positions - for input of size [BatchSize, NumChannels, Height, Width], to be something like [BatchSize, NumChannels, Height, K], where K would be the number of elements to keep, and the reduction axis would be the Width one; instead, in my experience the output size is the same as the input size;
  3. Trying to execute the code attached below, I obtain the error
    ERROR: topk: input and output must have the same number of dimensions
    topkdebug: topKdbg.cpp:139: void APIToModel(unsigned int, nvinfer1::IHostMemory**): Assertion `engine != nullptr' failed.
    

    I tried both defining output size as expected and defining it the same as the input; same problem;

is that the expected behaviour or it is me misunderstanding?
can anyone show me an example of the correct usage of the layer? I know it is used in the sampleCharRNN sample but the code doesn’t execute for me either.

I use TensorRT 4.0.0.3 with CUDA 8.0;

Please find my toy example code below to reproduce the issue;

Thanks,

f

#include "NvInfer.h"
#include "NvCaffeParser.h"
#include "NvUtils.h"
#include "cuda_runtime_api.h"
#include <cassert>
#include <cmath>
#include <cstring>
#include <string>
#include <fstream>
#include <iostream>
#include <sstream>
#include <sys/stat.h>
#include <vector>
#include <algorithm>

#define CHECK(status)									\
{														\
	if (status != 0)									\
	{													\
		std::cout << "Cuda failure: " << status;		\
		abort();										\
	}													\
}


// Logger for GIE info/warning/errors
class Logger : public nvinfer1::ILogger			
{
    public:
	void log(nvinfer1::ILogger::Severity severity, const char* msg) override
	{
		// suppress info-level messages
        if (severity == Severity::kINFO) return;

        switch (severity)
        {
            case Severity::kINTERNAL_ERROR: std::cerr << "INTERNAL_ERROR: "; break;
            case Severity::kERROR: std::cerr << "ERROR: "; break;
            case Severity::kWARNING: std::cerr << "WARNING: "; break;
            case Severity::kINFO: std::cerr << "INFO: "; break;
            default: std::cerr << "UNKNOWN: "; break;
        }
        std::cerr << msg << std::endl;
	}
};



// def constants
static const int K_TOP = 10; //3;
static const int INPUT_H = 1;
static const int INPUT_W = 10;
static const int INPUT_C = 5;
static const int OUTPUT_H = 1;
//static const int OUTPUT_W = 10;
static const int OUTPUT_W = K_TOP;
static const int OUTPUT_C = 5;

static const int BATCH_SIZE = 4;
static const int MAX_BATCH_SIZE = 4;

const char* INPUT_BLOB_NAME = "data";
const char* OUTPUT_BLOB_NAME = "out";

static void* buffers[2];
static cudaStream_t stream;
static int inputIndex, outputIndex;

using namespace nvinfer1;
static Logger gLogger;



// print tensor dimensions
void printDims(ITensor* data)
{
    Dims dims = data->getDimensions();
    int nbDims = dims.nbDims;
    for (int d = 0; d < nbDims; d++)
        std::cout << dims.d[d] << " ";
    std::string sss;    
    if (data->getType() == DataType::kHALF)
        sss = "float16";
    if (data->getType() == DataType::kFLOAT)
        sss = "float32";
    std::cout << sss << " ";
    std::cout << std::endl;
}



void APIToModel(unsigned int maxBatchSize, IHostMemory **modelStream)
{
	// create the builder
	IBuilder* builder = createInferBuilder(gLogger);

///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
    INetworkDefinition* network = builder->createNetwork();
    

	// define input
	auto data = network->addInput(INPUT_BLOB_NAME, DataType::kFLOAT, DimsCHW{INPUT_C, INPUT_H, INPUT_W});
	assert(data != nullptr);
    std::cout << "input" << std::endl;
    printDims(data);
    

    // apply topK
    int reduceAxis = 0x4;
    auto topk = network->addTopK(*data, TopKOperation::kMAX, K_TOP, reduceAxis);
    topk->setName("topk");
    std::cout << "topk0" << std::endl;
    printDims(topk->getOutput(0));
    topk->getOutput(0)->setName(OUTPUT_BLOB_NAME);
	network->markOutput(*topk->getOutput(0));
//    std::cout << "topk1" << std::endl;
//    printDims(topk->getOutput(1));
//    std::cout << topk->getNbOutputs() << std::endl;



///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////

	// Build the engine
	builder->setMaxBatchSize(maxBatchSize);
	builder->setMaxWorkspaceSize(1e9);

    std::cout << "building the engine..." << std::endl;
	auto engine = builder->buildCudaEngine(*network);
         assert(engine != nullptr);
    std::cout << "engine built!" << std::endl;

	// serialize the engine, then close everything down
	(*modelStream) = engine->serialize();

///////////////////////////////////////////////////////////

    network->destroy();
	engine->destroy();
	builder->destroy();
}


void setUpDevice(IExecutionContext& context, int batchSize)
{
    const ICudaEngine& engine = context.getEngine();
    // input and output buffer pointers that we pass to the engine - the engine requires exactly IEngine::getNbBindings(),
    // of these, but in this case we know that there is exactly one input and one output.
    assert(engine.getNbBindings() == 2);

    // In order to bind the buffers, we need to know the names of the input and output tensors.
    // note that indices are guaranteed to be less than IEngine::getNbBindings()
    inputIndex = engine.getBindingIndex(INPUT_BLOB_NAME);
    outputIndex = engine.getBindingIndex(OUTPUT_BLOB_NAME);

    // create GPU buffers and a stream
    CHECK(cudaMalloc(&buffers[inputIndex], batchSize * INPUT_H * INPUT_W * INPUT_C * sizeof(float)));
    CHECK(cudaMalloc(&buffers[outputIndex], batchSize * OUTPUT_W * OUTPUT_H * OUTPUT_C * sizeof(float)));

    // create cuda stream
    CHECK(cudaStreamCreate(&stream));
}

void cleanUp()
{
  	// release the stream and the buffers
	cudaStreamDestroy(stream);
	CHECK(cudaFree(buffers[inputIndex]));
    CHECK(cudaFree(buffers[outputIndex]));
}

void doInference(IExecutionContext& context, float* input, float* output, int batchSize)
{
	// DMA the input to the GPU, execute the batch asynchronously, and DMA it back:
	CHECK(cudaMemcpyAsync(buffers[inputIndex], input, batchSize * INPUT_H * INPUT_W * INPUT_C * sizeof(float), cudaMemcpyHostToDevice, stream));
	context.enqueue(batchSize, buffers, stream, nullptr);
    CHECK(cudaMemcpyAsync(output, buffers[outputIndex], batchSize * OUTPUT_W * OUTPUT_H * OUTPUT_C * sizeof(float), cudaMemcpyDeviceToHost, stream));
	cudaStreamSynchronize(stream);
}





void printData(float *out, const int batch_size, const int output_c,  const int output_h,  const int output_w)
{
    int output_size(output_c * output_h * output_w);

    std::cout << "================="<< std::endl;   
    std::cout << "================="<< std::endl;
    for (int b = 0; b < batch_size; b++)
        {
        std::cout << "-----------------"<< std::endl;
        for (int c = 0; c < output_c; c++)
            {
                for (int h = 0; h < output_h; h++)
                {
                    for (int w = 0; w < output_w; w++)
                        std::cout << out[b * output_size + c * output_h * output_w + h * output_w + w] << " ";
                    std::cout << std::endl;
                }
            std::cout << "-----------------"<< std::endl; 
            }
        std::cout << "================="<< std::endl;   
        std::cout << "================="<< std::endl;
        }

    return;
}





int main(int argc, char** argv)
{
    // allocate CPU memory for input and output
    int inputSize = sizeof(float) * BATCH_SIZE * INPUT_C * INPUT_H * INPUT_W;
    int outputSize = sizeof(float) * BATCH_SIZE * OUTPUT_C * OUTPUT_W * OUTPUT_H;
    float *data = (float *)malloc(inputSize);
    float *out = (float *)malloc(outputSize);

    // init dummy input
    srand (time(NULL));
    for (int d = 0; d < BATCH_SIZE * INPUT_W * INPUT_H * INPUT_C; d++)
        data[d] = rand() % 100 + 1;;

    // print input
    printData(data, BATCH_SIZE, INPUT_C, INPUT_H, INPUT_W);

	// create a model using the API directly and serialize it to a stream
    IHostMemory *modelStream{nullptr};
    APIToModel(MAX_BATCH_SIZE, &modelStream);
	IRuntime* runtime = createInferRuntime(gLogger);
    ICudaEngine* engine = runtime->deserializeCudaEngine(modelStream->data(), modelStream->size(), nullptr);
    if (modelStream) modelStream->destroy();
	IExecutionContext *context = engine->createExecutionContext();

    // allocate device memory, do bindings
    setUpDevice(*context, BATCH_SIZE);

    // run inference
    doInference(*context, data, out, BATCH_SIZE);

	// destroy the engine
	context->destroy();
	engine->destroy();
	runtime->destroy();

    // free device memory
    cleanUp();

    // print output
    printData(out, BATCH_SIZE, OUTPUT_C, OUTPUT_H, OUTPUT_W);
    
    //free host mem
    free(data);
    free(out);

    std::cout << "done!" << std::endl;

    return 0;
}