I create a plugin layer with PluginV2
It has 3 input with type of int32 int32 float, I hope the output is float, but when I check the output type, find the output is int32.
PluginV2 has NO interface to set the output type, I want to know how to set the output type when using the PluginV2
class ScatterPlugin : public IPluginV2 {
const int input_dims = 3;
int kFeatureChannel = 64;
int kMaxPillarsCount = 30000;
int grid_size_;
protected:
void deserialize(void const* data, size_t length) {
const char* d = static_cast<const char*>(data);
read(d, kFeatureChannel);
read(d, kMaxPillarsCount);
read(d, grid_size_);
}
size_t getSerializationSize() const override {
return sizeof(kFeatureChannel) + sizeof(kMaxPillarsCount) + sizeof(grid_size_);
}
void serialize(void *buffer) const override {
char* d = static_cast<char*>(buffer);
write(d, kFeatureChannel);
write(d, kMaxPillarsCount);
write(d, grid_size_);
}
public:
ScatterPlugin(int grid_size, int pillars_count, int feature_channel)
: kMaxPillarsCount(pillars_count), grid_size_(grid_size), kFeatureChannel(feature_channel) {
std::cout << "ScatterPlugin " << grid_size << "," << pillars_count << "," << feature_channel << std::endl;
assert(kMaxPillarsCount > 0);
assert(grid_size_ > 0);
}
ScatterPlugin(void const* data, size_t length) {
std::cout << "ScatterPlugin " << length << std::endl;
this->deserialize(data, length);
}
const char *getPluginType() const override {
return POINTPILLARS_PLUGIN_NAME;
}
const char *getPluginVersion() const override {
return POINTPILLARS_PLUGIN_VERSION;
}
int getNbOutputs() const override {
// FIXME:
return 1;
}
Dims getOutputDimensions(int index,
const Dims *inputs, int nbInputDims) override {
// pillar count
// pillar coordinate
// fpn output
std::cout << "getOutputDimensions " << index << std::endl;
assert(nbInputDims == input_dims);
assert(index < this->getNbOutputs());
// FIXME:
return Dims3(kFeatureChannel, grid_size_, grid_size_);
}
bool supportsFormat(DataType type, PluginFormat format) const override {
return type == DataType::kFLOAT && format == PluginFormat::kNCHW;
}
void configureWithFormat(const Dims* inputDims, int nbInputs, const Dims* outputDims,
int nbOutputs, DataType type, PluginFormat format, int maxBatchSize) override {
std::cout << "configureWithFormat " << nbInputs << "," << nbOutputs << "," << kFeatureChannel << std::endl;
assert(type == nvinfer1::DataType::kFLOAT && format == nvinfer1::PluginFormat::kNCHW);
assert(nbInputs == input_dims);
assert(nbOutputs == 1);
assert(inputDims[1].d[0] == (kMaxPillarsCount));
assert(inputDims[0].d[0] > 0);
assert(inputDims[2].d[0] == 1
&& inputDims[2].d[1] == kMaxPillarsCount
&& inputDims[2].d[2] == kFeatureChannel);
}
int initialize() override {
std::cout << "initialize " << std::endl;
return 0; }
void terminate() override {}
size_t getWorkspaceSize(int maxBatchSize) const override {
// FIXME:
static int size = 10;
return size;
}
int enqueue(int batchSize,
const void *const *inputs, void **outputs,
void *workspace, cudaStream_t stream) override {
const void* const pillar_count = inputs[0];
const void* const pillar2grid = inputs[1];
const void* const pfe_ouput = inputs[2];
void* scattered_feature = outputs[0];
cuda::doScatterCuda(grid_size_*grid_size_, kFeatureChannel,
kMaxPillarsCount, pillar_count, pillar2grid, pfe_ouput,
scattered_feature, stream);
return 0;
}
void destroy() override {
delete this;
}
const char *getPluginNamespace() const override {
return POINTPILLARS_PLUGIN_NAMESPACE;
}
void setPluginNamespace(const char *N) override {
}
IPluginV2 *clone() const override {
std::cout << "Clone\n";
return new ScatterPlugin(grid_size_, kMaxPillarsCount, kFeatureChannel);
}
private:
template<typename T> void write(char*& buffer, const T& val) const {
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
template<typename T> void read(const char*& buffer, T& val) {
val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
}
};