FastDeploy  latest
Fast & Easy to Deploy!
adaptive_pool2d.h
1 // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 #include "common.h" // NOLINT
17 #include "fastdeploy/runtime/backends/common/cuda/adaptive_pool2d_kernel.h"
18 
19 namespace fastdeploy {
20 
21 class AdaptivePool2d : public BasePlugin {
22  public:
23  AdaptivePool2d(std::vector<int32_t> output_size, std::string pooling_type);
24 
25  AdaptivePool2d(const void* buffer, size_t length);
26 
27  ~AdaptivePool2d() override = default;
28 
29  int getNbOutputs() const noexcept override;
30 
31  nvinfer1::DimsExprs
32  getOutputDimensions(int outputIndex, const nvinfer1::DimsExprs* inputs,
33  int nbInputs,
34  nvinfer1::IExprBuilder& exprBuilder) noexcept override;
35 
36  nvinfer1::DataType getOutputDataType(int index,
37  const nvinfer1::DataType* inputType,
38  int nbInputs) const noexcept override;
39 
40  bool supportsFormatCombination(int pos,
41  const nvinfer1::PluginTensorDesc* inOut,
42  int nbInputs, int nbOutputs) noexcept override;
43 
44  int initialize() noexcept override;
45 
46  void terminate() noexcept override;
47 
48  size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
49  int nbInputs,
50  const nvinfer1::PluginTensorDesc* outputs,
51  int nbOutputs) const noexcept override;
52 
53  int enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
54  const nvinfer1::PluginTensorDesc* outputDesc,
55  const void* const* inputs, void* const* outputs, void* workspace,
56  cudaStream_t stream) noexcept override;
57 
58  size_t getSerializationSize() const noexcept override;
59 
60  void serialize(void* buffer) const noexcept override;
61 
62  const char* getPluginType() const noexcept override;
63 
64  const char* getPluginVersion() const noexcept override;
65  void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
66  int nbInputs,
67  const nvinfer1::DynamicPluginTensorDesc* out,
68  int nbOutputs) noexcept override;
69  void destroy() noexcept override;
70 
71  nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
72 
73  private:
74  std::vector<int32_t> output_size_;
75  std::string pooling_type_;
76 };
77 
78 class AdaptivePool2dPluginCreator : public BaseCreator {
79  public:
80  AdaptivePool2dPluginCreator();
81 
82  ~AdaptivePool2dPluginCreator() override = default;
83 
84  const char* getPluginName() const noexcept override;
85 
86  const char* getPluginVersion() const noexcept override;
87 
88  const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
89 
90  nvinfer1::IPluginV2DynamicExt*
91  createPlugin(const char* name,
92  const nvinfer1::PluginFieldCollection* fc) noexcept override;
93 
94  nvinfer1::IPluginV2DynamicExt*
95  deserializePlugin(const char* name, const void* serialData,
96  size_t serialLength) noexcept override;
97 
98  private:
99  static nvinfer1::PluginFieldCollection mFC;
100  static std::vector<nvinfer1::PluginField> mPluginAttributes;
101  std::vector<int32_t> output_size_;
102  std::string pooling_type_;
103 };
104 
105 REGISTER_TENSORRT_PLUGIN(AdaptivePool2dPluginCreator);
106 
107 } // namespace fastdeploy
Definition: float16.h:572
All C++ FastDeploy APIs are defined inside this namespace.
Definition: option.h:16