Plugin to convert to and from half precision within the network

Hello,

I am trying to implement a network with the TensorRT 5 API, which I need to work in half precision mode.

To do so, I implemented two IPluginV2Ext to run the conversion float–>half and half–>float. The input to the network is then FP32, then the float–>half conversion plugin is called, all the computation is performed in FP16 and finally half–>float plugin is called.

The problem is that after the first conversion, the datatype keeps being float32 all through the net.

I implemented the plugin with the IPluginV2Ext interface. I attach the code below.

Any clue about what is going wrong or if is there a different strategy to do data conversion with the API?

Thanks,

f

class CastFloat2HalfLayer : public IPluginV2Ext
{
public:
    CastFloat2HalfLayer(const int num_elements)
    {
        numElements = num_elements;
    }

    CastFloat2HalfLayer(const void* data, size_t length)
    {
        const char* d = static_cast<const char*>(data);
        numElements = read<int>(d);
    }

    // It makes no sense to construct UffPoolPluginV2 without arguments.
    CastFloat2HalfLayer() = delete;

    virtual ~CastFloat2HalfLayer() {}

    int getNbOutputs() const override
    {
        return 1;
    }

    Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
    {
        assert(nbInputDims == 1);
        assert(inputs[0].nbDims == 3);
        assert(index == 0);

        int out_c = inputs[0].d[0];
        int out_h = inputs[0].d[1];
        int out_w = inputs[0].d[2];

        return DimsCHW(out_c, out_h, out_w);
    }

    void attachToContext(cudnnContext *, cublasContext * , IGpuAllocator *) override { ; } 

    bool canBroadcastInputAcrossBatch(int inputIndex) const override { return 0; } 

    void configurePlugin(const Dims *  inputDims,
                          int   nbInputs,
                          const Dims *   outputDims,
                          int   nbOutputs,
                          const DataType *   inputTypes,
                          const DataType *   outputTypes,
                          const bool *   inputIsBroadcast,
                          const bool *   outputIsBroadcast,
                          PluginFormat   floatFormat,
                          int   maxBatchSize 
                          )  override { numElements = inputDims[0].d[0] * inputDims[0].d[1] * inputDims[0].d[2]; }

    void detachFromContext() override { ; }

    DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const override { return DataType::kHALF; }

    int getTensorRTVersion() const override { return 5; }

    bool isOutputBroadcastAcrossBatch(int outputIndex, const bool * inputIsBroadcasted, int nbInputs ) const override { return 0; }

    int initialize() override { return 0; }

    void terminate() override { ; }

    size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }

    int enqueue(int batch_size, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override
    {
        int num_elements = numElements;
        float *data_in = (float*)inputs[0];
        __half *data_out = (__half*)outputs[0];

        cudaCastFloatToHalf_device(data_in, data_out, num_elements, stream);

        return 0;
    }

    size_t getSerializationSize() const { return sizeof(int); }

    void serialize(void* buffer) const
    {
        char *d = reinterpret_cast<char*>(buffer);
        write(d, numElements);
    }

    void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override
    {
        assert(nbOutputs == 1);
        assert(nbInputs == 1);
        assert(inputs[0].nbDims == 3);
    }

    bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kNCHW); }

    const char* getPluginType() const override { return "CastFloat2Half_TRT"; }

    const char* getPluginVersion() const override { return "1"; }

    void destroy() override { delete this; }

    IPluginV2Ext* clone() const override { return new CastFloat2HalfLayer(numElements); }

    void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }

    const char* getPluginNamespace() const override { return mNamespace.c_str(); }

private:
    template <typename T>
    void write(char*& buffer, const T& val) const
    {
        *reinterpret_cast<T*>(buffer) = val;
        buffer += sizeof(T);
    }

    template <typename T>
    T read(const char*& buffer)
    {
        T val = *reinterpret_cast<const T*>(buffer);
        buffer += sizeof(T);
        return val;
    }

    int numElements;
    std::string mNamespace;
};

namespace
{
const char* CASTFLOAT2HALFLAYER_PLUGIN_VERSION{"1"};
const char* CASTFLOAT2HALFLAYER_PLUGIN_NAME{"CastFloat2Half_TRT"};
} // namespace

Could you please let us know if you are still facing this issue?

Thanks