7d781cff2090887ded476ad8eca14dcaeb4302c8 — DaniĆ«l de Kok 3 months ago 4284afd
Add support for RNN-based subword representations

When subword representations are enabled, sticker builds a string
tensor containing each form in the input. The forms are converted to
(truncated) byte representations in the Tensorflow graph. Each byte is
then mapped to a byte embedding. Subword representations are then
formed by applying a bidirectional RNN to the sequence of byte
embeddings.

In order to support the configuration of subword embeddings, the
sticker configuration format has been adjusted. An `input` section has
been added, which holds the new `subwords` boolean option. The
`embeddings` section is now a subsection of `input`. The use of
subword representations is communicated to the graph writers through
the shapes file.
M .travis.yml => .travis.yml +2 -2
@@ 19,8 19,8 @@
       - ./sticker-write-rnn-graph --rnn_layers 2 --hidden_size 100 testdata/sticker.shapes rnn.graph
       - ./sticker-write-conv-graph --levels 6 --hidden_size 100 testdata/sticker.shapes conv.graph
       - ./sticker-write-transformer-graph testdata/sticker.shapes trans.graph
-      - ./sticker-write-transformer-graph --outer_hsize 350 --pass_inputs --num_heads 7 testdata/sticker.shapes trans-pass-input.graph
-      - ./sticker-write-transformer-graph --outer_hsize 350 --pass_inputs --num_heads 7 --embed_time --max_time 150 testdata/sticker.shapes trans-time.graph
+      - ./sticker-write-transformer-graph --outer_hsize 400 --pass_inputs --num_heads 8 testdata/sticker.shapes trans-pass-input.graph
+      - ./sticker-write-transformer-graph --outer_hsize 400 --pass_inputs --num_heads 8 --embed_time --max_time 150 testdata/sticker.shapes trans-time.graph
   - language: rust
     os: linux
     rust: 1.34.2

M sticker-graph/sticker_graph/conv_model.py => sticker-graph/sticker_graph/conv_model.py +1 -6
@@ 162,13 162,8 @@
 
         self.setup_placeholders()
 
-        inputs = tf.contrib.layers.dropout(
-            self.inputs,
-            keep_prob=args.keep_prob_input,
-            is_training=self.is_training)
-
         hidden_states = dilated_convolution(
-            inputs,
+            self.inputs,
             args.hidden_size,
             kernel_size=args.kernel_size,
             n_levels=args.levels,

M sticker-graph/sticker_graph/model.py => sticker-graph/sticker_graph/model.py +84 -8
@@ 1,5 1,7 @@
 import tensorflow as tf
 
+from sticker_graph.rnn import bidi_rnn_layers
+
 
 class Model:
     def __init__(self, args, shapes):


@@ 34,6 36,7 @@
                 x, w, b), [
                 batch_size, -1, n_outputs])
 
+
     def masked_softmax_loss(self, prefix, logits, labels, mask):
         # Compute losses
         losses = tf.nn.sparse_softmax_cross_entropy_with_logits(


@@ 92,14 95,9 @@
 
         return probs, labels
 
-    def setup_placeholders(self):
-        self._is_training = tf.placeholder(tf.bool, [], "is_training")
-
-        self._tags = tf.placeholder(
-            tf.int32, name="tags", shape=[
-                None, None])
-
-        self._inputs = tf.placeholder(
+    def setup_inputs(self):
+        # Word/tag embeddings
+        inputs = tf.placeholder(
             tf.float32,
             shape=[
                 None,


@@ 107,6 105,31 @@
                 self.shapes['token_embed_dims'] +
                 self.shapes['tag_embed_dims']],
             name="inputs")
+        inputs = tf.contrib.layers.dropout(
+            inputs,
+            keep_prob=self.args.keep_prob_input,
+            is_training=self.is_training)
+
+        if self.shapes['subwords']:
+            # Forms for subwords
+            self._subwords = tf.placeholder(
+                tf.string,
+                shape=[
+                    None,
+                    None],
+                name="subwords")
+            inputs = tf.concat([inputs, self.subword_reprs()], axis=-1)
+
+        self._inputs = inputs
+
+    def setup_placeholders(self):
+        self._is_training = tf.placeholder(tf.bool, [], "is_training")
+
+        self._tags = tf.placeholder(
+            tf.int32, name="tags", shape=[
+                None, None])
+
+        self.setup_inputs()
 
         self._seq_lens = tf.placeholder(
             tf.int32, [None], name="seq_lens")


@@ 116,6 139,55 @@
             self.seq_lens, maxlen=tf.shape(
                 self.inputs)[1], dtype=tf.float32)
 
+    def subword_reprs(self):
+        # Convert strings to a byte tensor.
+        #
+        # Shape: [batch_size, seq_len, subword_len]
+        subword_bytes = tf.strings.unicode_decode(
+            self.subwords, input_encoding='UTF-8')
+        subword_bytes_padded = subword_bytes.to_tensor(
+            0)[:, :, :self.args.subword_len]
+
+        # Get the lengths of the subwords. Only the last dimension should
+        # be ragged, so no actual padding should happen.
+        #
+        # Shape: [batch_size, seq_len]
+        subword_lens = tf.math.minimum(
+            subword_bytes.row_lengths(
+                axis=-1).to_tensor(0),
+            self.args.subword_len)
+
+        # Lookup byte embeddings, this results in a tensor of shape.
+        #
+        # Shape: [batch_size, seq_len, max_bytes_len, byte_embed_size]
+        byte_embeds = tf.get_variable(
+            "byte_embeds", [
+                256, self.args.byte_embed_size])
+        byte_reprs = tf.nn.embedding_lookup(byte_embeds, subword_bytes_padded)
+
+        byte_reprs = tf.contrib.layers.dropout(
+            byte_reprs,
+            keep_prob=self.args.keep_prob_input,
+            is_training=self.is_training)
+
+        # Prepare shape for applying the RNN:
+        #
+        # Shape: [batch_size * seq_len, max_bytes_len, byte_embed_size]
+        bytes_shape = tf.shape(subword_bytes_padded)
+        byte_reprs = tf.reshape(
+            byte_reprs, [-1, bytes_shape[2], self.args.byte_embed_size])
+        byte_lens = tf.reshape(subword_lens, [-1])
+
+        with tf.variable_scope("byte_rnn"):
+            _, fw, bw = bidi_rnn_layers(self.is_training, byte_reprs, num_layers=self.args.subword_layers, output_size=self.args.subword_hidden_size,
+                                        output_keep_prob=self.args.subword_keep_prob, seq_lens=byte_lens, gru=self.args.subword_gru, residual_connections=self.args.subword_residual)
+
+        # Concat forward/backward states.
+        subword_reprs = tf.concat([fw[-1].h, bw[-1].h], axis=-1)
+
+        return tf.reshape(subword_reprs, [bytes_shape[0], bytes_shape[1],
+                                          subword_reprs.shape[-1]])
+
     def create_summary_ops(self, acc, grad_norm, loss, lr):
         step = tf.train.get_or_create_global_step()
 


@@ 171,6 243,10 @@
         return self._seq_lens
 
     @property
+    def subwords(self):
+        return self._subwords
+
+    @property
     def tags(self):
         return self._tags
 

A sticker-graph/sticker_graph/rnn.py => sticker-graph/sticker_graph/rnn.py +59 -0
@@ 0,0 1,59 @@
+import tensorflow as tf
+
+import sticker_graph.vendored
+
+
+def dropout_wrapper(
+        cell,
+        is_training,
+        output_keep_prob=1.0,
+        state_keep_prob=1.0):
+    output_keep_prob = tf.cond(
+        is_training,
+        lambda: tf.constant(output_keep_prob),
+        lambda: tf.constant(1.0))
+    state_keep_prob = tf.cond(
+        is_training,
+        lambda: tf.constant(state_keep_prob),
+        lambda: tf.constant(1.0))
+    return tf.contrib.rnn.DropoutWrapper(
+        cell,
+        output_keep_prob=output_keep_prob,
+        state_keep_prob=state_keep_prob)
+
+
+def bidi_rnn_layers(
+        is_training,
+        inputs,
+        num_layers=1,
+        output_size=50,
+        output_keep_prob=1.0,
+        state_keep_prob=1.0,
+        seq_lens=None,
+        gru=False,
+        residual_connections=False):
+    if gru:
+        cell = tf.nn.rnn_cell.GRUCell
+    else:
+        cell = tf.contrib.rnn.LSTMCell
+
+    fw_cells = [
+        dropout_wrapper(
+            cell=cell(output_size),
+            is_training=is_training,
+            state_keep_prob=state_keep_prob,
+            output_keep_prob=output_keep_prob) for i in range(num_layers)]
+
+    bw_cells = [
+        dropout_wrapper(
+            cell=cell(output_size),
+            is_training=is_training,
+            state_keep_prob=state_keep_prob,
+            output_keep_prob=output_keep_prob) for i in range(num_layers)]
+    return sticker_graph.vendored.stack_bidirectional_dynamic_rnn(
+        fw_cells,
+        bw_cells,
+        inputs,
+        dtype=tf.float32,
+        sequence_length=seq_lens,
+        residual_connections=residual_connections)

M sticker-graph/sticker_graph/rnn_model.py => sticker-graph/sticker_graph/rnn_model.py +2 -62
@@ 2,62 2,7 @@
 from tensorflow.contrib.layers import batch_norm
 
 from sticker_graph.model import Model
-import sticker_graph.vendored
-
-def dropout_wrapper(
-        cell,
-        is_training,
-        output_keep_prob=1.0,
-        state_keep_prob=1.0):
-    output_keep_prob = tf.cond(
-        is_training,
-        lambda: tf.constant(output_keep_prob),
-        lambda: tf.constant(1.0))
-    state_keep_prob = tf.cond(
-        is_training,
-        lambda: tf.constant(state_keep_prob),
-        lambda: tf.constant(1.0))
-    return tf.contrib.rnn.DropoutWrapper(
-        cell,
-        output_keep_prob=output_keep_prob,
-        state_keep_prob=state_keep_prob)
-
-
-def bidi_rnn_layers(
-        is_training,
-        inputs,
-        num_layers=1,
-        output_size=50,
-        output_keep_prob=1.0,
-        state_keep_prob=1.0,
-        seq_lens=None,
-        gru=False,
-        residual_connections=False):
-    if gru:
-        cell = tf.nn.rnn_cell.GRUCell
-    else:
-        cell = tf.contrib.rnn.LSTMCell
-
-    fw_cells = [
-        dropout_wrapper(
-            cell=cell(output_size),
-            is_training=is_training,
-            state_keep_prob=state_keep_prob,
-            output_keep_prob=output_keep_prob) for i in range(num_layers)]
-
-    bw_cells = [
-        dropout_wrapper(
-            cell=cell(output_size),
-            is_training=is_training,
-            state_keep_prob=state_keep_prob,
-            output_keep_prob=output_keep_prob) for i in range(num_layers)]
-    return sticker_graph.vendored.stack_bidirectional_dynamic_rnn(
-        fw_cells,
-        bw_cells,
-        inputs,
-        dtype=tf.float32,
-        sequence_length=seq_lens,
-        residual_connections=residual_connections)
+from sticker_graph.rnn import bidi_rnn_layers
 
 
 class RNNModel(Model):


@@ 69,14 14,9 @@
 
         self.setup_placeholders()
 
-        inputs = tf.contrib.layers.dropout(
-            self.inputs,
-            keep_prob=args.keep_prob_input,
-            is_training=self.is_training)
-
         hidden_states, _, _ = bidi_rnn_layers(
             self.is_training,
-            inputs,
+            self.inputs,
             num_layers=args.rnn_layers,
             output_size=args.hidden_size,
             output_keep_prob=args.keep_prob,

M sticker-graph/sticker_graph/transformer_model.py => sticker-graph/sticker_graph/transformer_model.py +1 -5
@@ 259,11 259,7 @@
             raise NotImplementedError('Activation %s is not available.'
                                       % args.activation)
 
-        inputs = tf.contrib.layers.dropout(
-            self.inputs,
-            keep_prob=args.keep_prob_input,
-            is_training=self.is_training)
-
+        inputs = self.inputs
         if not args.pass_inputs:
             inputs = tf.layers.dense(inputs, args.outer_hsize, activation)
         else:

M sticker-graph/sticker_graph/write_helper.py => sticker-graph/sticker_graph/write_helper.py +34 -0
@@ 75,10 75,44 @@
         type=str,
         help='output graph file')
     parser.add_argument(
+        "--byte_embed_size",
+        type=int,
+        help="size of character embeddings",
+        default=25)
+    parser.add_argument(
         "--crf",
         help="use CRF layer for classification",
         action="store_true")
     parser.add_argument(
+        "--subword_gru",
+        help="use GRU RNN cells in the character RNN",
+        action="store_true")
+    parser.add_argument(
+        "--subword_hidden_size",
+        type=int,
+        help="character RNN hidden size per direction",
+        default=25)
+    parser.add_argument(
+        "--subword_keep_prob",
+        type=float,
+        help="character RNN dropout keep probability",
+        default=0.6)
+    parser.add_argument(
+        "--subword_layers",
+        type=int,
+        help="character RNN hidden layers",
+        default=1)
+    parser.add_argument(
+        "--subword_len",
+        type=int,
+        help="number of characters in character-based representations",
+        default=20)
+    parser.add_argument(
+        "--subword_residual",
+        action='store_true',
+        help="use character RNN residual skip connections"
+    )
+    parser.add_argument(
         "--top_k",
         type=int,
         help="number of predictions to return",

M sticker-graph/testdata/sticker.shapes => sticker-graph/testdata/sticker.shapes +1 -0
@@ 1,3 1,4 @@
 n_labels = 13
+subwords = true
 token_embed_dims = 300
 tag_embed_dims = 50

M sticker-utils/src/bin/sticker-prepare.rs => sticker-utils/src/bin/sticker-prepare.rs +6 -3
@@ 24,6 24,7 @@
 #[derive(Serialize)]
 struct Shapes {
     n_labels: usize,
+    subwords: bool,
     token_embed_dims: usize,
     tag_embed_dims: usize,
 }


@@ 81,10 82,11 @@
     let shapes_write = output.write().or_exit("Cannot create shapes file", 1);
 
     let embeddings = config
+        .input
         .embeddings
         .load_embeddings()
         .or_exit("Cannot load embeddings", 1);
-    let vectorizer = SentVectorizer::new(embeddings);
+    let vectorizer = SentVectorizer::new(embeddings, config.input.subwords);
 
     match config.labeler.labeler_type {
         LabelerType::Sequence(ref layer) => prepare_with_encoder(


@@ 128,7 130,7 @@
     let mut collector = NoopCollector::new(encoder, labels, vectorizer);
     collect_sentences(&mut collector, read);
     write_labels(&config, collector.labels()).or_exit("Cannot write labels", 1);
-    write_shapes(shapes_write, &collector);
+    write_shapes(&config, shapes_write, &collector);
 }
 
 fn collect_sentences<E, R>(collector: &mut NoopCollector<E>, reader: R)


@@ 155,7 157,7 @@
     labels.to_cbor_write(&mut f)
 }
 
-fn write_shapes<W, E>(mut shapes_write: W, collector: &NoopCollector<E>)
+fn write_shapes<W, E>(config: &Config, mut shapes_write: W, collector: &NoopCollector<E>)
 where
     W: Write,
     E: SentenceEncoder,


@@ 163,6 165,7 @@
 {
     let shapes = Shapes {
         n_labels: collector.labels().len(),
+        subwords: config.input.subwords,
         token_embed_dims: collector
             .vectorizer()
             .layer_embeddings()

M sticker-utils/src/bin/sticker-pretrain.rs => sticker-utils/src/bin/sticker-pretrain.rs +17 -5
@@ 240,10 240,11 @@
     let mut categorical_encoder = CategoricalEncoder::new(encoder, labels);
 
     let embeddings = config
+        .input
         .embeddings
         .load_embeddings()
         .or_exit("Cannot load embeddings", 1);
-    let vectorizer = SentVectorizer::new(embeddings);
+    let vectorizer = SentVectorizer::new(embeddings, config.input.subwords);
 
     let mut best_epoch = 0;
     let mut best_acc = 0.0;


@@ 339,20 340,31 @@
         .batches(encoder, vectorizer, config.model.batch_size, app.max_len)
         .or_exit("Cannot read batches", 1)
     {
-        let (inputs, seq_lens, labels) = batch.or_exit("Cannot read batch", 1).into_parts();
+        let tensors = batch.or_exit("Cannot read batch", 1).into_parts();
 
         let batch_perf = if is_training {
             let bytes_done = (epoch * train_size) + file.seek(SeekFrom::Current(0))? as usize;
             let lr_scale = 1f32 - (bytes_done as f32 / (app.epochs * train_size) as f32);
             let lr = lr_scale * app.initial_lr.into_inner();
-            let batch_perf = trainer.train(&seq_lens, &inputs, &labels, lr);
+            let batch_perf = trainer.train(
+                &tensors.seq_lens,
+                &tensors.inputs,
+                tensors.subwords.as_ref(),
+                &tensors.labels,
+                lr,
+            );
             progress_bar.set_message(&format!(
                 "lr: {:.6}, loss: {:.4}, accuracy: {:.4}",
                 lr, batch_perf.loss, batch_perf.accuracy
             ));
             batch_perf
         } else {
-            let batch_perf = trainer.validate(&seq_lens, &inputs, &labels);
+            let batch_perf = trainer.validate(
+                &tensors.seq_lens,
+                &tensors.inputs,
+                tensors.subwords.as_ref(),
+                &tensors.labels,
+            );
             progress_bar.set_message(&format!(
                 "batch loss: {:.4}, batch accuracy: {:.4}",
                 batch_perf.loss, batch_perf.accuracy


@@ 360,7 372,7 @@
             batch_perf
         };
 
-        let n_tokens = seq_lens.view().iter().sum::<i32>();
+        let n_tokens = tensors.seq_lens.view().iter().sum::<i32>();
         loss += n_tokens as f32 * batch_perf.loss;
         acc += n_tokens as f32 * batch_perf.accuracy;
         instances += n_tokens;

M sticker-utils/src/bin/sticker-train.rs => sticker-utils/src/bin/sticker-train.rs +17 -5
@@ 271,10 271,11 @@
     let mut categorical_encoder = CategoricalEncoder::new(encoder, labels);
 
     let embeddings = config
+        .input
         .embeddings
         .load_embeddings()
         .or_exit("Cannot load embeddings", 1);
-    let vectorizer = SentVectorizer::new(embeddings);
+    let vectorizer = SentVectorizer::new(embeddings, config.input.subwords);
 
     let mut best_epoch = 0;
     let mut best_acc = 0.0;


@@ 380,15 381,26 @@
         .batches(encoder, vectorizer, config.model.batch_size, app.max_len)
         .or_exit("Cannot read batches", 1)
     {
-        let (inputs, seq_lens, labels) = batch.or_exit("Cannot read batch", 1).into_parts();
+        let tensors = batch.or_exit("Cannot read batch", 1).into_parts();
 
         let batch_perf = if is_training {
-            trainer.train(&seq_lens, &inputs, &labels, lr)
+            trainer.train(
+                &tensors.seq_lens,
+                &tensors.inputs,
+                tensors.subwords.as_ref(),
+                &tensors.labels,
+                lr,
+            )
         } else {
-            trainer.validate(&seq_lens, &inputs, &labels)
+            trainer.validate(
+                &tensors.seq_lens,
+                &tensors.inputs,
+                tensors.subwords.as_ref(),
+                &tensors.labels,
+            )
         };
 
-        let n_tokens = seq_lens.view().iter().sum::<i32>();
+        let n_tokens = tensors.seq_lens.view().iter().sum::<i32>();
         loss += n_tokens as f32 * batch_perf.loss;
         acc += n_tokens as f32 * batch_perf.accuracy;
         instances += n_tokens;

M sticker-utils/src/config.rs => sticker-utils/src/config.rs +12 -4
@@ 17,7 17,7 @@
 #[serde(deny_unknown_fields)]
 pub struct Config {
     pub labeler: Labeler,
-    pub embeddings: Embeddings,
+    pub input: Input,
     pub model: ModelConfig,
 }
 


@@ 30,9 30,9 @@
         let config_path = config_path.as_ref();
 
         self.labeler.labels = relativize_path(config_path, &self.labeler.labels)?;
-        self.embeddings.word.filename =
-            relativize_path(config_path, &self.embeddings.word.filename)?;
-        if let Some(ref mut embeddings) = self.embeddings.tag {
+        self.input.embeddings.word.filename =
+            relativize_path(config_path, &self.input.embeddings.word.filename)?;
+        if let Some(ref mut embeddings) = self.input.embeddings.tag {
             embeddings.filename = relativize_path(config_path, &embeddings.filename)?;
         }
         self.model.graph = relativize_path(config_path, &self.model.graph)?;


@@ 94,6 94,14 @@
     RelativePOS,
 }
 
+/// Input configuration
+#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
+#[serde(deny_unknown_fields)]
+pub struct Input {
+    pub embeddings: Embeddings,
+    pub subwords: bool,
+}
+
 #[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
 #[serde(deny_unknown_fields)]
 pub struct Labeler {

M sticker-utils/src/config_tests.rs => sticker-utils/src/config_tests.rs +12 -9
@@ 4,7 4,7 @@
 use sticker::tensorflow::ModelConfig;
 use sticker::Layer;
 
-use super::{Config, Embedding, EmbeddingAlloc, Embeddings, Labeler, LabelerType, TomlRead};
+use super::{Config, Embedding, EmbeddingAlloc, Embeddings, Input, Labeler, LabelerType, TomlRead};
 
 lazy_static! {
     static ref BASIC_LABELER_CHECK: Config = Config {


@@ 13,15 13,18 @@
             labels: "sticker.labels".to_owned(),
             read_ahead: 10,
         },
-        embeddings: Embeddings {
-            word: Embedding {
-                filename: "word-vectors.bin".into(),
-                alloc: EmbeddingAlloc::Mmap,
+        input: Input {
+            embeddings: Embeddings {
+                word: Embedding {
+                    filename: "word-vectors.bin".into(),
+                    alloc: EmbeddingAlloc::Mmap,
+                },
+                tag: Some(Embedding {
+                    filename: "tag-vectors.bin".into(),
+                    alloc: EmbeddingAlloc::Read,
+                }),
             },
-            tag: Some(Embedding {
-                filename: "tag-vectors.bin".into(),
-                alloc: EmbeddingAlloc::Read,
-            }),
+            subwords: true,
         },
         model: ModelConfig {
             batch_size: 128,

M sticker-utils/src/lib.rs => sticker-utils/src/lib.rs +1 -1
@@ 3,7 3,7 @@
 
 mod config;
 pub use crate::config::{
-    Config, Embedding, EmbeddingAlloc, Embeddings, EncoderType, Labeler, LabelerType,
+    Config, Embedding, EmbeddingAlloc, Embeddings, EncoderType, Input, Labeler, LabelerType,
 };
 
 mod progress;

M sticker-utils/src/tagger_wrapper.rs => sticker-utils/src/tagger_wrapper.rs +2 -1
@@ 45,10 45,11 @@
     /// Create a tagger from the given configuration.
     pub fn new(config: &Config) -> Fallible<Self> {
         let embeddings = config
+            .input
             .embeddings
             .load_embeddings()
             .with_context(|e| format!("Cannot load embeddings: {}", e))?;
-        let vectorizer = SentVectorizer::new(embeddings);
+        let vectorizer = SentVectorizer::new(embeddings, config.input.subwords);
 
         let graph_reader = File::open(&config.model.graph).with_context(|e| {
             format!(

M sticker-utils/testdata/sticker.conf => sticker-utils/testdata/sticker.conf +5 -2
@@ 3,11 3,14 @@
   labels = "sticker.labels"
   read_ahead = 10
 
-[embeddings.word]
+[input]
+  subwords = true
+
+[input.embeddings.word]
   filename = "word-vectors.bin"
   alloc = "mmap"
 
-[embeddings.tag]
+[input.embeddings.tag]
   filename = "tag-vectors.bin"
   alloc = "read"
 

M sticker/src/input.rs => sticker/src/input.rs +31 -4
@@ 74,11 74,17 @@
     }
 }
 
+pub struct InputVector {
+    pub sequence: Vec<f32>,
+    pub subwords: Option<Vec<String>>,
+}
+
 /// Vectorizer for sentences.
 ///
 /// An `SentVectorizer` vectorizes sentences.
 pub struct SentVectorizer {
     layer_embeddings: LayerEmbeddings,
+    subwords: bool,
 }
 
 impl SentVectorizer {


@@ 87,8 93,16 @@
     /// The vectorizer is constructed from the embedding matrices. The layer
     /// embeddings are used to find the indices into the embedding matrix for
     /// layer values.
-    pub fn new(layer_embeddings: LayerEmbeddings) -> Self {
-        SentVectorizer { layer_embeddings }
+    pub fn new(layer_embeddings: LayerEmbeddings, subwords: bool) -> Self {
+        SentVectorizer {
+            layer_embeddings,
+            subwords,
+        }
+    }
+
+    /// Does the vectorizer produce representations for subwords?
+    pub fn has_subwords(&self) -> bool {
+        self.subwords
     }
 
     /// Get the length of the input representation.


@@ 108,7 122,7 @@
     }
 
     /// Vectorize a sentence.
-    pub fn realize(&self, sentence: &Sentence) -> Result<Vec<f32>, Error> {
+    pub fn realize(&self, sentence: &Sentence) -> Result<InputVector, Error> {
         let input_size = self.layer_embeddings.token_embeddings.dims()
             + self
                 .layer_embeddings


@@ 118,9 132,19 @@
                 .unwrap_or_default();
         let mut input = Vec::with_capacity(sentence.len() * input_size);
 
+        let mut subwords = if self.subwords {
+            Some(Vec::with_capacity(sentence.len()))
+        } else {
+            None
+        };
+
         for token in sentence.iter().filter_map(Node::token) {
             let form = token.form();
 
+            if let Some(ref mut subwords) = subwords {
+                subwords.push(form.to_owned());
+            }
+
             input.extend_from_slice(
                 &self
                     .layer_embeddings


@@ 146,6 170,9 @@
             }
         }
 
-        Ok(input)
+        Ok(InputVector {
+            sequence: input,
+            subwords,
+        })
     }
 }

M sticker/src/lib.rs => sticker/src/lib.rs +1 -1
@@ 4,7 4,7 @@
 pub mod encoder;
 
 mod input;
-pub use crate::input::{Embeddings, LayerEmbeddings, SentVectorizer};
+pub use crate::input::{Embeddings, InputVector, LayerEmbeddings, SentVectorizer};
 
 mod numberer;
 pub use crate::numberer::Numberer;

M sticker/src/tensorflow/dataset.rs => sticker/src/tensorflow/dataset.rs +2 -1
@@ 129,6 129,7 @@
             batch_sentences.len(),
             max_seq_len,
             self.vectorizer.input_len(),
+            self.vectorizer.has_subwords(),
         );
 
         for sentence in batch_sentences {


@@ 144,7 145,7 @@
                     .collect::<Vec<_>>(),
                 Err(err) => return Some(Err(err)),
             };
-            builder.add_with_labels(&inputs, &labels);
+            builder.add_with_labels(inputs, &labels);
         }
 
         Some(Ok(builder))

M sticker/src/tensorflow/tagger.rs => sticker/src/tensorflow/tagger.rs +25 -5
@@ 6,9 6,9 @@
 use std::path::Path;
 
 use conllx::graph::Sentence;
-use failure::{Error, Fallible};
+use failure::{err_msg, Error, Fallible};
 use itertools::Itertools;
-use ndarray::{Ix1, Ix3};
+use ndarray::{Ix1, Ix2, Ix3};
 use ndarray_tensorflow::NdTensor;
 use protobuf::Message;
 use serde_derive::{Deserialize, Serialize};


@@ 90,6 90,7 @@
     pub const LR_OP: &str = "model/lr";
 
     pub const INPUTS_OP: &str = "model/inputs";
+    pub const SUBWORDS_OP: &str = "model/subwords";
     pub const SEQ_LENS_OP: &str = "model/seq_lens";
 
     pub const LOSS_OP: &str = "model/tag_loss";


@@ 126,6 127,7 @@
     pub(crate) lr_op: Operation,
     pub(crate) is_training_op: Operation,
     pub(crate) inputs_op: Operation,
+    pub(crate) subwords_op: Option<Operation>,
     pub(crate) seq_lens_op: Operation,
 
     pub(crate) loss_op: Operation,


@@ 162,6 164,7 @@
         let lr_op = Self::add_op(&graph, op_names::LR_OP)?;
 
         let inputs_op = Self::add_op(&graph, op_names::INPUTS_OP)?;
+        let subwords_op = Self::add_op(&graph, op_names::SUBWORDS_OP).ok();
         let seq_lens_op = Self::add_op(&graph, op_names::SEQ_LENS_OP)?;
 
         let loss_op = Self::add_op(&graph, op_names::LOSS_OP)?;


@@ 199,6 202,7 @@
             is_training_op,
             lr_op,
             inputs_op,
+            subwords_op,
             seq_lens_op,
 
             loss_op,


@@ 291,13 295,17 @@
             .max()
             .unwrap_or(0);
 
-        let mut builder =
-            TensorBuilder::new(sentences.len(), max_seq_len, self.vectorizer.input_len());
+        let mut builder = TensorBuilder::new(
+            sentences.len(),
+            max_seq_len,
+            self.vectorizer.input_len(),
+            self.vectorizer.has_subwords(),
+        );
 
         // Fill the batch.
         for sentence in sentences {
             let input = self.vectorizer.realize(sentence.borrow())?;
-            builder.add_without_labels(&input);
+            builder.add_without_labels(input);
         }
 
         Ok(builder)


@@ 313,6 321,7 @@
         let (tag_tensor, probs_tensor) = self.tag_sequences(
             builder.seq_lens(),
             builder.inputs(),
+            builder.subwords(),
             &self.graph.top_k_predicted_op,
             &self.graph.top_k_probs_op,
         )?;


@@ 351,6 360,7 @@
         &self,
         seq_lens: &NdTensor<i32, Ix1>,
         inputs: &NdTensor<f32, Ix3>,
+        subwords: Option<&NdTensor<String, Ix2>>,
         predicted_op: &Operation,
         probs_op: &Operation,
     ) -> Result<(Tensor<i32>, Tensor<f32>), Error> {


@@ 365,6 375,16 @@
         args.add_feed(&self.graph.seq_lens_op, 0, seq_lens.inner_ref());
         args.add_feed(&self.graph.inputs_op, 0, inputs.inner_ref());
 
+        if let Some(subwords) = subwords {
+            args.add_feed(
+                self.graph.subwords_op.as_ref().ok_or_else(|| {
+                    err_msg("Subwords used in a graph without support for subwords")
+                })?,
+                0,
+                subwords.inner_ref(),
+            );
+        }
+
         let probs_token = args.request_fetch(probs_op, 0);
         let predictions_token = args.request_fetch(predicted_op, 0);
 

M sticker/src/tensorflow/tensor.rs => sticker/src/tensorflow/tensor.rs +68 -18
@@ 1,8 1,10 @@
 use std::cmp::min;
 
-use ndarray::{s, ArrayView2, Ix1, Ix2, Ix3};
+use ndarray::{s, Array1, Array2, Ix1, Ix2, Ix3};
 use ndarray_tensorflow::NdTensor;
 
+use crate::InputVector;
+
 mod labels {
     pub trait Labels {
         fn from_shape(batch_size: usize, time_steps: usize) -> Self;


@@ 30,8 32,9 @@
 /// Build `Tensor`s from Rust slices.
 pub struct TensorBuilder<L> {
     sequence: usize,
-    sequence_lens: NdTensor<i32, Ix1>,
+    seq_lens: NdTensor<i32, Ix1>,
     inputs: NdTensor<f32, Ix3>,
+    subwords: Option<NdTensor<String, Ix2>>,
     labels: L,
 }
 


@@ 43,11 46,23 @@
     ///
     /// Creates a new builder with the given batch size, number of time steps,
     /// and input size.
-    pub fn new(batch_size: usize, time_steps: usize, inputs_size: usize) -> Self {
+    pub fn new(
+        batch_size: usize,
+        time_steps: usize,
+        inputs_size: usize,
+        use_subwords: bool,
+    ) -> Self {
+        let subwords = if use_subwords {
+            Some(NdTensor::zeros([batch_size, time_steps]))
+        } else {
+            None
+        };
+
         TensorBuilder {
             sequence: 0,
-            sequence_lens: NdTensor::zeros([batch_size]),
+            seq_lens: NdTensor::zeros([batch_size]),
             inputs: NdTensor::zeros([batch_size, time_steps, inputs_size]),
+            subwords,
             labels: L::from_shape(batch_size, time_steps),
         }
     }


@@ 55,23 70,37 @@
 
 impl<L> TensorBuilder<L> {
     /// Add an input.
-    fn add_input(&mut self, input: &[f32]) {
+    fn add_input(&mut self, input_vector: InputVector) {
         assert!(self.sequence < self.inputs.view().shape()[0]);
 
         let token_repr_size = self.inputs.view().shape()[2];
+        let input_timesteps = input_vector.sequence.len() / token_repr_size;
+
+        let input =
+            Array2::from_shape_vec([input_timesteps, token_repr_size], input_vector.sequence)
+                .unwrap();
 
-        let input = ArrayView2::from_shape([input.len() / token_repr_size, token_repr_size], input)
-            .unwrap();
+        let subwords = input_vector.subwords.map(Array1::from_vec);
 
         let timesteps = min(self.inputs.view().shape()[1], input.shape()[0]);
 
-        self.sequence_lens.view_mut()[self.sequence] = timesteps as i32;
+        self.seq_lens.view_mut()[self.sequence] = timesteps as i32;
 
         #[allow(clippy::deref_addrof)]
         self.inputs
             .view_mut()
             .slice_mut(s![self.sequence, 0..timesteps, ..])
             .assign(&input.slice(s![0..timesteps, ..]));
+
+        if let Some(subwords) = subwords {
+            #[allow(clippy::deref_addrof)]
+            self.subwords
+                .as_mut()
+                .unwrap()
+                .view_mut()
+                .slice_mut(s![self.sequence, 0..timesteps])
+                .assign(&subwords.slice(s![..timesteps]));
+        }
     }
 
     /// Get the constructed tensors.


@@ 81,8 110,13 @@
     /// * The input tensor.
     /// * The sequence lengths tensor.
     /// * The labels.
-    pub fn into_parts(self) -> (NdTensor<f32, Ix3>, NdTensor<i32, Ix1>, L) {
-        (self.inputs, self.sequence_lens, self.labels)
+    pub fn into_parts(self) -> Tensors<L> {
+        Tensors {
+            inputs: self.inputs,
+            subwords: self.subwords,
+            seq_lens: self.seq_lens,
+            labels: self.labels,
+        }
     }
 
     /// Get the inputs lengths tensor.


@@ 98,19 132,27 @@
 
     /// Get the sequence lengths tensor.
     pub fn seq_lens(&self) -> &NdTensor<i32, Ix1> {
-        &self.sequence_lens
+        &self.seq_lens
+    }
+
+    /// Get subwords tensor.
+    pub fn subwords(&self) -> Option<&NdTensor<String, Ix2>> {
+        self.subwords.as_ref()
     }
 }
 
 impl TensorBuilder<LabelTensor> {
     /// Add an instance with labels.
-    pub fn add_with_labels(&mut self, input: &[f32], labels: &[i32]) {
-        self.add_input(input);
-
+    pub fn add_with_labels(&mut self, input_vector: InputVector, labels: &[i32]) {
+        // Number of sequence time steps.
         let token_repr_size = self.inputs.view().shape()[2] as usize;
+        let timesteps = min(
+            self.inputs.view().shape()[1],
+            input_vector.sequence.len() / token_repr_size,
+        );
+
+        self.add_input(input_vector);
 
-        // Number of time steps to copy
-        let timesteps = min(self.inputs.view().shape()[1], input.len() / token_repr_size);
         #[allow(clippy::deref_addrof)]
         self.labels
             .view_mut()


@@ 123,8 165,16 @@
 
 impl TensorBuilder<NoLabels> {
     /// Add an instance without labels.
-    pub fn add_without_labels(&mut self, input: &[f32]) {
-        self.add_input(input);
+    pub fn add_without_labels(&mut self, input_vector: InputVector) {
+        self.add_input(input_vector);
         self.sequence += 1;
     }
 }
+
+/// Tensors constructed by `TensorBuilder`.
+pub struct Tensors<L> {
+    pub inputs: NdTensor<f32, Ix3>,
+    pub subwords: Option<NdTensor<String, Ix2>>,
+    pub seq_lens: NdTensor<i32, Ix1>,
+    pub labels: L,
+}

M sticker/src/tensorflow/trainer.rs => sticker/src/tensorflow/trainer.rs +16 -2
@@ 128,6 128,7 @@
         &self,
         seq_lens: &NdTensor<i32, Ix1>,
         inputs: &NdTensor<f32, Ix3>,
+        subwords: Option<&NdTensor<String, Ix2>>,
         labels: &NdTensor<i32, Ix2>,
         learning_rate: f32,
     ) -> ModelPerformance {


@@ 150,7 151,7 @@
                     .expect("Summaries requested from a graph that does not support summaries."),
             );
         }
-        self.validate_(seq_lens, inputs, labels, args)
+        self.validate_(seq_lens, inputs, subwords, labels, args)
     }
 
     /// Perform validation using a batch of inputs and labels.


@@ 158,6 159,7 @@
         &self,
         seq_lens: &NdTensor<i32, Ix1>,
         inputs: &NdTensor<f32, Ix3>,
+        subwords: Option<&NdTensor<String, Ix2>>,
         labels: &NdTensor<i32, Ix2>,
     ) -> ModelPerformance {
         let mut is_training = Tensor::new(&[]);


@@ 173,13 175,14 @@
                     .expect("Summaries requested from a graph that does not support summaries."),
             );
         }
-        self.validate_(seq_lens, inputs, labels, args)
+        self.validate_(seq_lens, inputs, subwords, labels, args)
     }
 
     fn validate_<'l>(
         &'l self,
         seq_lens: &'l NdTensor<i32, Ix1>,
         inputs: &'l NdTensor<f32, Ix3>,
+        subwords: Option<&'l NdTensor<String, Ix2>>,
         labels: &'l NdTensor<i32, Ix2>,
         mut args: SessionRunArgs<'l>,
     ) -> ModelPerformance {


@@ 187,6 190,17 @@
         args.add_feed(&self.graph.inputs_op, 0, inputs.inner_ref());
         args.add_feed(&self.graph.seq_lens_op, 0, seq_lens.inner_ref());
 
+        if let Some(subwords) = subwords {
+            args.add_feed(
+                self.graph
+                    .subwords_op
+                    .as_ref()
+                    .expect("Subwords used in a graph without support for subwords"),
+                0,
+                subwords.inner_ref(),
+            );
+        }
+
         // Add gold labels.
         args.add_feed(&self.graph.labels_op, 0, labels.inner_ref());