19 #ifndef NON_64_PLATFORM 20 #include "onnxruntime_cxx_api.h" 24 struct MultiClassNmsKernel {
26 int64_t background_label = -1;
27 int64_t keep_top_k = -1;
29 float nms_threshold = 0.7;
32 float score_threshold;
33 Ort::CustomOpApi ort_;
36 MultiClassNmsKernel(Ort::CustomOpApi ort,
const OrtKernelInfo* info)
41 void GetAttribute(
const OrtKernelInfo* info);
43 void Compute(OrtKernelContext* context);
44 void FastNMS(
const float* boxes,
const float* scores,
const int& num_boxes,
45 std::vector<int>* keep_indices);
46 int NMSForEachSample(
const float* boxes,
const float* scores,
int num_boxes,
48 std::map<
int, std::vector<int>>* keep_indices);
51 struct MultiClassNmsOp
52 : Ort::CustomOpBase<MultiClassNmsOp, MultiClassNmsKernel> {
53 void* CreateKernel(Ort::CustomOpApi api,
const OrtKernelInfo* info)
const {
54 return new MultiClassNmsKernel(api, info);
57 const char* GetName()
const {
return "MultiClassNMS"; }
59 size_t GetInputTypeCount()
const {
return 2; }
61 ONNXTensorElementDataType GetInputType(
size_t index)
const {
62 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
65 size_t GetOutputTypeCount()
const {
return 3; }
67 ONNXTensorElementDataType GetOutputType(
size_t index)
const {
69 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
71 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
74 const char* GetExecutionProviderType()
const {
75 return "CPUExecutionProvider";
All C++ FastDeploy APIs are defined inside this namespace.
Definition: option.h:16