FastDeploy  latest
Fast & Easy to Deploy!
poros_module.h
1 // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #pragma once
16 
17 #include "torch/csrc/jit/jit_log.h" // NOLINT
18 #include "torch/script.h" // NOLINT
19 #include <string>
20 // #include "ATen/Context.h"
21 
22 namespace baidu {
23 namespace mirana {
24 namespace poros {
25 
26 enum Device : int8_t { GPU = 0, CPU, XPU, UNKNOW };
27 
28 struct PorosOptions {
29  Device device = GPU;
30  bool debug = false;
31  bool use_fp16 = false;
32  bool is_dynamic = false;
33  bool long_to_int = true;
34  uint64_t max_workspace_size = 1ULL << 30;
35  int32_t device_id = -1;
36  int32_t unconst_ops_thres = -1;
37  bool use_nvidia_tf32 = false;
38 };
39 
40 
41 class PorosModule : public torch::jit::Module {
42  public:
43  PorosModule(torch::jit::Module module) : torch::jit::Module(module) {} // NOLINT
44  ~PorosModule() = default;
45 
46  void to_device(Device device) { _options.device = device; }
47 
48  // c10::IValue forward(std::vector<c10::IValue> inputs);
49  // void save(const std::string& filename);
50  public:
51  PorosOptions _options;
52 };
53 
54 // via porosmodule.save
55 std::unique_ptr<PorosModule> Load(const std::string& filename,
56  const PorosOptions& options);
57 
58 } // namespace poros
59 } // namespace mirana
60 } // namespace baidu
Definition: compile.h:26