Rivet API documentation

Rivet 4.1.3
RivetONNXrt.hh
1// -*- C++ -*-
2#ifndef RIVET_RivetONNXrt_HH
3#define RIVET_RivetONNXrt_HH
4
5#include <iostream>
6#include <functional>
7#include <numeric>
8
9#include "Rivet/Tools/RivetPaths.hh"
10#include "Rivet/Tools/Utils.hh"
11#include "onnxruntime/onnxruntime_cxx_api.h"
12
13namespace Rivet {
14
15
21 class RivetONNXrt {
22 public:
23
24 // Suppress default constructor
25 RivetONNXrt() = delete;
26
28 RivetONNXrt(const string& filename, const string& runname="RivetONNXrt") {
29
30 // Set some ORT variables that need to be kept in memory
31 _env = std::make_unique<Ort::Env>(ORT_LOGGING_LEVEL_WARNING, runname.c_str());
32
33 // Load the model
34 Ort::SessionOptions sessionopts;
35 try {
36 _session = std::make_unique<Ort::Session> (*_env, filename.c_str(), sessionopts);
37 } catch (const std::exception & e) {
38 MSG_ERROR("Failure loading onnx file: " << e.what());
39 }
40
41 // Store network hyperparameters (input/output shape, etc.)
42 getNetworkInfo();
43
44 MSG_DEBUG(*this);
45 }
46
47
51 template <typename T=float>
52 vector<vector<T>> compute(const vector<vector<T>>& inputs) const {
53
54 // Check that number of input nodes matches what the model expects
55 if (inputs.size() != _inDims.size()) {
56 throw DataError("Expected " + to_string(_inDims.size()) + " input nodes, " +
57 "received " + to_string(inputs.size()));
58 }
59
60 // Create input tensor objects from input data
61 vector<Ort::Value> ort_input;
62 ort_input.reserve(_inDims.size());
63 auto memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
64 for (size_t i = 0; i < _inDims.size(); ++i) {
65
66 // Check that input data matches expected input node dimension
67 if (inputs[i].size() != (size_t)_inDimsFlat[i]) {
68 throw DataError("Expected flattened dimension " + to_string(_inDimsFlat[i]) +
69 " for input node " + to_string(i) +
70 ", received " + to_string(inputs[i].size()));
71 }
72
73 // Check that input data matches expected input node type
74 _checkTypes(inputs[i].data(), i); //< bit hacky, but minimises duplication
75
76 ort_input.emplace_back(Ort::Value::CreateTensor<T>(memory_info,
77 const_cast<T*>(inputs[i].data()), inputs[i].size(),
78 _inDims[i].data(), _inDims[i].size()));
79 }
80
81 // Retrieve output tensors
82 auto ort_output = _session->Run(Ort::RunOptions{nullptr}, _inNames.data(),
83 ort_input.data(), ort_input.size(),
84 _outNames.data(), _outNames.size());
85
86 // Construct flattened values and return
87 vector<vector<T>> outputs; outputs.resize(_outDims.size());
88 for (size_t i = 0; i < _outDims.size(); ++i) {
89 T* floatarr = ort_output[i].GetTensorMutableData<T>();
90 outputs[i].assign(floatarr, floatarr + _outDimsFlat[i]);
91 }
92 return outputs;
93 }
94
95
97 template <typename T=float>
98 vector<T> compute(const vector<T>& inputs) const {
99 if (_inDims.size() != 1 || _outDims.size() != 1) {
100 throw("This method assumes a single input/output node!");
101 }
102 vector<vector<T>> wrapped_inputs = { inputs };
103 vector<vector<T>> outputs = compute(wrapped_inputs);
104 return outputs[0];
105 }
106
107
109 bool hasKey(const std::string& key) const {
110 Ort::AllocatorWithDefaultOptions allocator;
111 return (bool)_metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
112 }
113
114
117 template <typename T, typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
118 T retrieve(const std::string& key) const {
119 Ort::AllocatorWithDefaultOptions allocator;
120 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
121 if (!res) {
122 throw("Key '"+key+"' not found in network metadata!");
123 }
124 /*if constexpr (std::is_same<T, std::string>::value) {
125 return res.get();
126 }*/
127 return lexical_cast<T>(res.get());
128 }
129
131 std::string retrieve(const std::string& key) const {
132 Ort::AllocatorWithDefaultOptions allocator;
133 Ort::AllocatedStringPtr res = _metadata->LookupCustomMetadataMapAllocated(key.c_str(), allocator);
134 if (!res) {
135 throw("Key '"+key+"' not found in network metadata!");
136 }
137 return res.get();
138 }
139
141 template <typename T>
142 vector<T> retrieve(const std::string & key) const {
143 const vector<string> stringvec = split(retrieve(key), ",");
144 vector<T> returnvec = {};
145 for (const string & s : stringvec){
146 returnvec.push_back(lexical_cast<T>(s));
147 }
148 return returnvec;
149 }
150
152 template <typename T>
153 vector<T> retrieve(const std::string & key, const vector<T> & defaultreturn) const {
154 try {
155 return retrieve<T>(key);
156 } catch (...) {
157 return defaultreturn;
158 }
159 }
160
161 std::string retrieve(const std::string& key, const std::string& defaultreturn) const {
162 try {
163 return retrieve(key);
164 } catch (...) {
165 return defaultreturn;
166 }
167 }
168
171 template <typename T, typename std::enable_if_t<!is_iterable_v<T> | is_cstring_v<T> >>
172 T retrieve(const std::string& key, const T& defaultreturn) const {
173 try {
174 return retrieve<T>(key);
175 } catch (...) {
176 return defaultreturn;
177 }
178 }
179
181 friend std::ostream& operator << (std::ostream& os, const RivetONNXrt& rort) {
182 os << "RivetONNXrt Network Summary: \n";
183 for (size_t i=0; i < rort._inNames.size(); ++i) {
184 os << "- Input node " << i << " name: " << rort._inNames[i];
185 os << ", dimensions: (";
186 for (size_t j=0; j < rort._inDims[i].size(); ++j){
187 if (j) os << ", ";
188 os << rort._inDims[i][j];
189 }
190 os << "), type (as ONNX enums): " << rort._inTypes[i] << "\n";
191 }
192 for (size_t i=0; i < rort._outNames.size(); ++i) {
193 os << "- Output node " << i << " name: " << rort._outNames[i];
194 os << ", dimensions: (";
195 for (size_t j=0; j < rort._outDims[i].size(); ++j){
196 if (j) os << ", ";
197 os << rort._outDims[i][j];
198 }
199 os << "), type (as ONNX enums): (" << rort._outTypes[i] << "\n";
200 }
201 return os;
202 }
203
205 Log& getLog() const {
206 string logname = "Rivet.RivetONNXrt";
207 return Log::getLog(logname);
208 }
209
210
211 private:
212
214 void getNetworkInfo() {
215
216 Ort::AllocatorWithDefaultOptions allocator;
217
218 // Retrieve network metadata
219 _metadata = std::make_unique<Ort::ModelMetadata>(_session->GetModelMetadata());
220
221 // Find out how many input nodes the model expects
222 const size_t num_input_nodes = _session->GetInputCount();
223 _inDimsFlat.reserve(num_input_nodes);
224 _inTypes.reserve(num_input_nodes);
225 _inDims.reserve(num_input_nodes);
226 _inNames.reserve(num_input_nodes);
227 _inNamesPtr.reserve(num_input_nodes);
228 for (size_t i = 0; i < num_input_nodes; ++i) {
229 // Retrieve input node name
230 auto input_name = _session->GetInputNameAllocated(i, allocator);
231 _inNames.push_back(input_name.get());
232 _inNamesPtr.push_back(std::move(input_name));
233
234 // Retrieve input node type
235 auto in_type_info = _session->GetInputTypeInfo(i);
236 auto in_tensor_info = in_type_info.GetTensorTypeAndShapeInfo();
237 _inTypes.push_back(in_tensor_info.GetElementType());
238 _inDims.push_back(in_tensor_info.GetShape());
239 }
240
241 // Fix negative shape values - appears to be an artefact of batch size issues.
242 for (auto& dims : _inDims) {
243 int64_t n = 1;
244 for (auto& dim : dims) {
245 if (dim < 0) dim = abs(dim);
246 n *= dim;
247 }
248 _inDimsFlat.push_back(n);
249 }
250
251 // Find out how many output nodes the model expects
252 const size_t num_output_nodes = _session->GetOutputCount();
253 _outDimsFlat.reserve(num_output_nodes);
254 _outTypes.reserve(num_output_nodes);
255 _outDims.reserve(num_output_nodes);
256 _outNames.reserve(num_output_nodes);
257 _outNamesPtr.reserve(num_output_nodes);
258 for (size_t i = 0; i < num_output_nodes; ++i) {
259 // Retrieve output node name
260 auto output_name = _session->GetOutputNameAllocated(i, allocator);
261 _outNames.push_back(output_name.get());
262 _outNamesPtr.push_back(std::move(output_name));
263
264 // Retrieve output node type
265 auto out_type_info = _session->GetOutputTypeInfo(i);
266 auto out_tensor_info = out_type_info.GetTensorTypeAndShapeInfo();
267 _outTypes.push_back(out_tensor_info.GetElementType());
268 _outDims.push_back(out_tensor_info.GetShape());
269 }
270
271 // Fix negative shape values - appears to be an artefact of batch size issues.
272 for (auto& dims : _outDims) {
273 int64_t n = 1;
274 for (auto& dim : dims) {
275 if (dim < 0) dim = abs(dim);
276 n *= dim;
277 }
278 _outDimsFlat.push_back(n);
279 }
280 }
281
282
284 void _checkTypes(const float*, size_t inode) const {
285 if (_inTypes[inode] != ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
286 throw DataError("ONNX network provided wrong input type (float)");
287 }
289 void _checkTypes(const double*, size_t inode) const {
290 if (_inTypes[inode] != ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
291 throw DataError("ONNX network provided wrong input type (double)");
292 }
293
294 private:
295
297 std::unique_ptr<Ort::Env> _env;
298
300 std::unique_ptr<Ort::Session> _session;
301
303 std::unique_ptr<Ort::ModelMetadata> _metadata;
304
308 vector<vector<int64_t>> _inDims, _outDims;
309
311 vector<int64_t> _inDimsFlat, _outDimsFlat;
312
314 vector<ONNXTensorElementDataType> _inTypes, _outTypes;
315
317 vector<Ort::AllocatedStringPtr> _inNamesPtr, _outNamesPtr;
318
320 vector<const char*> _inNames, _outNames;
321 };
322
323
325 using RivetONNXrtPtr = unique_ptr<RivetONNXrt>;
326
327
331 inline string getONNXFilePath(const string& filename) {
333 const string path1 = findAnalysisDataFile(filename);
334 if (!path1.empty()) return path1;
335 throw Rivet::Error("Couldn't find an ONNX data file for '" + filename + "' " +
336 "in the path " + toString(getRivetDataPath()));
337 }
338
339
348 inline RivetONNXrtPtr getONNX(const string& analysisname, const string& suffix="", const string& extn="onnx") {
349 const string fname = analysisname + (suffix.empty() ? "" : "-") + suffix + "." + extn;
350 return make_unique<RivetONNXrt>(getONNXFilePath(fname));
351 }
352
353
354
358 using ONNXrtPtr = RivetONNXrtPtr;
360
361
362}
363
364#endif
Logging system for controlled & formatted writing to stdout.
Definition Logging.hh:10
static Log & getLog(const std::string &name)
Simple interface class to take care of basic ONNX networks.
Definition RivetONNXrt.hh:21
Log & getLog() const
Logger.
Definition RivetONNXrt.hh:205
vector< T > compute(const vector< T > &inputs) const
Given a single-node input vector, populate and return the single-node output vector.
Definition RivetONNXrt.hh:98
T retrieve(const std::string &key, const T &defaultreturn) const
Definition RivetONNXrt.hh:172
std::string retrieve(const std::string &key) const
Template specialisation of retrieve for std::string.
Definition RivetONNXrt.hh:131
friend std::ostream & operator<<(std::ostream &os, const RivetONNXrt &rort)
Printing function for debugging.
Definition RivetONNXrt.hh:181
vector< vector< T > > compute(const vector< vector< T > > &inputs) const
Definition RivetONNXrt.hh:52
vector< T > retrieve(const std::string &key, const vector< T > &defaultreturn) const
Overload of retrieve for vector<T>, with a default return.
Definition RivetONNXrt.hh:153
RivetONNXrt(const string &filename, const string &runname="RivetONNXrt")
Constructor.
Definition RivetONNXrt.hh:28
bool hasKey(const std::string &key) const
Method to check if key exists in network metatdata.
Definition RivetONNXrt.hh:109
T retrieve(const std::string &key) const
Definition RivetONNXrt.hh:118
vector< T > retrieve(const std::string &key) const
Overload of retrieve for vector<T>.
Definition RivetONNXrt.hh:142
#define MSG_DEBUG(x)
Debug messaging, not enabled by default, using MSG_LVL.
Definition Logging.hh:182
#define MSG_ERROR(x)
Highest level messaging for serious problems, using MSG_LVL.
Definition Logging.hh:189
std::string findAnalysisDataFile(const std::string &filename, const std::vector< std::string > &pathprepend=std::vector< std::string >(), const std::vector< std::string > &pathappend=std::vector< std::string >())
Find the first file of the given name in the general data file search dirs.
std::string getRivetDataPath()
Get Rivet data install path.
T lexical_cast(const U &in)
Convert between any types via stringstream.
Definition Utils.hh:62
vector< string > split(const string &s, const string &sep)
Split a string on a specified separator string.
Definition Utils.hh:242
Definition MC_CENT_PPB_Projections.hh:10
string getONNXFilePath(const string &filename)
Useful function for getting ONNX file paths.
Definition RivetONNXrt.hh:331
RivetONNXrtPtr getONNX(const string &analysisname, const string &suffix="", const string &extn="onnx")
Definition RivetONNXrt.hh:348
unique_ptr< RivetONNXrt > RivetONNXrtPtr
Typedef for a handle to an OONXrt object.
Definition RivetONNXrt.hh:325
std::string toString(const AnalysisInfo &ai)
String representation.
RivetONNXrt ONNXrt
Definition RivetONNXrt.hh:357
Error relating to provided data mismatching expectations.
Definition Exceptions.hh:79
Generic runtime Rivet error.
Definition Exceptions.hh:12