I recently wrote a guide on recurrent networks in TensorFlow. That covered the basics but often we want to learn on sequences of variable lengths, possibly even within the same batch of training examples. In this post, I will explain how to use variable length sequences in TensorFlow and what implications they have on your model.

## Computing the Sequence Length

Since TensorFlow unfolds our recurrent network for a given number of steps, we can only feed sequences of that shape to the network. We also want the input to have a fixed size so that we can represent a training batch as a single tensor of shape batch size x max length x features.

I will assume that the sequences are padded with zero vectors to fill up the remaining time steps in the batch. To pass sequence lengths to TensorFlow, we have to compute them from the batch. While we could do this in Numpy in a pre-processing step, let’s do it on the fly as part of the compute graph!

def length(sequence):
used = tf.sign(tf.reduce_max(tf.abs(sequence), 2))
length = tf.reduce_sum(used, 1)
length = tf.cast(length, tf.int32)
return length


We first collapse the frame vectors (third dimension of a batch) into scalars using maximum. Each sequence is now a vector of scalars that will be zero for the padded frames at the end. We then use tf.sign() to convert the actual frames from their maximum values to values of one. This gives us a binary mask of ones for used frames and zeros for unused frames that we can just sum to get the sequence length.

## Using the Length Information

Now that we have a vector holding the sequence lengths, we can pass that to dynamic_rnn(), the function that unfolds our network, using the optional sequence_length parameter. When running the model later, TensorFlow will return zero vectors for states and outputs after these sequence lengths. Therefore, weights will not affect those outputs and don’t get trained on them.

max_length = 100
frame_size = 64
num_hidden = 200

sequence = tf.placeholder(
tf.float32, [None, max_length, frame_size])
output, state = tf.nn.dynamic_rnn(
tf.contrib.rnn.GRUCell(num_hidden),
sequence,
dtype=tf.float32,
sequence_length=length(sequence),
)


Note that our output will still be of size batch_size x max_length x out_size, but with the last being zero vectors for sequences shorter than the maximum length. When you use the outputs at each time step, as in sequence labeling, we don’t want to consider them in our cost function. We mask out the unused frames and compute the mean error over the sequence length by dividing by the actual length. Using tf.reduce_mean() does not work here because it would devide by the maximum sequence length.

def cost(output, target):
# Compute cross entropy for each frame.
cross_entropy = target * tf.log(output)
cross_entropy = -tf.reduce_sum(cross_entropy, 2)
# Average over actual sequence lengths.
cross_entropy = tf.reduce_sum(cross_entropy, 1)
return tf.reduce_mean(cross_entropy)


You can compute the average of your error function the same way. Actually, we wouldn’t have to do the masking for the cost and error functions because both prediction and target are zero vectors for the padding frames so they are perfect predictions. Anyway, it’s nice to be explicit in code. Here is a full example of variable-length sequence labeling.

## Select the Last Relevant Output

For sequence classification, we want to feed the last output of the recurrent network into a predictor, e.g. a softmax layer. While taking the last frame worked well for fixed-sized sequences, we not have to select the last relevant frame. This is a bit cumbersome in TensorFlow since it does’t support advanced slicing yet. In Numpy this would just be output[:, length - 1]. But we need the indexing to be part of the compute graph in order to train the whole system end-to-end.

def last_relevant(output, length):
batch_size = tf.shape(output)[0]
max_length = tf.shape(output)[1]
out_size = int(output.get_shape()[2])
index = tf.range(0, batch_size) * max_length + (length - 1)
flat = tf.reshape(output, [-1, out_size])
relevant = tf.gather(flat, index)
return relevant


What happens here? We flatten the output tensor to shape frames in all examples x output size. Then we construct an index into that by creating a tensor with the start indices for each example tf.range(0, batch_size) * max_length and add the individual sequence lengths to it. tf.gather() then performs the actual indexing. Let’s hope the TensorFlow guys can provide proper indexing soon so this gets much easier.

On a side node: A one-layer GRU network outputs its full state. In that case, we can use the state returned by tf.nn.dynamic_rnn() directly. Similarly, we can use state.o for a one-layer LSTM network. For more complex architectures, that doesn’t work or at least result in a large amount of parameters.

We got the last relevant output and can feed that into a simple softmax layer to predict the class of each sequence:

num_classes = 10

last = last_relevant(output)
weight = tf.Variable(
tf.truncated_normal([num_hidden, num_classes], stddev=0.1))
bias = tf.Variable(tf.constant(0.1, shape=[num_classes]))
prediction = tf.nn.softmax(tf.matmul(last, weight) + bias)


You can of course use more complex predictors with multiple layers as well. Here is the working example for variable-length sequence classification.

I explained how to use recurrent networks on variable-length sequences and how to use their outputs. Feel free to comment with questions and remarks.

Updated 2016-08-17: TensorFlow 0.10 moved the recurrent network operations from tf.models.rnn into the tf.nn package where they live along the other neural network operations now. Cells can now be found in tf.nn.rnn_cell.

Updated 2017-07-12: TensorFlow 1.0 moves recurrent network cells from tf.nn.rnn_cell to tf.contrib.rnn.

You can use this post under the open CC BY-SA 3.0 license and cite it as:

@misc{hafner2016tfrnnvarlength,
author = {Hafner, Danijar},
title = {Variable Sequence Lengths in TensorFlow},
year = {2016},
howpublished = {Blog post},
url = {https://danijar.com/variable-sequence-lengths-in-tensorflow/}
}