Hello,
I am trying to implement a network with the TensorRT 5 API, which I need to work in half precision mode.
To do so, I implemented two IPluginV2Ext to run the conversion float–>half and half–>float. The input to the network is then FP32, then the float–>half conversion plugin is called, all the computation is performed in FP16 and finally half–>float plugin is called.
The problem is that after the first conversion, the datatype keeps being float32 all through the net.
I implemented the plugin with the IPluginV2Ext interface. I attach the code below.
Any clue about what is going wrong or if is there a different strategy to do data conversion with the API?
Thanks,
f
class CastFloat2HalfLayer : public IPluginV2Ext
{
public:
CastFloat2HalfLayer(const int num_elements)
{
numElements = num_elements;
}
CastFloat2HalfLayer(const void* data, size_t length)
{
const char* d = static_cast<const char*>(data);
numElements = read<int>(d);
}
// It makes no sense to construct UffPoolPluginV2 without arguments.
CastFloat2HalfLayer() = delete;
virtual ~CastFloat2HalfLayer() {}
int getNbOutputs() const override
{
return 1;
}
Dims getOutputDimensions(int index, const Dims* inputs, int nbInputDims) override
{
assert(nbInputDims == 1);
assert(inputs[0].nbDims == 3);
assert(index == 0);
int out_c = inputs[0].d[0];
int out_h = inputs[0].d[1];
int out_w = inputs[0].d[2];
return DimsCHW(out_c, out_h, out_w);
}
void attachToContext(cudnnContext *, cublasContext * , IGpuAllocator *) override { ; }
bool canBroadcastInputAcrossBatch(int inputIndex) const override { return 0; }
void configurePlugin(const Dims * inputDims,
int nbInputs,
const Dims * outputDims,
int nbOutputs,
const DataType * inputTypes,
const DataType * outputTypes,
const bool * inputIsBroadcast,
const bool * outputIsBroadcast,
PluginFormat floatFormat,
int maxBatchSize
) override { numElements = inputDims[0].d[0] * inputDims[0].d[1] * inputDims[0].d[2]; }
void detachFromContext() override { ; }
DataType getOutputDataType(int index, const DataType* inputTypes, int nbInputs) const override { return DataType::kHALF; }
int getTensorRTVersion() const override { return 5; }
bool isOutputBroadcastAcrossBatch(int outputIndex, const bool * inputIsBroadcasted, int nbInputs ) const override { return 0; }
int initialize() override { return 0; }
void terminate() override { ; }
size_t getWorkspaceSize(int maxBatchSize) const override { return 0; }
int enqueue(int batch_size, const void*const *inputs, void** outputs, void*, cudaStream_t stream) override
{
int num_elements = numElements;
float *data_in = (float*)inputs[0];
__half *data_out = (__half*)outputs[0];
cudaCastFloatToHalf_device(data_in, data_out, num_elements, stream);
return 0;
}
size_t getSerializationSize() const { return sizeof(int); }
void serialize(void* buffer) const
{
char *d = reinterpret_cast<char*>(buffer);
write(d, numElements);
}
void configureWithFormat(const Dims* inputs, int nbInputs, const Dims* outputDims, int nbOutputs, nvinfer1::DataType type, nvinfer1::PluginFormat format, int maxBatchSize) override
{
assert(nbOutputs == 1);
assert(nbInputs == 1);
assert(inputs[0].nbDims == 3);
}
bool supportsFormat(DataType type, PluginFormat format) const override { return (type == DataType::kFLOAT && format == PluginFormat::kNCHW); }
const char* getPluginType() const override { return "CastFloat2Half_TRT"; }
const char* getPluginVersion() const override { return "1"; }
void destroy() override { delete this; }
IPluginV2Ext* clone() const override { return new CastFloat2HalfLayer(numElements); }
void setPluginNamespace(const char* libNamespace) override { mNamespace = libNamespace; }
const char* getPluginNamespace() const override { return mNamespace.c_str(); }
private:
template <typename T>
void write(char*& buffer, const T& val) const
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template <typename T>
T read(const char*& buffer)
{
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
int numElements;
std::string mNamespace;
};
namespace
{
const char* CASTFLOAT2HALFLAYER_PLUGIN_VERSION{"1"};
const char* CASTFLOAT2HALFLAYER_PLUGIN_NAME{"CastFloat2Half_TRT"};
} // namespace