The landscape of deep learning libraries is developing rapidly, and with it the ways to structure your machine learning models. This post introduces a functional approach to structure your code that is geared towards research. It works with any deep learning library, with or without graph. In contrast to many alternatives, it only requires you to mention your layers once.

Find the code examples for this post that train a Variational Auto-Encoder:

Please let me know if you’d like to add examples for other frameworks.

Reusing Elements

The idea for handling the state of models is a simple ReuseElements class.

class ReuseElements:

  def __init__(self, elements):
    self._elements = elements
    self._adding = (len(elements) == 0)
    self._index = 0

  def __call__(self, provided):
    if self._adding:
      return provided
    existing = self._elements[self._index]
    self._index += 1
    assert isinstance(existing, type(provided))
    return existing

# If you give it an empty list, it adds elements to that list.
elements = []
reuse = ReuseElements(elements)
assert reuse(42) == 42
assert reuse(12) == 12
assert elements == [42, 12]

# If you give it a populated list, it returns its existing elements.
elements = [42, 12]
reuse = ReuseElements(elements)
assert reuse(17) == 42
assert reuse(34) == 12
assert elements == [42, 12]

Building Neural Networks

We can use this to reuse layers in networks.

import tensorflow as tf
import tensorflow.keras.layers as tfl

def network(inputs, layers):
  reuse = ReuseElements(layers)
  hidden = reuse(tfl.Dense(100, tf.nn.relu))(inputs)
  output = reuse(tfl.Dense(100, tf.nn.relu))(hidden)
  return output

# Initialize an instance of the network.
layers = []
inputs = tf.zeros((32, 100))
output = network(inputs, layers)
# Output: [Dense(), Dense()]

# Reuse the same layers.
output = network(inputs, layers)

We can use partial functions to instantiate networks with their own layers.

import functools

# Bind the layers list so we don't have to pass it manually.
my_network = functools.partial(network, layers=[])
output = my_network(inputs)  # Creates layers.
output = my_network(inputs)  # Reuses layers.

# The partial function lets you access the list of layers.
# Output: [Dense(), Dense()]

# Binding another layer list creates an independent network.
my_other_network = functools.partial(network, layers=[])

Composing Modules

A parent module can reuse submodules simply by reusing the layer lists it passes to them. In low level modules, you can reuse weights rather than other modules.

def encoder(data, layers, code_size):
  reuse = ReuseElements(layers)
  hidden = reuse(tfl.Dense(100, tf.nn.relu))(data)
  code = reuse(tfl.Dense(code_size))(hidden)
  return code

def decoder(code, layers):
  reuse = ReuseElements(layers)
  hidden = reuse(tfl.Dense(100, tf.nn.relu))(code)
  hidden = reuse(tfl.Dense(64 * 64 * 3))(hidden)
  recon = tf.reshape(recon, [-1, 64, 64, 3])
  return recond

def auto_encoder(data, layers):
  reuse = ReuseElements(layers)
  my_encoder = functools.partial(encoder, layers=reuse([]))
  my_decoder = functools.partial(decoder, layers=reuse([]))
  code = my_encoder(data)
  recon = my_decoder(code)
  loss = tf.reduce_mean((recon - code) ** 2)
  return recon, loss

model = functools.partial(auto_encoder, layers=[], code_size=32)
recon, loss = model(inputs)
# Output: [[Dense(), Dense()], [Dense(), Dense()]]

The parent module lets you access all layers, for example to pass their weights to an optimizer. This is done in the code example linked above.

Grouping Functions into a Class

For models that expose more than one function, you may prefer a class over functions. In this case, create the lists for its layers in the constructor instead of binding them.

class AutoEncoder:

  def __init__(self, code_size):
    self._code_size = code_size
    self._encoder_layers = []
    self._decoder_layers = []

  def weights(self):
    weights = []
    for layer in self._encoder_layers:
      weights += layer.variables
    for layer in self._decoder_layers:
      weights += layer.variables
    return weights

  def encode(self, data):
    reuse = ReuseElements(self._encoder_layers)
    hidden = reuse(tfl.Dense(100, tf.nn.relu))(data)
    code = reuse(tfl.Dense(self._code_size))(hidden)
    return code

  def decode(self, code):
    reuse = ReuseElements(self._encoder_layers)
    hidden = reuse(tfl.Dense(100, tf.nn.relu))(code)
    hidden = reuse(tfl.Dense(64 * 64 * 3))(hidden)
    recon = tf.reshape(recon, [-1, 64, 64, 3])
    return recon

  def loss(self, data):
    code = self.encode(image)
    recon = self.decode(code)
    loss = tf.reduce_mean((recon - code) ** 2)
    return loss

model = AutoEncoder(code_size=32)
model.decode(code=np.zeros((100, 32)))

Classes can be composed and mix well with functions.


We pass the state into network modules. Passing an empty list asks the module to populate it with the initial state. This lets us compose instantiate and compose simply by binding function arguments. The lightweight approach combines the benefits of compute graphs (mention each layer once) and Python control flow, independent of the deep learning library.

You can use this post under the open CC BY-SA 3.0 license and cite it as:

  author = {Hafner, Danijar},
  title = {Structuring Deep Learning Models},
  year = {2019},
  howpublished = {Blog post},
  url = {}