rivet is hosted by Hepforge, IPPP Durham
Rivet 4.0.2
RivetONNXrt.hh
1// -*- C++ -*-
2#ifndef RIVET_RivetONNXrt_HH
3#define RIVET_RivetONNXrt_HH
4
5#include <iostream>
6#include <functional>
7#include <numeric>
8
9#include "Rivet/Tools/RivetPaths.hh"
10#include "Rivet/Tools/Utils.hh"
11#include "onnxruntime/onnxruntime_cxx_api.h"
12
13namespace Rivet {
14
15
22
23 public:
24
25 // Suppress default constructor
26 RivetONNXrt() = delete;
27
29 RivetONNXrt(const string& filename, const string& runname = "RivetONNXrt") {
30
31 // Set some ORT variables that need to be kept in memory
32 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
33
34 // Load the model
35 Ort::SessionOptions sessionopts;
36 _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
37
38 // Store network hyperparameters (input/output shape, etc.)
39 getNetworkInfo();
40
41 MSG_DEBUG(*this);
42 }
43
45 vector<vector<float>> compute(vector<vector<float>>& inputs) const {
46
48 if (inputs.size() != _inDims.size()) {
49 throw("Expected " + to_string(_inDims.size())
50 + " input nodes, received " + to_string(inputs.size()));
51 }
52
53 // Create input tensor objects from input data
54 vector<Ort::Value> ort_input;
55 ort_input.reserve(_inDims.size());
56 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
57 for (size_t i=0; i < _inDims.size(); ++i) {
58
59 // Check that input data matches expected input node dimension
60 if (inputs[i].size() != _inDimsFlat[i]) {
61 throw("Expected flattened input node dimension " + to_string(_inDimsFlat[i])
62 + ", received " + to_string(inputs[i].size()));
63 }
64
65 ort_input.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
66 inputs[i].data(), inputs[i].size(),
67 _inDims[i].data(), _inDims[i].size()));
68 }
69
70 // retrieve output tensors
71 auto ort_output = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
72 ort_input.data(), ort_input.size(),
73 _outNames.data(), _outNames.size());
74
75 // construct flattened values and return
76 vector<vector<float>> outputs; outputs.resize(_outDims.size());
77 for (size_t i = 0; i < _outDims.size(); ++i) {
78 float* floatarr = ort_output[i].GetTensorMutableData<float>();
79 outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
80 }
81 return outputs;
82 }
83
85 vector<float> compute(const vector<float>& inputs) const {
86 if (_inDims.size() != 1 || _outDims.size() != 1) {
87 throw("This method assumes a single input/output node!");
88 }
89 vector<vector<float>> wrapped_inputs = { inputs };
90 vector<vector<float>> outputs = compute(wrapped_inputs);
91 return outputs[0];
92 }
93
95 bool hasKey(const std::string& key) const {
96 Ort::AllocatorWithDefaultOptions allocator;
97 return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
98 }
99
102 template <typename T,
103 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
104 T retrieve(const std::string& key) const {
105 Ort::AllocatorWithDefaultOptions allocator;
106 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
107 if (!res) {
108 throw("Key '"+key+"' not found in network metadata!");
109 }
110 /*if constexpr (std::is_same<T, std::string>::value) {
111 return res.get();
112 }*/
113 return lexical_cast<T>(res.get());
114 }
115
117 std::string retrieve(const std::string& key) const {
118 Ort::AllocatorWithDefaultOptions allocator;
119 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
120 if (!res) {
121 throw("Key '"+key+"' not found in network metadata!");
122 }
123 return res.get();
124 }
125
127 template <typename T>
128 vector<T> retrieve(const std::string & key) const {
129 const vector<string> stringvec = split(retrieve(key), ",");
130 vector<T> returnvec = {};
131 for (const string & s : stringvec){
132 returnvec.push_back(lexical_cast<T>(s));
133 }
134 return returnvec;
135 }
136
138 template <typename T>
139 vector<T> retrieve(const std::string & key, const vector<T> & defaultreturn) const {
140 try {
141 return retrieve<T>(key);
142 } catch (...) {
143 return defaultreturn;
144 }
145 }
146
147 std::string retrieve(const std::string& key, const std::string& defaultreturn) const {
148 try {
149 return retrieve(key);
150 } catch (...) {
151 return defaultreturn;
152 }
153 }
154
157 template <typename T,
158 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
159 T retrieve(const std::string& key, const T& defaultreturn) const {
160 try {
161 return retrieve<T>(key);
162 } catch (...) {
163 return defaultreturn;
164 }
165 }
166
168 friend std::ostream& operator <<(std::ostream& os, const RivetONNXrt& rort){
169 os << "RivetONNXrt Network Summary: \n";
170 for (size_t i=0; i < rort._inNames.size(); ++i) {
171 os << "- Input node " << i << " name: " << rort._inNames[i];
172 os << ", dimensions: (";
173 for (size_t j=0; j < rort._inDims[i].size(); ++j){
174 if (j) os << ", ";
175 os << rort._inDims[i][j];
176 }
177 os << "), type (as ONNX enums): " << rort._inTypes[i] << "\n";
178 }
179 for (size_t i=0; i < rort._outNames.size(); ++i) {
180 os << "- Output node " << i << " name: " << rort._outNames[i];
181 os << ", dimensions: (";
182 for (size_t j=0; j < rort._outDims[i].size(); ++j){
183 if (j) os << ", ";
184 os << rort._outDims[i][j];
185 }
186 os << "), type (as ONNX enums): (" << rort._outTypes[i] << "\n";
187 }
188 return os;
189 }
190
192 Log& getLog() const {
193 string logname = "Rivet.RivetONNXrt";
194 return Log::getLog(logname);
195 }
196
197
198 private:
199
200 void getNetworkInfo() {
201
202 Ort::AllocatorWithDefaultOptions allocator;
203
204 // Retrieve network metadat
205 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
206
207 // find out how many input nodes the model expects
208 const size_t num_input_nodes = _session->GetInputCount();
209 _inDimsFlat.reserve(num_input_nodes);
210 _inTypes.reserve(num_input_nodes);
211 _inDims.reserve(num_input_nodes);
212 _inNames.reserve(num_input_nodes);
213 _inNamesPtr.reserve(num_input_nodes);
214 for (size_t i = 0; i < num_input_nodes; ++i) {
215 // retrieve input node name
216 auto input_name = _session->GetInputNameAllocated(i, allocator);
217 _inNames.push_back(input_name.get());
218 _inNamesPtr.push_back(std::move(input_name));
219
220 // retrieve input node type
221 auto in_type_info = _session->GetInputTypeInfo(i);
222 auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
223 _inTypes.push_back(in_tensor_info.GetElementType());
224 _inDims.push_back(in_tensor_info.GetShape());
225 }
226
227 // Fix negative shape values - appears to be an artefact of batch size issues.
228 for (auto& dims : _inDims) {
229 int64_t n = 1;
230 for (auto& dim : dims) {
231 if (dim < 0) dim = abs(dim);
232 n *= dim;
233 }
234 _inDimsFlat.push_back(n);
235 }
236
237 // find out how many output nodes the model expects
238 const size_t num_output_nodes = _session->GetOutputCount();
239 _outDimsFlat.reserve(num_output_nodes);
240 _outTypes.reserve(num_output_nodes);
241 _outDims.reserve(num_output_nodes);
242 _outNames.reserve(num_output_nodes);
243 _outNamesPtr.reserve(num_output_nodes);
244 for (size_t i = 0; i < num_output_nodes; ++i) {
245 // retrieve output node name
246 auto output_name = _session->GetOutputNameAllocated(i, allocator);
247 _outNames.push_back(output_name.get());
248 _outNamesPtr.push_back(std::move(output_name));
249
250 // retrieve input node type
251 auto out_type_info = _session->GetOutputTypeInfo(i);
252 auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
253 _outTypes.push_back(out_tensor_info.GetElementType());
254 _outDims.push_back(out_tensor_info.GetShape());
255 }
256
257 // Fix negative shape values - appears to be an artefact of batch size issues.
258 for (auto& dims : _outDims) {
259 int64_t n = 1;
260 for (auto& dim : dims) {
261 if (dim < 0) dim = abs(dim);
262 n *= dim;
263 }
264 _outDimsFlat.push_back(n);
265 }
266 }
267
268 private:
269
271 std::unique_ptr<Ort::Env> _env;
272
274 std::unique_ptr<Ort::Session> _session;
275
277 std::unique_ptr<Ort::ModelMetadata> _metadata;
278
282 vector<vector<int64_t>> _inDims, _outDims;
283
285 vector<int64_t> _inDimsFlat, _outDimsFlat;
286
288 vector<ONNXTensorElementDataType> _inTypes, _outTypes;
289
291 vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
292
294 vector<const char*> _inNames, _outNames;
295 };
296
297
301 inline string getONNXFilePath(const string& filename) {
303 const string path1 = findAnalysisDataFile(filename);
304 if (!path1.empty()) return path1;
305 throw Rivet::Error("Couldn't find an ONNX data file for '" + filename + "' " +
306 "in the path " + toString(getRivetDataPath()));
307 }
308
309
315 inline unique_ptr<RivetONNXrt> getONNX(const string& analysisname, const string& suffix = ".onnx"){
316 return make_unique<RivetONNXrt>(getONNXFilePath(analysisname+suffix));
317 }
318
319
320}
321
322#endif
Logging system for controlled & formatted writing to stdout.
Definition Logging.hh:10
static Log & getLog(const std::string &name)
Simple interface class to take care of basic ONNX networks.
Definition RivetONNXrt.hh:21
Log & getLog() const
Logger.
Definition RivetONNXrt.hh:192
T retrieve(const std::string &key, const T &defaultreturn) const
Definition RivetONNXrt.hh:159
std::string retrieve(const std::string &key) const
Template specialisation of retrieve for std::string.
Definition RivetONNXrt.hh:117
friend std::ostream & operator<<(std::ostream &os, const RivetONNXrt &rort)
Printing function for debugging.
Definition RivetONNXrt.hh:168
vector< T > retrieve(const std::string &key, const vector< T > &defaultreturn) const
Overload of retrieve for vector<T>, with a default return.
Definition RivetONNXrt.hh:139
vector< float > compute(const vector< float > &inputs) const
Given a single-node input vector, populate and return the single-node output vector.
Definition RivetONNXrt.hh:85
RivetONNXrt(const string &filename, const string &runname="RivetONNXrt")
Constructor.
Definition RivetONNXrt.hh:29
vector< vector< float > > compute(vector< vector< float > > &inputs) const
Given a multi-node input vector, populate and return the multi-node output vector.
Definition RivetONNXrt.hh:45
bool hasKey(const std::string &key) const
Method to check if key exists in network metatdata.
Definition RivetONNXrt.hh:95
T retrieve(const std::string &key) const
Definition RivetONNXrt.hh:104
vector< T > retrieve(const std::string &key) const
Overload of retrieve for vector<T>
Definition RivetONNXrt.hh:128
#define MSG_DEBUG(x)
Debug messaging, not enabled by default, using MSG_LVL.
Definition Logging.hh:182
std::string findAnalysisDataFile(const std::string &filename, const std::vector< std::string > &pathprepend=std::vector< std::string >(), const std::vector< std::string > &pathappend=std::vector< std::string >())
Find the first file of the given name in the general data file search dirs.
std::string getRivetDataPath()
Get Rivet data install path.
vector< string > split(const string &s, const string &sep)
Split a string on a specified separator string.
Definition Utils.hh:214
Definition MC_CENT_PPB_Projections.hh:10
string getONNXFilePath(const string &filename)
Useful function for getting ONNX file paths.
Definition RivetONNXrt.hh:301
unique_ptr< RivetONNXrt > getONNX(const string &analysisname, const string &suffix=".onnx")
Definition RivetONNXrt.hh:315
std::string toString(const AnalysisInfo &ai)
String representation.
Generic runtime Rivet error.
Definition Exceptions.hh:12