FastDeploy  latest
Fast & Easy to Deploy!
ort_backend.h
1 // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include <iostream>
18 #include <memory>
19 #include <string>
20 #include <vector>
21 #include <map>
22 
23 #include "fastdeploy/runtime/backends/backend.h"
24 #include "fastdeploy/runtime/backends/ort/option.h"
25 #include "onnxruntime_cxx_api.h" // NOLINT
26 
27 #ifdef WITH_DIRECTML
28 #include "dml_provider_factory.h" // NOLINT
29 #endif
30 
31 namespace fastdeploy {
32 
33 struct OrtValueInfo {
34  std::string name;
35  std::vector<int64_t> shape;
36  ONNXTensorElementDataType dtype;
37 };
38 
39 class OrtBackend : public BaseBackend {
40  public:
41  OrtBackend() {}
42  virtual ~OrtBackend() = default;
43 
44  bool BuildOption(const OrtBackendOption& option);
45 
46  bool Init(const RuntimeOption& option);
47 
48  bool Infer(std::vector<FDTensor>& inputs, std::vector<FDTensor>* outputs,
49  bool copy_to_fd = true) override;
50 
51  int NumInputs() const override { return inputs_desc_.size(); }
52 
53  int NumOutputs() const override { return outputs_desc_.size(); }
54 
55  TensorInfo GetInputInfo(int index) override;
56  TensorInfo GetOutputInfo(int index) override;
57  std::vector<TensorInfo> GetInputInfos() override;
58  std::vector<TensorInfo> GetOutputInfos() override;
59  static std::vector<OrtCustomOp*> custom_operators_;
60  void InitCustomOperators();
61 
62  private:
63  bool InitFromPaddle(const std::string& model_buffer,
64  const std::string& params_buffer,
65  const OrtBackendOption& option = OrtBackendOption(),
66  bool verbose = false);
67 
68  bool InitFromOnnx(const std::string& model_buffer,
69  const OrtBackendOption& option = OrtBackendOption());
70 
71  Ort::Env env_;
72  Ort::Session session_{nullptr};
73  Ort::SessionOptions session_options_;
74  std::shared_ptr<Ort::IoBinding> binding_;
75  std::vector<OrtValueInfo> inputs_desc_;
76  std::vector<OrtValueInfo> outputs_desc_;
77 
78  // the ONNX model file name,
79  // when ONNX is bigger than 2G, we will set this name
80  std::string model_file_name;
81 #ifndef NON_64_PLATFORM
82  Ort::CustomOpDomain custom_op_domain_ = Ort::CustomOpDomain("Paddle");
83 #endif
84  OrtBackendOption option_;
85  void OrtValueToFDTensor(const Ort::Value& value, FDTensor* tensor,
86  const std::string& name, bool copy_to_fd);
87 };
88 } // namespace fastdeploy
All C++ FastDeploy APIs are defined inside this namespace.
Definition: option.h:16