FastDeploy  latest
Fast & Easy to Deploy!
iengine.h
1 // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include <string>
18 
19 // from pytorch
20 #include "ATen/core/interned_strings.h" // NOLINT
21 #include "torch/csrc/jit/ir/ir.h" // NOLINT
22 #include "torch/script.h" // NOLINT
23 
24 #include "plugin_create.h" // NOLINT
25 
26 namespace baidu {
27 namespace mirana {
28 namespace poros {
29 
30 struct PorosGraph {
31  torch::jit::Graph* graph = NULL;
32  torch::jit::Node* node = NULL;
33 };
34 
35 typedef uint64_t EngineID;
36 
37 class IEngine : public IPlugin, public torch::CustomClassHolder {
38  public:
39  virtual ~IEngine() {}
40 
46  virtual int init() = 0;
47 
54  virtual int transform(const PorosGraph& sub_graph) = 0;
55 
61  virtual std::vector<at::Tensor>
62  excute_engine(const std::vector<at::Tensor>& inputs) = 0;
63 
64  virtual void register_module_attribute(const std::string& name,
65  torch::jit::Module& module) = 0;
66 
67  // Logo
68  virtual const std::string who_am_i() = 0;
69 
70  // Whether the node is supported by the current engine
71  bool is_node_supported(const torch::jit::Node* node);
72 
73  public:
74  std::pair<uint64_t, uint64_t> _num_io; // Number of input/output parameters
75  EngineID _id;
76 };
77 
78 } // namespace poros
79 } // namespace mirana
80 } // namespace baidu
Definition: compile.h:26