17 #include "fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h" 21 class AdaptivePool2d :
public BasePlugin {
23 AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
25 AdaptivePool2d(
const void* buffer,
size_t length);
27 ~AdaptivePool2d()
override =
default;
29 int getNbOutputs() const noexcept override;
32 getOutputDimensions(
int outputIndex, const nvinfer1::DimsExprs* inputs,
34 nvinfer1::IExprBuilder& exprBuilder) noexcept override;
36 nvinfer1::DataType getOutputDataType(
int index,
37 const nvinfer1::DataType* inputType,
38 int nbInputs) const noexcept override;
40 bool supportsFormatCombination(
int pos,
41 const nvinfer1::PluginTensorDesc* inOut,
42 int nbInputs,
int nbOutputs) noexcept override;
44 int initialize() noexcept override;
46 void terminate() noexcept override;
48 size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
50 const nvinfer1::PluginTensorDesc* outputs,
51 int nbOutputs) const noexcept override;
53 int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
54 const nvinfer1::PluginTensorDesc* outputDesc,
55 const
void* const* inputs,
void* const* outputs,
void* workspace,
56 cudaStream_t stream) noexcept override;
58 size_t getSerializationSize() const noexcept override;
60 void serialize(
void* buffer) const noexcept override;
62 const
char* getPluginType() const noexcept override;
64 const
char* getPluginVersion() const noexcept override;
65 void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
67 const nvinfer1::DynamicPluginTensorDesc* out,
68 int nbOutputs) noexcept override;
69 void destroy() noexcept override;
71 nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
74 std::vector<int32_t> output_size_;
75 std::
string pooling_type_;
78 class AdaptivePool2dPluginCreator : public BaseCreator {
80 AdaptivePool2dPluginCreator();
82 ~AdaptivePool2dPluginCreator()
override =
default;
84 const char* getPluginName() const noexcept override;
86 const
char* getPluginVersion() const noexcept override;
88 const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
90 nvinfer1::IPluginV2DynamicExt*
91 createPlugin(const
char* name,
92 const nvinfer1::PluginFieldCollection* fc) noexcept override;
94 nvinfer1::IPluginV2DynamicExt*
95 deserializePlugin(const
char* name, const
void* serialData,
96 size_t serialLength) noexcept override;
99 static nvinfer1::PluginFieldCollection mFC;
100 static
std::vector<nvinfer1::PluginField> mPluginAttributes;
101 std::vector<int32_t> output_size_;
102 std::
string pooling_type_;
105 REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
Definition: float16.h:572
All C++ FastDeploy APIs are defined inside this namespace.
Definition: option.h:16