04d73653dddbc3aee2dc58fc2f9d3d668b490bd0 — DaniĆ«l de Kok 3 months ago e367149
Store graph construction args to the graph

Also provide the sticker-graph-metadata utility to print the args.
M sticker-graph/sticker_graph/write_helper.py => sticker-graph/sticker_graph/write_helper.py +8 -3
@@ 13,7 13,6 @@
 
 
 def create_graph(model, args):
-    write_args(model, args)
 
     shapes = read_shapes(args)
     graph_filename = args.output_graph_file


@@ 22,6 21,8 @@
     tfconfig = tf.ConfigProto(gpu_options=gpuopts)
 
     with tf.Graph().as_default(), tf.Session(config=tfconfig) as session:
+        write_args(model, args)
+
         logdir = tf.placeholder(shape=[], name="logdir", dtype=tf.string)
         summary_writer = sticker_graph.vendored._create_file_writer_generic_type(logdir)
 


@@ 46,10 47,14 @@
 
 
 def write_args(model, args):
+    graph_metadata = 'Model = "{}"\n{}'.format(model.__name__,
+                                             toml.dumps(args.__dict__))
+
+    tf.constant(graph_metadata, name="graph_metadata")
+
     f = open(args.write_args, 'w') if args.write_args else sys.stdout
     try:
-        f.write('Model = "{}"\n{}'.format(model.__name__,
-                                          toml.dumps(args.__dict__)))
+        f.write(graph_metadata)
     finally:
         if f != sys.stdout:
             f.close()

A sticker-utils/src/bin/sticker-graph-metadata.rs => sticker-utils/src/bin/sticker-graph-metadata.rs +57 -0
@@ 0,0 1,57 @@
+use std::fs::File;
+use std::io::BufReader;
+
+use clap::{App, AppSettings, Arg};
+use stdinout::OrExit;
+use sticker::tensorflow::TaggerGraph;
+
+static GRAPH: &str = "GRAPH";
+
+static DEFAULT_CLAP_SETTINGS: &[AppSettings] = &[
+    AppSettings::DontCollapseArgsInUsage,
+    AppSettings::UnifiedHelpMessage,
+];
+
+pub struct GraphMetadataApp {
+    graph: String,
+}
+
+impl GraphMetadataApp {
+    fn new() -> Self {
+        let matches = App::new("sticker-graph-metadata")
+            .settings(DEFAULT_CLAP_SETTINGS)
+            .arg(
+                Arg::with_name(GRAPH)
+                    .help("Tensorflow graph")
+                    .index(1)
+                    .required(true),
+            )
+            .get_matches();
+
+        let graph = matches.value_of(GRAPH).unwrap().to_owned();
+
+        GraphMetadataApp { graph }
+    }
+}
+
+impl Default for GraphMetadataApp {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+fn main() {
+    let app = GraphMetadataApp::new();
+
+    let reader = BufReader::new(File::open(app.graph).or_exit("Cannot open graph for reading", 1));
+    let model_config = Default::default();
+    let graph = TaggerGraph::load_graph(reader, &model_config).expect("Cannot load graph");
+
+    match graph
+        .metadata()
+        .or_exit("Cannot retrieve graph metadata", 1)
+    {
+        Some(metadata) => println!("{}", metadata),
+        None => eprintln!("Graph does not contain metadata"),
+    }
+}

M sticker/src/tensorflow/tagger.rs => sticker/src/tensorflow/tagger.rs +50 -21
@@ 65,9 65,23 @@
     }
 }
 
+impl Default for ModelConfig {
+    fn default() -> Self {
+        ModelConfig {
+            batch_size: 128,
+            gpu_allow_growth: true,
+            graph: String::new(),
+            inter_op_parallelism_threads: 1,
+            intra_op_parallelism_threads: 1,
+            parameters: String::new(),
+        }
+    }
+}
+
 mod op_names {
-    pub const INIT_OP: &str = "init";
+    pub const GRAPH_METADATA_OP: &str = "graph_metadata";
 
+    pub const INIT_OP: &str = "init";
     pub const RESTORE_OP: &str = "save/restore_all";
     pub const SAVE_OP: &str = "save/control_dependency";
     pub const SAVE_PATH_OP: &str = "save/Const";


@@ 97,6 111,8 @@
     pub(crate) graph: Graph,
     pub(crate) model_config: ModelConfig,
 
+    pub(crate) graph_metadata_op: Option<Operation>,
+
     pub(crate) graph_write_op: Option<Operation>,
     pub(crate) logdir_op: Option<Operation>,
     pub(crate) summary_init_op: Option<Operation>,


@@ 135,6 151,8 @@
             .import_graph_def(&data, &opts)
             .map_err(status_to_error)?;
 
+        let graph_metadata_op = Self::add_op(&graph, op_names::GRAPH_METADATA_OP).ok();
+
         let restore_op = Self::add_op(&graph, op_names::RESTORE_OP)?;
         let save_op = Self::add_op(&graph, op_names::SAVE_OP)?;
         let save_path_op = Self::add_op(&graph, op_names::SAVE_PATH_OP)?;


@@ 165,6 183,8 @@
             graph,
             model_config: model_config.clone(),
 
+            graph_metadata_op,
+
             graph_write_op,
             logdir_op,
             summary_init_op,


@@ 196,6 216,23 @@
             .operation_by_name_required(name)
             .map_err(status_to_error)
     }
+
+    pub fn metadata(&self) -> Fallible<Option<String>> {
+        let metadata = match self.graph_metadata_op {
+            Some(ref graph_metadata_op) => {
+                let mut args = SessionRunArgs::new();
+                let metadata_token = args.request_fetch(graph_metadata_op, 0);
+                let session = new_session(self)?;
+                session.run(&mut args).map_err(status_to_error)?;
+                let metadata: Tensor<String> =
+                    args.fetch(metadata_token).map_err(status_to_error)?;
+                Some(metadata[0].clone())
+            }
+            None => None,
+        };
+
+        Ok(metadata)
+    }
 }
 
 pub struct Tagger<D>


@@ 232,7 269,7 @@
         let mut args = SessionRunArgs::new();
         args.add_feed(&graph.save_path_op, 0, &path_tensor);
         args.add_target(&graph.restore_op);
-        let session = Self::new_session(&graph)?;
+        let session = new_session(&graph)?;
         session.run(&mut args).map_err(status_to_error)?;
 
         Ok(Tagger {


@@ 243,15 280,6 @@
         })
     }
 
-    fn new_session(graph: &TaggerGraph) -> Result<Session, Error> {
-        let mut session_opts = SessionOptions::new();
-        session_opts
-            .set_config(&graph.model_config.to_protobuf()?)
-            .map_err(status_to_error)?;
-
-        Session::new(&session_opts, &graph.graph).map_err(status_to_error)
-    }
-
     fn prepare_batch(
         &self,
         sentences: &[impl Borrow<Sentence>],


@@ 360,6 388,15 @@
     }
 }
 
+fn new_session(graph: &TaggerGraph) -> Result<Session, Error> {
+    let mut session_opts = SessionOptions::new();
+    session_opts
+        .set_config(&graph.model_config.to_protobuf()?)
+        .map_err(status_to_error)?;
+
+    Session::new(&session_opts, &graph.graph).map_err(status_to_error)
+}
+
 #[cfg(test)]
 mod tests {
     use std::fs::File;


@@ 368,7 405,7 @@
 
     use flate2::read::GzDecoder;
 
-    use super::{ModelConfig, TaggerGraph};
+    use super::TaggerGraph;
 
     fn load_graph(path: impl AsRef<Path>) {
         let f = File::open(path).expect("Cannot open test graph.");


@@ 378,15 415,7 @@
             .read_to_end(&mut data)
             .expect("Cannot decompress test graph.");
 
-        let model_config = ModelConfig {
-            batch_size: 128,
-            gpu_allow_growth: true,
-            graph: String::new(),
-            inter_op_parallelism_threads: 1,
-            intra_op_parallelism_threads: 1,
-            parameters: String::new(),
-        };
-
+        let model_config = Default::default();
         TaggerGraph::load_graph(Cursor::new(data), &model_config).expect("Cannot load graph");
     }