rivet is hosted by Hepforge, IPPP Durham
Rivet 4.0.0
RivetLWTNN.hh
1// -*- C++ -*-
2#ifndef RIVET_RivetLWTNN_HH
3#define RIVET_RivetLWTNN_HH
4
5#include "Rivet/Tools/RivetPaths.hh"
6#include "lwtnn/LightweightNeuralNetwork.hh"
7#include "lwtnn/LightweightGraph.hh"
8#include "lwtnn/Exceptions.hh"
9#include "lwtnn/parse_json.hh"
10#include <fstream>
11
12namespace Rivet {
13 using namespace std;
14
15
17 lwt::JSONConfig readLWTNNConfig(const string& jsonpath) {
18 ifstream input;
19 try {
20 // Note: a failed read here may fail quietly, and cause the filestream to
21 // go bad, making it look like the hepmc event-read has failed.
22 input = std::ifstream(jsonpath);
23 return lwt::parse_json(input);
24 } catch (lwt::LightweightNNException &e) {
25 input.close();
26 throw IOError("Error loading LWTNN JSON config");
27 }
28 }
29
30
34 lwt::GraphConfig readLWTNNGraphConfig(const string& jsonpath) {
35 ifstream input;
36 try {
37 // Note: a failed read here may fail quietly, and cause the filestream to
38 // go bad, making it look like the hepmc event-read has failed.
39 input = std::ifstream(jsonpath);
40 return lwt::parse_json_graph(input);
41 } catch (lwt::LightweightNNException &e) {
42 input.close();
43 throw IOError("Error loading LWTNN JSON config");
44 }
45 }
46
48 std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const lwt::JSONConfig& jsonconfig) {
49 try {
50 return std::make_unique<lwt::LightweightNeuralNetwork>(jsonconfig.inputs, jsonconfig.layers, jsonconfig.outputs);
51 } catch (lwt::LightweightNNException &e) {
52 throw IOError("Error initialising from LWTNN JSON config");
53 }
54 }
55
59 std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const lwt::GraphConfig& graphconfig) {
60 try {
61 return std::make_unique<lwt::LightweightGraph>(graphconfig);
62 } catch (lwt::LightweightNNException &e) {
63 throw IOError("Error initialising from LWTNN JSON config");
64 }
65 }
66
67
69 std::unique_ptr<lwt::LightweightNeuralNetwork> mkLWTNN(const string& jsonpath) {
70 lwt::JSONConfig config = readLWTNNConfig(jsonpath);
71 return mkLWTNN(config);
72 }
73
77 std::unique_ptr<lwt::LightweightGraph> mkGraphLWTNN(const string& jsonpath) {
78 lwt::GraphConfig config = readLWTNNGraphConfig(jsonpath);
79 return mkGraphLWTNN(config);
80 }
81
82}
83
84#endif
Definition MC_CENT_PPB_Projections.hh:10
lwt::JSONConfig readLWTNNConfig(const string &jsonpath)
Read a LWT DNN config from the JSON path.
Definition RivetLWTNN.hh:17
lwt::GraphConfig readLWTNNGraphConfig(const string &jsonpath)
Read a LWT Graph config from the JSON path.
Definition RivetLWTNN.hh:34
std::unique_ptr< lwt::LightweightNeuralNetwork > mkLWTNN(const lwt::JSONConfig &jsonconfig)
Make a LWT DNN from the JSON config object.
Definition RivetLWTNN.hh:48
std::unique_ptr< lwt::LightweightGraph > mkGraphLWTNN(const lwt::GraphConfig &graphconfig)
Make a LWT Graph from the JSON config object.
Definition RivetLWTNN.hh:59
STL namespace.
Error for I/O failures.
Definition Exceptions.hh:67