FastDeploy  latest
Fast & Easy to Deploy!
multiclass_nms.h
1 // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include <map>
18 
19 #ifndef NON_64_PLATFORM
20 #include "onnxruntime_cxx_api.h" // NOLINT
21 
22 namespace fastdeploy {
23 
24 struct MultiClassNmsKernel {
25  protected:
26  int64_t background_label = -1;
27  int64_t keep_top_k = -1;
28  float nms_eta;
29  float nms_threshold = 0.7;
30  int64_t nms_top_k;
31  bool normalized;
32  float score_threshold;
33  Ort::CustomOpApi ort_;
34 
35  public:
36  MultiClassNmsKernel(Ort::CustomOpApi ort, const OrtKernelInfo* info)
37  : ort_(ort) {
38  GetAttribute(info);
39  }
40 
41  void GetAttribute(const OrtKernelInfo* info);
42 
43  void Compute(OrtKernelContext* context);
44  void FastNMS(const float* boxes, const float* scores, const int& num_boxes,
45  std::vector<int>* keep_indices);
46  int NMSForEachSample(const float* boxes, const float* scores, int num_boxes,
47  int num_classes,
48  std::map<int, std::vector<int>>* keep_indices);
49 };
50 
51 struct MultiClassNmsOp
52  : Ort::CustomOpBase<MultiClassNmsOp, MultiClassNmsKernel> {
53  void* CreateKernel(Ort::CustomOpApi api, const OrtKernelInfo* info) const {
54  return new MultiClassNmsKernel(api, info);
55  }
56 
57  const char* GetName() const { return "MultiClassNMS"; }
58 
59  size_t GetInputTypeCount() const { return 2; }
60 
61  ONNXTensorElementDataType GetInputType(size_t index) const {
62  return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
63  }
64 
65  size_t GetOutputTypeCount() const { return 3; }
66 
67  ONNXTensorElementDataType GetOutputType(size_t index) const {
68  if (index == 0) {
69  return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
70  }
71  return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
72  }
73 
74  const char* GetExecutionProviderType() const {
75  return "CPUExecutionProvider";
76  }
77 };
78 
79 } // namespace fastdeploy
80 
81 #endif
All C++ FastDeploy APIs are defined inside this namespace.
Definition: option.h:16