2#ifndef RIVET_RivetONNXrt_HH
3#define RIVET_RivetONNXrt_HH
9#include "Rivet/Tools/RivetPaths.hh"
10#include "Rivet/Tools/Utils.hh"
11#include "onnxruntime/onnxruntime_cxx_api.h"
28 RivetONNXrt(
const string& filename,
const string& runname =
"RivetONNXrt") {
31 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
34 Ort::SessionOptions sessionopts;
35 _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
44 vector<vector<float>>
compute(vector<vector<float>>& inputs)
const {
47 if (inputs.size() != _inDims.size()) {
48 throw(
"Expected " + to_string(_inDims.size())
49 +
" input nodes, received " + to_string(inputs.size()));
53 vector<Ort::Value> ort_input;
54 ort_input.reserve(_inDims.size());
55 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
56 for (
size_t i=0; i < _inDims.size(); ++i) {
59 if (inputs[i].size() != _inDimsFlat[i]) {
60 throw(
"Expected flattened input node dimension " + to_string(_inDimsFlat[i])
61 +
", received " + to_string(inputs[i].size()));
64 ort_input.emplace_back(Ort::Value::CreateTensor<float>(memory_info,
65 inputs[i].data(), inputs[i].size(),
66 _inDims[i].data(), _inDims[i].size()));
70 auto ort_output = _session->Run(Ort::RunOptions{
nullptr}, _inNames.data(),
71 ort_input.data(), ort_input.size(),
72 _outNames.data(), _outNames.size());
75 vector<vector<float>> outputs; outputs.resize(_outDims.size());
76 for (
size_t i = 0; i < _outDims.size(); ++i) {
77 float* floatarr = ort_output[i].GetTensorMutableData<
float>();
78 outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
84 vector<float>
compute(
const vector<float>& inputs)
const {
85 if (_inDims.size() != 1 || _outDims.size() != 1) {
86 throw(
"This method assumes a single input/output node!");
88 vector<vector<float>> wrapped_inputs = { inputs };
89 vector<vector<float>> outputs =
compute(wrapped_inputs);
94 bool hasKey(
const std::string& key)
const {
95 Ort::AllocatorWithDefaultOptions allocator;
96 return (
bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
101 template <
typename T,
102 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
104 Ort::AllocatorWithDefaultOptions allocator;
105 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
107 throw(
"Key '"+key+
"' not found in network metadata!");
112 return lexical_cast<T>(res.get());
116 std::string
retrieve(
const std::string& key)
const {
117 Ort::AllocatorWithDefaultOptions allocator;
118 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
120 throw(
"Key '"+key+
"' not found in network metadata!");
126 template <
typename T>
127 vector<T>
retrieve(
const std::string & key)
const {
129 vector<T> returnvec = {};
130 for (
const string & s : stringvec){
131 returnvec.push_back(lexical_cast<T>(s));
137 template <
typename T>
138 vector<T>
retrieve(
const std::string & key,
const vector<T> & defaultreturn)
const {
140 return retrieve<T>(key);
142 return defaultreturn;
146 std::string
retrieve(
const std::string& key,
const std::string& defaultreturn)
const {
150 return defaultreturn;
156 template <
typename T,
157 typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
158 T
retrieve(
const std::string& key,
const T& defaultreturn)
const {
160 return retrieve<T>(key);
162 return defaultreturn;
168 os <<
"RivetONNXrt Network Summary: \n";
169 for (
size_t i=0; i < rort._inNames.size(); ++i) {
170 os <<
"- Input node " << i <<
" name: " << rort._inNames[i];
171 os <<
", dimensions: (";
172 for (
size_t j=0; j < rort._inDims[i].size(); ++j){
174 os << rort._inDims[i][j];
176 os <<
"), type (as ONNX enums): " << rort._inTypes[i] <<
"\n";
178 for (
size_t i=0; i < rort._outNames.size(); ++i) {
179 os <<
"- Output node " << i <<
" name: " << rort._outNames[i];
180 os <<
", dimensions: (";
181 for (
size_t j=0; j < rort._outDims[i].size(); ++j){
183 os << rort._outDims[i][j];
185 os <<
"), type (as ONNX enums): (" << rort._outTypes[i] <<
"\n";
192 string logname =
"Rivet.RivetONNXrt";
199 void getNetworkInfo() {
201 Ort::AllocatorWithDefaultOptions allocator;
204 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
207 const size_t num_input_nodes = _session->GetInputCount();
208 _inDimsFlat.reserve(num_input_nodes);
209 _inTypes.reserve(num_input_nodes);
210 _inDims.reserve(num_input_nodes);
211 _inNames.reserve(num_input_nodes);
212 _inNamesPtr.reserve(num_input_nodes);
213 for (
size_t i = 0; i < num_input_nodes; ++i) {
215 auto input_name = _session->GetInputNameAllocated(i, allocator);
216 _inNames.push_back(input_name.get());
217 _inNamesPtr.push_back(std::move(input_name));
220 auto in_type_info = _session->GetInputTypeInfo(i);
221 auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
222 _inTypes.push_back(in_tensor_info.GetElementType());
223 _inDims.push_back(in_tensor_info.GetShape());
227 for (
auto& dims : _inDims) {
229 for (
auto& dim : dims) {
230 if (dim < 0) dim = abs(dim);
233 _inDimsFlat.push_back(n);
237 const size_t num_output_nodes = _session->GetOutputCount();
238 _outDimsFlat.reserve(num_output_nodes);
239 _outTypes.reserve(num_output_nodes);
240 _outDims.reserve(num_output_nodes);
241 _outNames.reserve(num_output_nodes);
242 _outNamesPtr.reserve(num_output_nodes);
243 for (
size_t i = 0; i < num_output_nodes; ++i) {
245 auto output_name = _session->GetOutputNameAllocated(i, allocator);
246 _outNames.push_back(output_name.get());
247 _outNamesPtr.push_back(std::move(output_name));
250 auto out_type_info = _session->GetOutputTypeInfo(i);
251 auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
252 _outTypes.push_back(out_tensor_info.GetElementType());
253 _outDims.push_back(out_tensor_info.GetShape());
257 for (
auto& dims : _outDims) {
259 for (
auto& dim : dims) {
260 if (dim < 0) dim = abs(dim);
263 _outDimsFlat.push_back(n);
270 std::unique_ptr<Ort::Env> _env;
273 std::unique_ptr<Ort::Session> _session;
276 std::unique_ptr<Ort::ModelMetadata> _metadata;
281 vector<vector<int64_t>> _inDims, _outDims;
284 vector<int64_t> _inDimsFlat, _outDimsFlat;
287 vector<ONNXTensorElementDataType> _inTypes, _outTypes;
290 vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
293 vector<const char*> _inNames, _outNames;
301 if (!path1.empty())
return path1;
302 throw Rivet::Error(
"Couldn't find a ref data file for '" + filename +
311 unique_ptr<RivetONNXrt>
getONNX(
const string& analysisname,
const string& suffix =
".onnx"){
Logging system for controlled & formatted writing to stdout.
Definition Logging.hh:10
static Log & getLog(const std::string &name)
Simple interface class to take care of basic ONNX networks.
Definition RivetONNXrt.hh:20
Log & getLog() const
Logger.
Definition RivetONNXrt.hh:191
T retrieve(const std::string &key, const T &defaultreturn) const
Definition RivetONNXrt.hh:158
std::string retrieve(const std::string &key) const
Template specialisation of retrieve for std::string.
Definition RivetONNXrt.hh:116
friend std::ostream & operator<<(std::ostream &os, const RivetONNXrt &rort)
Printing function for debugging.
Definition RivetONNXrt.hh:167
vector< T > retrieve(const std::string &key, const vector< T > &defaultreturn) const
Overload of retrieve for vector<T>, with a default return.
Definition RivetONNXrt.hh:138
vector< float > compute(const vector< float > &inputs) const
Given a single-node input vector, populate and return the single-node output vector.
Definition RivetONNXrt.hh:84
RivetONNXrt(const string &filename, const string &runname="RivetONNXrt")
Constructor.
Definition RivetONNXrt.hh:28
vector< vector< float > > compute(vector< vector< float > > &inputs) const
Given a multi-node input vector, populate and return the multi-node output vector.
Definition RivetONNXrt.hh:44
bool hasKey(const std::string &key) const
Method to check if key exists in network metatdata.
Definition RivetONNXrt.hh:94
T retrieve(const std::string &key) const
Definition RivetONNXrt.hh:103
vector< T > retrieve(const std::string &key) const
Overload of retrieve for vector<T>
Definition RivetONNXrt.hh:127
#define MSG_DEBUG(x)
Debug messaging, not enabled by default, using MSG_LVL.
Definition Logging.hh:182
std::string findAnalysisRefFile(const std::string &filename, const std::vector< std::string > &pathprepend=std::vector< std::string >(), const std::vector< std::string > &pathappend=std::vector< std::string >())
Find the first file of the given name in the ref data file search dirs.
std::string getRivetDataPath()
Get Rivet data install path.
vector< string > split(const string &s, const string &sep)
Split a string on a specified separator string.
Definition Utils.hh:214
Definition MC_CENT_PPB_Projections.hh:10
string getONNXFilePath(const string &filename)
Definition RivetONNXrt.hh:298
unique_ptr< RivetONNXrt > getONNX(const string &analysisname, const string &suffix=".onnx")
Definition RivetONNXrt.hh:311
Generic runtime Rivet error.
Definition Exceptions.hh:12