Recurrent networks like LSTM and GRU are powerful sequence models. I will explain how to create recurrent networks in TensorFlow and use them for sequence classification and labelling tasks.
If you are not familiar with recurrent networks, I suggest you take a look at Christopher Olah’s great article first. On the TensorFlow part, I also expect some basic knowledge. The official tutorials are a good place to start.
Defining the Network
To use recurrent networks in TensorFlow we first need to define the network architecture consiting of one or more layers, the cell type and possibly dropout between the layers. In TensorFlow, we build recurrent networks out of so called cells that wrap each other.
import tensorflow num_units = 200 num_layers = 3 dropout = tf.placeholder(tf.float32) cells =  for _ in range(num_layers): cell = tf.contrib.rnn.GRUCell(num_units) # Or LSTMCell(num_units) cell = tf.contrib.rnn.DropoutWrapper( cell, output_keep_prob=1.0 - dropout) cells.append(cell) cell = tf.contrib.rnn.MultiRNNCell(cells)
Simulating Time Steps
We can now add the operations to the graph that simulate the recurrent network
over the time steps of the input. We do this using TensorFlow’s
operation. It takes the a tensor block holding the input sequences and returns
the output activations and last hidden state as tensors.
# Batch size x time steps x features. data = tf.placeholder(tf.float32, [None, None, 28]) output, state = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32)
For classification, you might only care about the output activation at the last
time step. We transpose so that the time axis is first and use
selecting the last frame. We can’t just use
output[-1] because unlike Python
lists, TensorFlow doesn’t support negative indexing yet.
output, _ = tf.nn.dynamic_rnn(cell, data, dtype=tf.float32) output = tf.transpose(output, [1, 0, 2]) last = tf.gather(output, int(output.get_shape()) - 1)
The code below adds a softmax classifier ontop of the last activation and defines the cross entropy loss function. Here is the complete gist for sequence classification.
out_size = target.get_shape().value logit = tf.contrib.layers.fully_connected( last, out_size, activation_fn=None) prediction = tf.nn.softmax(logit) loss = tf.losses.softmax_cross_entropy(target, logit)
For now we assume sequences to be equal in length. Please refer to my other post on handling sequences of different length.
For sequence labelling, we want a prediction for each timestamp. However, we share the weights for the softmax layer across all timesteps. How do we do that? By flattening the first two dimensions of the output tensor. This way time steps look the same as examples in the batch to the weight matrix. Afterwards, we reshape back to the desired shape.
out_size = target.get_shape().value logit = tf.contrib.layers.fully_connected( output, out_size, activation_fn=None) prediction = tf.nn.softmax(logit)
Let’s say we predict a class for each frame, so we keep using cross entropy as our loss function. Here we have a prediction and target for every time step. We thus compute the cross entropy for every time step and sequence in the batch, and then average along these two dimensions. Here is the complete gist for sequence labelling.
flat_target = tf.reshape(target, [-1] + target.shape.as_list()[2:]) flat_logit = tf.reshape(logit, [-1] + logit.shape.as_list()[2:]) loss = tf.losses.softmax_cross_entropy(flat_target, flat_logit) loss = tf.reduce_mean(loss)
That’s all. We have learned how to construct recurrent networks in TensorFlow and use them for sequence learning tasks. Please ask any questions below if you couldn’t follow.
Updated 2016-08-17: TensorFlow 0.10 moved the recurrent network operations
tf.models.rnn into the
tf.nn package where they live along the other
neural network operations now. Cells can now be found in
Updated 2016-05-20: TensorFlow 0.8 introduced
dynamic_rnn() that uses a
symbolic loop instead of creating a sub graph for each time step. This results
in a more compact graph. The function also expects and returns tensors directly,
so we do not need to convert to and from Python-lists anymore.
Updated 2017-06-07: TensorFlow 1.0 moved recurrent cells into
tf.contrib.rnn. From TensorFlow 1.2 on, recurrent cells reuse their weights,
so that we need to create multiple separate
GRUCells in the first code block.
Moreover, I switched to using the existing implementation of the cross entropy
loss which is numerically stable and has a more efficient gradient computation.