In TensorFlow 2, there is no need to worry about graphs anymore. This makes your code easier to understand and debug and it also changes how you can structure it. This makes a good opportunity to provide an update on my previous post on structuring models. Below, I’m describing a simple way to structure your models in TensorFlow 2. You can find two full code examples here:

Model Classes

A simple and common design is to have one class per model. Different methods implement the functions performed by the model. For example, an auto-encoder is a class with encode() and decode() methods. In the easiest case, the layers are created in the constructor and used afterwards:

class AutoEncoder:

  def __init__(self, code_size=64):
    self._enc1 = tfl.Conv2D(32, 4, 2)
    self._enc2 = tfl.Conv2D(64, 4, 2)
    self._enc3 = tfl.Conv2D(128, 4, 2)
    self._enc4 = tfl.Dense(code_size)
    self._dec1 = tfl.Conv2DTranspose(128, 5, 2)
    self._dec2 = tfl.Conv2DTranspose(64, 5, 2)
    self._dec3 = tfl.Conv2DTranspose(1, 4, 2)
    self._flatten = tfl.Flatten()

  def encode(self, image):
    hidden = self._enc1(image)
    hidden = tf.nn.elu(hidden)
    hidden = self._enc2(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self._enc3(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self._flatten(hidden)
    code = self._enc4(hidden)
    return code

  def decode(self, code):
    hidden = code[:, None, None, :]
    hidden = self._dec1(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self._dec2(hidden)
    hidden = tf.nn.elu(hidden)
    recon = self._dec3(hidden)
    return recon


model = AutoEncoder()
code = model.encode(image)
recon = model.decode(code)

We could add additional methods to the class, for example for computing the loss or training the model. If all your models are trained in a similar way, it can be better to keep that logic outside of the model class though. In TensorFlow 2, methods that will be used many times should be decorated with @tf.function to get compiled for faster computation.

Accessing Parameters

We often need to get the list of model parameters, for example for training or checkpointing. TensorFlow 2 provides this functionality for us. Inheriting from tf.Module adds the properties model.variables and model.trainable_variables that automatically search for any variables stored in self:

class AutoEncoder(tf.Module):

  # ...

model = AutoEncoder()
print(model.variables)  # [...]
print(model.trainable_variables)  # [...]

Alternatively, we can manually create a list of variables:

class AutoEncoder:

  @property
  def variables(self):
    variables = []
    variables += self._enc1.variables
    variables += self._enc2.variables
    # ...
    return variables

  # ...

The resulting list of variables can then be used to compute gradients using a tf.GradientTape() as described in the documentation.

Layers Dictionary

So far, we have created layers in the constructor and used them in the later functions. However, this means we need to mention each layer twice, once to create it and once to use it. In more complex models, this can make it difficult to keep track of what’s going on. Instead, we can use a dictionary to hold all layers. A get() function then lets us use the layers, creating them on the fly when accessed for the first time:

class AutoEncoder:

  def __init__(self, code_size):
    self._code_size = code_size
    self._layers = {}

  def get(self, name, ctor, *args, **kwargs):
    if name not in self._layers:
      self._layers[name] = ctor(*args, **kwargs)
    return self._layers[name]

  def encode(self, image):
    hidden = self.get('enc1', tfl.Conv2D, 32, 4, 2)(image)
    hidden = tf.nn.elu(hidden)
    hidden = self.get('enc2', tfl.Conv2D, 64, 4, 2)(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self.get('enc3', tfl.Conv2D, 128, 4, 2)(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self.get('flatten', tfl.Flatten)(hidden)
    code = self.get('enc4', tfl.Dense, self._code_size)(hidden)
    return code

  def decode(self, code):
    hidden = code[:, None, None, :]
    hidden = self.get('dec1', tfl.Conv2DTranspose, 128, 5, 2)(hidden)
    hidden = tf.nn.elu(hidden)
    hidden = self.get('dec2', tfl.Conv2DTranspose, 64, 5, 2)(hidden)
    hidden = tf.nn.elu(hidden)
    recon = self.get('dec3', tfl.Conv2DTranspose, 1, 4, 2)(hidden)
    return recon

Using a layer dictionary, you need to mention each layer only once where it is used. To use the same layer in multiple locations, we can simply refer to it by the same name. To play around with the ideas discussed in this post and structure your own models, check out the complete colab for the auto-encoder example.

You can use this post under the open CC BY-SA 3.0 license and cite it as:

@misc{hafner2019reuselayers,
  author = {Hafner, Danijar},
  title = {Structuring Deep Learning Models},
  year = {2019},
  howpublished = {Blog post},
  url = {https://danijar.com/structuring-models/}
}