Structuring Deep Learning Models
In TensorFlow 2, there is no need to worry about graphs anymore. This makes your code easier to understand and debug and it also changes how you can structure it. This makes a good opportunity to provide an update on my previous post on structuring models. Below, I’m describing a simple way to structure your models in TensorFlow 2. You can find two full code examples here:
Model Classes
A simple and common design is to have one class per model. Different methods
implement the functions performed by the model. For example, an auto-encoder is
a class with encode()
and decode()
methods. In the easiest case, the layers
are created in the constructor and used afterwards:
class AutoEncoder:
def __init__(self, code_size=64):
self._enc1 = tfl.Conv2D(32, 4, 2)
self._enc2 = tfl.Conv2D(64, 4, 2)
self._enc3 = tfl.Conv2D(128, 4, 2)
self._enc4 = tfl.Dense(code_size)
self._dec1 = tfl.Conv2DTranspose(128, 5, 2)
self._dec2 = tfl.Conv2DTranspose(64, 5, 2)
self._dec3 = tfl.Conv2DTranspose(1, 4, 2)
self._flatten = tfl.Flatten()
def encode(self, image):
hidden = self._enc1(image)
hidden = tf.nn.elu(hidden)
hidden = self._enc2(hidden)
hidden = tf.nn.elu(hidden)
hidden = self._enc3(hidden)
hidden = tf.nn.elu(hidden)
hidden = self._flatten(hidden)
code = self._enc4(hidden)
return code
def decode(self, code):
hidden = code[:, None, None, :]
hidden = self._dec1(hidden)
hidden = tf.nn.elu(hidden)
hidden = self._dec2(hidden)
hidden = tf.nn.elu(hidden)
recon = self._dec3(hidden)
return recon
model = AutoEncoder()
code = model.encode(image)
recon = model.decode(code)
We could add additional methods to the class, for example for computing the
loss or training the model. If all your models are trained in a similar way, it
can be better to keep that logic outside of the model class though. In
TensorFlow 2, methods that will be used many times should be decorated with
@tf.function
to get compiled for faster computation.
Accessing Parameters
We often need to get the list of model parameters, for example for training or
checkpointing. TensorFlow 2 provides this functionality for us. Inheriting from
tf.Module
adds the properties model.variables
and
model.trainable_variables
that automatically search for any variables stored
in self
:
class AutoEncoder(tf.Module):
# ...
model = AutoEncoder()
print(model.variables) # [...]
print(model.trainable_variables) # [...]
Alternatively, we can manually create a list of variables:
class AutoEncoder:
@property
def variables(self):
variables = []
variables += self._enc1.variables
variables += self._enc2.variables
# ...
return variables
# ...
The resulting list of variables can then be used to compute gradients using a
tf.GradientTape()
as described in the documentation.
Layers Dictionary
So far, we have created layers in the constructor and used them in the later
functions. However, this means we need to mention each layer twice, once to
create it and once to use it. In more complex models, this can make it
difficult to keep track of what’s going on. Instead, we can use a dictionary to
hold all layers. A get()
function then lets us use the layers, creating them
on the fly when accessed for the first time:
class AutoEncoder:
def __init__(self, code_size):
self._code_size = code_size
self._layers = {}
def get(self, name, ctor, *args, **kwargs):
if name not in self._layers:
self._layers[name] = ctor(*args, **kwargs)
return self._layers[name]
def encode(self, image):
hidden = self.get('enc1', tfl.Conv2D, 32, 4, 2)(image)
hidden = tf.nn.elu(hidden)
hidden = self.get('enc2', tfl.Conv2D, 64, 4, 2)(hidden)
hidden = tf.nn.elu(hidden)
hidden = self.get('enc3', tfl.Conv2D, 128, 4, 2)(hidden)
hidden = tf.nn.elu(hidden)
hidden = self.get('flatten', tfl.Flatten)(hidden)
code = self.get('enc4', tfl.Dense, self._code_size)(hidden)
return code
def decode(self, code):
hidden = code[:, None, None, :]
hidden = self.get('dec1', tfl.Conv2DTranspose, 128, 5, 2)(hidden)
hidden = tf.nn.elu(hidden)
hidden = self.get('dec2', tfl.Conv2DTranspose, 64, 5, 2)(hidden)
hidden = tf.nn.elu(hidden)
recon = self.get('dec3', tfl.Conv2DTranspose, 1, 4, 2)(hidden)
return recon
Using a layer dictionary, you need to mention each layer only once where it is used. To use the same layer in multiple locations, we can simply refer to it by the same name. To play around with the ideas discussed in this post and structure your own models, check out the complete colab for the auto-encoder example.
You can use this post under the open CC BY-SA 3.0 license and cite it as:
@misc{hafner2019reuselayers, author = {Hafner, Danijar}, title = {Structuring Deep Learning Models}, year = {2019}, howpublished = {Blog post}, url = {https://danijar.com/structuring-models/} }