View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Overview
This notebook shows how to do lossy data compression using neural networks and TensorFlow Compression.
Lossy compression involves making a trade-off between rate, the expected number of bits needed to encode a sample, and distortion, the expected error in the reconstruction of the sample.
The examples below use an autoencoder-like model to compress images from the MNIST dataset. The method is based on the paper End-to-end Optimized Image Compression.
More background on learned data compression can be found in this paper targeted at people familiar with classical data compression, or this survey targeted at a machine learning audience.
Setup
Install Tensorflow Compression via pip
.
# Installs the latest version of TFC compatible with the installed TF version.
read MAJOR MINOR <<< "$(pip show tensorflow | perl -p -0777 -e 's/.*Version: (\d+)\.(\d+).*/\1 \2/sg')"
pip install "tensorflow-compression<$MAJOR.$(($MINOR+1))"
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. tf-keras 2.17.0 requires tensorflow<2.18,>=2.17, but you have tensorflow 2.14.1 which is incompatible.
Import library dependencies.
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
2024-07-19 01:53:11.077097: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-07-19 01:53:11.077144: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-07-19 01:53:11.077190: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Define the trainer model.
Because the model resembles an autoencoder, and we need to perform a different set of functions during training and inference, the setup is a little different from, say, a classifier.
The training model consists of three parts:
- the analysis (or encoder) transform, converting from the image into a latent space,
- the synthesis (or decoder) transform, converting from the latent space back into image space, and
- a prior and entropy model, modeling the marginal probabilities of the latents.
First, define the transforms:
def make_analysis_transform(latent_dims):
"""Creates the analysis (encoder) transform."""
return tf.keras.Sequential([
tf.keras.layers.Conv2D(
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
tf.keras.layers.Conv2D(
50, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(
500, use_bias=True, activation="leaky_relu", name="fc_1"),
tf.keras.layers.Dense(
latent_dims, use_bias=True, activation=None, name="fc_2"),
], name="analysis_transform")
def make_synthesis_transform():
"""Creates the synthesis (decoder) transform."""
return tf.keras.Sequential([
tf.keras.layers.Dense(
500, use_bias=True, activation="leaky_relu", name="fc_1"),
tf.keras.layers.Dense(
2450, use_bias=True, activation="leaky_relu", name="fc_2"),
tf.keras.layers.Reshape((7, 7, 50)),
tf.keras.layers.Conv2DTranspose(
20, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_1"),
tf.keras.layers.Conv2DTranspose(
1, 5, use_bias=True, strides=2, padding="same",
activation="leaky_relu", name="conv_2"),
], name="synthesis_transform")
The trainer holds an instance of both transforms, as well as the parameters of the prior.
Its call
method is set up to compute:
- rate, an estimate of the number of bits needed to represent the batch of digits, and
- distortion, the mean absolute difference between the pixels of the original digits and their reconstructions.
class MNISTCompressionTrainer(tf.keras.Model):
"""Model that trains a compressor/decompressor for MNIST."""
def __init__(self, latent_dims):
super().__init__()
self.analysis_transform = make_analysis_transform(latent_dims)
self.synthesis_transform = make_synthesis_transform()
self.prior_log_scales = tf.Variable(tf.zeros((latent_dims,)))
@property
def prior(self):
return tfc.NoisyLogistic(loc=0., scale=tf.exp(self.prior_log_scales))
def call(self, x, training):
"""Computes rate and distortion losses."""
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
# Compute latent space representation y, perturb it and model its entropy,
# then compute the reconstructed pixel-level representation x_hat.
y = self.analysis_transform(x)
entropy_model = tfc.ContinuousBatchedEntropyModel(
self.prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=training)
x_tilde = self.synthesis_transform(y_tilde)
# Average number of bits per MNIST digit.
rate = tf.reduce_mean(rate)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
return dict(rate=rate, distortion=distortion)
Compute rate and distortion.
Let's walk through this step by step, using one image from the training set. Load the MNIST dataset for training and validation:
training_dataset, validation_dataset = tfds.load(
"mnist",
split=["train", "test"],
shuffle_files=True,
as_supervised=True,
with_info=False,
)
2024-07-19 01:53:15.049496: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices...
And extract one image \(x\):
(x, _), = validation_dataset.take(1)
plt.imshow(tf.squeeze(x))
print(f"Data type: {x.dtype}")
print(f"Shape: {x.shape}")
Data type: <dtype: 'uint8'> Shape: (28, 28, 1) 2024-07-19 01:53:15.383276: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
To get the latent representation \(y\), we need to cast it to float32
, add a batch dimension, and pass it through the analysis transform.
x = tf.cast(x, tf.float32) / 255.
x = tf.reshape(x, (-1, 28, 28, 1))
y = make_analysis_transform(10)(x)
print("y:", y)
y: tf.Tensor( [[-0.01224455 -0.09719235 -0.08213592 0.0354024 -0.01443382 0.02162577 0.02967148 0.00232092 0.00181769 0.00430147]], shape=(1, 10), dtype=float32)
The latents will be quantized at test time. To model this in a differentiable way during training, we add uniform noise in the interval \((-.5, .5)\) and call the result \(\tilde y\). This is the same terminology as used in the paper End-to-end Optimized Image Compression.
y_tilde = y + tf.random.uniform(y.shape, -.5, .5)
print("y_tilde:", y_tilde)
y_tilde: tf.Tensor( [[ 0.38850513 -0.3627419 -0.10774756 0.03274522 -0.4638402 0.2084787 0.00465386 0.45559373 0.45222282 -0.46006057]], shape=(1, 10), dtype=float32)
The "prior" is a probability density that we train to model the marginal distribution of the noisy latents. For example, it could be a set of independent logistic distributions with different scales for each latent dimension. tfc.NoisyLogistic
accounts for the fact that the latents have additive noise. As the scale approaches zero, a logistic distribution approaches a dirac delta (spike), but the added noise causes the "noisy" distribution to approach the uniform distribution instead.
prior = tfc.NoisyLogistic(loc=0., scale=tf.linspace(.01, 2., 10))
_ = tf.linspace(-6., 6., 501)[:, None]
plt.plot(_, prior.prob(_));
During training, tfc.ContinuousBatchedEntropyModel
adds uniform noise, and uses the noise and the prior to compute a (differentiable) upper bound on the rate (the average number of bits necessary to encode the latent representation). That bound can be minimized as a loss.
entropy_model = tfc.ContinuousBatchedEntropyModel(
prior, coding_rank=1, compression=False)
y_tilde, rate = entropy_model(y, training=True)
print("rate:", rate)
print("y_tilde:", y_tilde)
rate: tf.Tensor([18.526083], shape=(1,), dtype=float32) y_tilde: tf.Tensor( [[ 0.0090554 -0.38909417 -0.4069785 0.18274103 0.2406526 -0.11575054 0.28057152 0.30737367 -0.13117756 -0.22494133]], shape=(1, 10), dtype=float32)
Lastly, the noisy latents are passed back through the synthesis transform to produce an image reconstruction \(\tilde x\). Distortion is the error between original image and reconstruction. Obviously, with the transforms untrained, the reconstruction is not very useful.
x_tilde = make_synthesis_transform()(y_tilde)
# Mean absolute difference across pixels.
distortion = tf.reduce_mean(abs(x - x_tilde))
print("distortion:", distortion)
x_tilde = tf.saturate_cast(x_tilde[0] * 255, tf.uint8)
plt.imshow(tf.squeeze(x_tilde))
print(f"Data type: {x_tilde.dtype}")
print(f"Shape: {x_tilde.shape}")
distortion: tf.Tensor(0.17078552, shape=(), dtype=float32) Data type: <dtype: 'uint8'> Shape: (28, 28, 1)
For every batch of digits, calling the MNISTCompressionTrainer
produces the rate and distortion as an average over that batch:
(example_batch, _), = validation_dataset.batch(32).take(1)
trainer = MNISTCompressionTrainer(10)
example_output = trainer(example_batch)
print("rate: ", example_output["rate"])
print("distortion: ", example_output["distortion"])
rate: tf.Tensor(20.296253, shape=(), dtype=float32) distortion: tf.Tensor(0.14659302, shape=(), dtype=float32) 2024-07-19 01:53:16.195986: W tensorflow/core/kernels/data/cache_dataset_ops.cc:854] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
In the next section, we set up the model to do gradient descent on these two losses.
Train the model.
We compile the trainer in a way that it optimizes the rate–distortion Lagrangian, that is, a sum of rate and distortion, where one of the terms is weighted by Lagrange parameter \(\lambda\).
This loss function affects the different parts of the model differently:
- The analysis transform is trained to produce a latent representation that achieves the desired trade-off between rate and distortion.
- The synthesis transform is trained to minimize distortion, given the latent representation.
- The parameters of the prior are trained to minimize the rate given the latent representation. This is identical to fitting the prior to the marginal distribution of latents in a maximum likelihood sense.
def pass_through_loss(_, x):
# Since rate and distortion are unsupervised, the loss doesn't need a target.
return x
def make_mnist_compression_trainer(lmbda, latent_dims=50):
trainer = MNISTCompressionTrainer(latent_dims)
trainer.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3),
# Just pass through rate and distortion as losses/metrics.
loss=dict(rate=pass_through_loss, distortion=pass_through_loss),
metrics=dict(rate=pass_through_loss, distortion=pass_through_loss),
loss_weights=dict(rate=1., distortion=lmbda),
)
return trainer
Next, train the model. The human annotations are not necessary here, since we just want to compress the images, so we drop them using a map
and instead add "dummy" targets for rate and distortion.
def add_rd_targets(image, label):
# Training is unsupervised, so labels aren't necessary here. However, we
# need to add "dummy" targets for rate and distortion.
return image, dict(rate=0., distortion=0.)
def train_mnist_model(lmbda):
trainer = make_mnist_compression_trainer(lmbda)
trainer.fit(
training_dataset.map(add_rd_targets).batch(128).prefetch(8),
epochs=15,
validation_data=validation_dataset.map(add_rd_targets).batch(128).cache(),
validation_freq=1,
verbose=1,
)
return trainer
trainer = train_mnist_model(lmbda=2000)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 216.9254 - distortion_loss: 0.0584 - rate_loss: 100.1568 - distortion_pass_through_loss: 0.0584 - rate_pass_through_loss: 100.1521 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 13s 22ms/step - loss: 216.9254 - distortion_loss: 0.0584 - rate_loss: 100.1568 - distortion_pass_through_loss: 0.0584 - rate_pass_through_loss: 100.1521 - val_loss: 176.1546 - val_distortion_loss: 0.0419 - val_rate_loss: 92.4025 - val_distortion_pass_through_loss: 0.0419 - val_rate_pass_through_loss: 92.4103 Epoch 2/15 469/469 [==============================] - 10s 20ms/step - loss: 165.6864 - distortion_loss: 0.0409 - rate_loss: 83.9722 - distortion_pass_through_loss: 0.0409 - rate_pass_through_loss: 83.9679 - val_loss: 155.8378 - val_distortion_loss: 0.0399 - val_rate_loss: 76.1234 - val_distortion_pass_through_loss: 0.0399 - val_rate_pass_through_loss: 76.1243 Epoch 3/15 469/469 [==============================] - 10s 20ms/step - loss: 150.9844 - distortion_loss: 0.0400 - rate_loss: 71.0840 - distortion_pass_through_loss: 0.0399 - rate_pass_through_loss: 71.0809 - val_loss: 144.5865 - val_distortion_loss: 0.0402 - val_rate_loss: 64.2065 - val_distortion_pass_through_loss: 0.0402 - val_rate_pass_through_loss: 64.2150 Epoch 4/15 469/469 [==============================] - 9s 20ms/step - loss: 142.6878 - distortion_loss: 0.0398 - rate_loss: 63.0357 - distortion_pass_through_loss: 0.0398 - rate_pass_through_loss: 63.0338 - val_loss: 136.4241 - val_distortion_loss: 0.0403 - val_rate_loss: 55.7424 - val_distortion_pass_through_loss: 0.0403 - val_rate_pass_through_loss: 55.7691 Epoch 5/15 469/469 [==============================] - 9s 20ms/step - loss: 137.4838 - distortion_loss: 0.0396 - rate_loss: 58.2308 - distortion_pass_through_loss: 0.0396 - rate_pass_through_loss: 58.2295 - val_loss: 132.1830 - val_distortion_loss: 0.0412 - val_rate_loss: 49.8589 - val_distortion_pass_through_loss: 0.0412 - val_rate_pass_through_loss: 49.8711 Epoch 6/15 469/469 [==============================] - 9s 20ms/step - loss: 133.8861 - distortion_loss: 0.0394 - rate_loss: 55.1402 - distortion_pass_through_loss: 0.0394 - rate_pass_through_loss: 55.1388 - val_loss: 127.9782 - val_distortion_loss: 0.0415 - val_rate_loss: 45.0345 - val_distortion_pass_through_loss: 0.0415 - val_rate_pass_through_loss: 45.0470 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 130.8392 - distortion_loss: 0.0389 - rate_loss: 52.9458 - distortion_pass_through_loss: 0.0389 - rate_pass_through_loss: 52.9443 - val_loss: 124.2168 - val_distortion_loss: 0.0408 - val_rate_loss: 42.7138 - val_distortion_pass_through_loss: 0.0408 - val_rate_pass_through_loss: 42.7179 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 128.3935 - distortion_loss: 0.0386 - rate_loss: 51.1929 - distortion_pass_through_loss: 0.0386 - rate_pass_through_loss: 51.1917 - val_loss: 121.7899 - val_distortion_loss: 0.0406 - val_rate_loss: 40.6796 - val_distortion_pass_through_loss: 0.0405 - val_rate_pass_through_loss: 40.6837 Epoch 9/15 469/469 [==============================] - 9s 20ms/step - loss: 125.8994 - distortion_loss: 0.0381 - rate_loss: 49.6562 - distortion_pass_through_loss: 0.0381 - rate_pass_through_loss: 49.6556 - val_loss: 119.4453 - val_distortion_loss: 0.0391 - val_rate_loss: 41.2929 - val_distortion_pass_through_loss: 0.0391 - val_rate_pass_through_loss: 41.3038 Epoch 10/15 469/469 [==============================] - 9s 20ms/step - loss: 123.7347 - distortion_loss: 0.0377 - rate_loss: 48.2651 - distortion_pass_through_loss: 0.0377 - rate_pass_through_loss: 48.2641 - val_loss: 117.1507 - val_distortion_loss: 0.0387 - val_rate_loss: 39.7558 - val_distortion_pass_through_loss: 0.0387 - val_rate_pass_through_loss: 39.7709 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 121.6866 - distortion_loss: 0.0373 - rate_loss: 47.0960 - distortion_pass_through_loss: 0.0373 - rate_pass_through_loss: 47.0950 - val_loss: 115.9093 - val_distortion_loss: 0.0379 - val_rate_loss: 40.1509 - val_distortion_pass_through_loss: 0.0379 - val_rate_pass_through_loss: 40.1661 Epoch 12/15 469/469 [==============================] - 9s 20ms/step - loss: 119.8163 - distortion_loss: 0.0369 - rate_loss: 46.0885 - distortion_pass_through_loss: 0.0369 - rate_pass_through_loss: 46.0875 - val_loss: 115.0018 - val_distortion_loss: 0.0372 - val_rate_loss: 40.5832 - val_distortion_pass_through_loss: 0.0372 - val_rate_pass_through_loss: 40.6000 Epoch 13/15 469/469 [==============================] - 9s 20ms/step - loss: 118.4900 - distortion_loss: 0.0366 - rate_loss: 45.2273 - distortion_pass_through_loss: 0.0366 - rate_pass_through_loss: 45.2263 - val_loss: 113.9260 - val_distortion_loss: 0.0372 - val_rate_loss: 39.5021 - val_distortion_pass_through_loss: 0.0372 - val_rate_pass_through_loss: 39.5158 Epoch 14/15 469/469 [==============================] - 9s 19ms/step - loss: 116.9452 - distortion_loss: 0.0362 - rate_loss: 44.5765 - distortion_pass_through_loss: 0.0362 - rate_pass_through_loss: 44.5763 - val_loss: 113.4185 - val_distortion_loss: 0.0365 - val_rate_loss: 40.5167 - val_distortion_pass_through_loss: 0.0365 - val_rate_pass_through_loss: 40.5147 Epoch 15/15 469/469 [==============================] - 9s 19ms/step - loss: 115.8957 - distortion_loss: 0.0359 - rate_loss: 44.0527 - distortion_pass_through_loss: 0.0359 - rate_pass_through_loss: 44.0523 - val_loss: 111.9197 - val_distortion_loss: 0.0360 - val_rate_loss: 39.8947 - val_distortion_pass_through_loss: 0.0360 - val_rate_pass_through_loss: 39.8986
Compress some MNIST images.
For compression and decompression at test time, we split the trained model in two parts:
- The encoder side consists of the analysis transform and the entropy model.
- The decoder side consists of the synthesis transform and the same entropy model.
At test time, the latents will not have additive noise, but they will be quantized and then losslessly compressed, so we give them new names. We call them and the image reconstruction \(\hat x\) and \(\hat y\), respectively (following End-to-end Optimized Image Compression).
class MNISTCompressor(tf.keras.Model):
"""Compresses MNIST images to strings."""
def __init__(self, analysis_transform, entropy_model):
super().__init__()
self.analysis_transform = analysis_transform
self.entropy_model = entropy_model
def call(self, x):
# Ensure inputs are floats in the range (0, 1).
x = tf.cast(x, self.compute_dtype) / 255.
y = self.analysis_transform(x)
# Also return the exact information content of each digit.
_, bits = self.entropy_model(y, training=False)
return self.entropy_model.compress(y), bits
class MNISTDecompressor(tf.keras.Model):
"""Decompresses MNIST images from strings."""
def __init__(self, entropy_model, synthesis_transform):
super().__init__()
self.entropy_model = entropy_model
self.synthesis_transform = synthesis_transform
def call(self, string):
y_hat = self.entropy_model.decompress(string, ())
x_hat = self.synthesis_transform(y_hat)
# Scale and cast back to 8-bit integer.
return tf.saturate_cast(tf.round(x_hat * 255.), tf.uint8)
When instantiated with compression=True
, the entropy model converts the learned prior into tables for a range coding algorithm. When calling compress()
, this algorithm is invoked to convert the latent space vector into bit sequences. The length of each binary string approximates the information content of the latent (the negative log likelihood of the latent under the prior).
The entropy model for compression and decompression must be the same instance, because the range coding tables need to be exactly identical on both sides. Otherwise, decoding errors can occur.
def make_mnist_codec(trainer, **kwargs):
# The entropy model must be created with `compression=True` and the same
# instance must be shared between compressor and decompressor.
entropy_model = tfc.ContinuousBatchedEntropyModel(
trainer.prior, coding_rank=1, compression=True, **kwargs)
compressor = MNISTCompressor(trainer.analysis_transform, entropy_model)
decompressor = MNISTDecompressor(entropy_model, trainer.synthesis_transform)
return compressor, decompressor
compressor, decompressor = make_mnist_codec(trainer)
Grab 16 images from the validation dataset. You can select a different subset by changing the argument to skip
.
(originals, _), = validation_dataset.batch(16).skip(3).take(1)
Compress them to strings, and keep track of each of their information content in bits.
strings, entropies = compressor(originals)
print(f"String representation of first digit in hexadecimal: 0x{strings[0].numpy().hex()}")
print(f"Number of bits actually needed to represent it: {entropies[0]:0.2f}")
String representation of first digit in hexadecimal: 0x0d1866b8f1 Number of bits actually needed to represent it: 37.49
Decompress the images back from the strings.
reconstructions = decompressor(strings)
Display each of the 16 original digits together with its compressed binary representation, and the reconstructed digit.
def display_digits(originals, strings, entropies, reconstructions):
"""Visualizes 16 digits together with their reconstructions."""
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(12.5, 5))
axes = axes.ravel()
for i in range(len(axes)):
image = tf.concat([
tf.squeeze(originals[i]),
tf.zeros((28, 14), tf.uint8),
tf.squeeze(reconstructions[i]),
], 1)
axes[i].imshow(image)
axes[i].text(
.5, .5, f"→ 0x{strings[i].numpy().hex()} →\n{entropies[i]:0.2f} bits",
ha="center", va="top", color="white", fontsize="small",
transform=axes[i].transAxes)
axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
display_digits(originals, strings, entropies, reconstructions)
Note that the length of the encoded string differs from the information content of each digit.
This is because the range coding process works with discrete probabilities, and has a small amount of overhead. So, especially for short strings, the correspondence is only approximate. However, range coding is asymptotically optimal: in the limit, the expected bit count will approach the cross entropy (the expected information content), for which the rate term in the training model is an upper bound.
The rate–distortion trade-off
Above, the model was trained for a specific trade-off (given by lmbda=2000
) between the average number of bits used to represent each digit and the incurred error in the reconstruction.
What happens when we repeat the experiment with different values?
Let's start by reducing \(\lambda\) to 500.
def train_and_visualize_model(lmbda):
trainer = train_mnist_model(lmbda=lmbda)
compressor, decompressor = make_mnist_codec(trainer)
strings, entropies = compressor(originals)
reconstructions = decompressor(strings)
display_digits(originals, strings, entropies, reconstructions)
train_and_visualize_model(lmbda=500)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 127.8276 - distortion_loss: 0.0705 - rate_loss: 92.5567 - distortion_pass_through_loss: 0.0705 - rate_pass_through_loss: 92.5504 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 12s 21ms/step - loss: 127.8276 - distortion_loss: 0.0705 - rate_loss: 92.5567 - distortion_pass_through_loss: 0.0705 - rate_pass_through_loss: 92.5504 - val_loss: 108.3089 - val_distortion_loss: 0.0574 - val_rate_loss: 79.6299 - val_distortion_pass_through_loss: 0.0574 - val_rate_pass_through_loss: 79.6331 Epoch 2/15 469/469 [==============================] - 9s 20ms/step - loss: 97.6987 - distortion_loss: 0.0548 - rate_loss: 70.2879 - distortion_pass_through_loss: 0.0548 - rate_pass_through_loss: 70.2826 - val_loss: 86.5739 - val_distortion_loss: 0.0598 - val_rate_loss: 56.6536 - val_distortion_pass_through_loss: 0.0598 - val_rate_pass_through_loss: 56.6590 Epoch 3/15 469/469 [==============================] - 9s 20ms/step - loss: 81.6310 - distortion_loss: 0.0570 - rate_loss: 53.1313 - distortion_pass_through_loss: 0.0570 - rate_pass_through_loss: 53.1278 - val_loss: 72.4750 - val_distortion_loss: 0.0688 - val_rate_loss: 38.0917 - val_distortion_pass_through_loss: 0.0687 - val_rate_pass_through_loss: 38.1038 Epoch 4/15 469/469 [==============================] - 9s 20ms/step - loss: 71.8922 - distortion_loss: 0.0601 - rate_loss: 41.8549 - distortion_pass_through_loss: 0.0601 - rate_pass_through_loss: 41.8529 - val_loss: 64.1137 - val_distortion_loss: 0.0785 - val_rate_loss: 24.8831 - val_distortion_pass_through_loss: 0.0784 - val_rate_pass_through_loss: 24.8904 Epoch 5/15 469/469 [==============================] - 9s 20ms/step - loss: 66.2340 - distortion_loss: 0.0629 - rate_loss: 34.7989 - distortion_pass_through_loss: 0.0629 - rate_pass_through_loss: 34.7976 - val_loss: 58.3210 - val_distortion_loss: 0.0801 - val_rate_loss: 18.2739 - val_distortion_pass_through_loss: 0.0801 - val_rate_pass_through_loss: 18.2635 Epoch 6/15 469/469 [==============================] - 9s 20ms/step - loss: 62.6940 - distortion_loss: 0.0649 - rate_loss: 30.2491 - distortion_pass_through_loss: 0.0649 - rate_pass_through_loss: 30.2479 - val_loss: 54.5973 - val_distortion_loss: 0.0814 - val_rate_loss: 13.8960 - val_distortion_pass_through_loss: 0.0813 - val_rate_pass_through_loss: 13.9088 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 60.2058 - distortion_loss: 0.0663 - rate_loss: 27.0736 - distortion_pass_through_loss: 0.0663 - rate_pass_through_loss: 27.0724 - val_loss: 51.4746 - val_distortion_loss: 0.0775 - val_rate_loss: 12.7405 - val_distortion_pass_through_loss: 0.0775 - val_rate_pass_through_loss: 12.7322 Epoch 8/15 469/469 [==============================] - 9s 20ms/step - loss: 58.0009 - distortion_loss: 0.0666 - rate_loss: 24.7138 - distortion_pass_through_loss: 0.0666 - rate_pass_through_loss: 24.7133 - val_loss: 49.4696 - val_distortion_loss: 0.0734 - val_rate_loss: 12.7863 - val_distortion_pass_through_loss: 0.0734 - val_rate_pass_through_loss: 12.7662 Epoch 9/15 469/469 [==============================] - 9s 20ms/step - loss: 55.9776 - distortion_loss: 0.0662 - rate_loss: 22.8791 - distortion_pass_through_loss: 0.0662 - rate_pass_through_loss: 22.8781 - val_loss: 48.1887 - val_distortion_loss: 0.0700 - val_rate_loss: 13.1711 - val_distortion_pass_through_loss: 0.0701 - val_rate_pass_through_loss: 13.1664 Epoch 10/15 469/469 [==============================] - 9s 20ms/step - loss: 54.0630 - distortion_loss: 0.0652 - rate_loss: 21.4487 - distortion_pass_through_loss: 0.0652 - rate_pass_through_loss: 21.4477 - val_loss: 47.5509 - val_distortion_loss: 0.0689 - val_rate_loss: 13.0906 - val_distortion_pass_through_loss: 0.0689 - val_rate_pass_through_loss: 13.0824 Epoch 11/15 469/469 [==============================] - 9s 20ms/step - loss: 52.5058 - distortion_loss: 0.0641 - rate_loss: 20.4323 - distortion_pass_through_loss: 0.0641 - rate_pass_through_loss: 20.4320 - val_loss: 47.0983 - val_distortion_loss: 0.0660 - val_rate_loss: 14.0991 - val_distortion_pass_through_loss: 0.0660 - val_rate_pass_through_loss: 14.0968 Epoch 12/15 469/469 [==============================] - 9s 20ms/step - loss: 51.2286 - distortion_loss: 0.0632 - rate_loss: 19.6388 - distortion_pass_through_loss: 0.0632 - rate_pass_through_loss: 19.6387 - val_loss: 46.6349 - val_distortion_loss: 0.0643 - val_rate_loss: 14.4766 - val_distortion_pass_through_loss: 0.0643 - val_rate_pass_through_loss: 14.4723 Epoch 13/15 469/469 [==============================] - 9s 20ms/step - loss: 50.2295 - distortion_loss: 0.0624 - rate_loss: 19.0260 - distortion_pass_through_loss: 0.0624 - rate_pass_through_loss: 19.0255 - val_loss: 46.3414 - val_distortion_loss: 0.0644 - val_rate_loss: 14.1262 - val_distortion_pass_through_loss: 0.0644 - val_rate_pass_through_loss: 14.1229 Epoch 14/15 469/469 [==============================] - 9s 19ms/step - loss: 49.4662 - distortion_loss: 0.0619 - rate_loss: 18.5085 - distortion_pass_through_loss: 0.0619 - rate_pass_through_loss: 18.5085 - val_loss: 46.0598 - val_distortion_loss: 0.0640 - val_rate_loss: 14.0695 - val_distortion_pass_through_loss: 0.0640 - val_rate_pass_through_loss: 14.0671 Epoch 15/15 469/469 [==============================] - 9s 20ms/step - loss: 48.8475 - distortion_loss: 0.0615 - rate_loss: 18.0763 - distortion_pass_through_loss: 0.0615 - rate_pass_through_loss: 18.0759 - val_loss: 45.9036 - val_distortion_loss: 0.0638 - val_rate_loss: 13.9967 - val_distortion_pass_through_loss: 0.0639 - val_rate_pass_through_loss: 13.9826
The bit rate of our code goes down, as does the fidelity of the digits. However, most of the digits remain recognizable.
Let's reduce \(\lambda\) further.
train_and_visualize_model(lmbda=300)
Epoch 1/15 469/469 [==============================] - ETA: 0s - loss: 113.9453 - distortion_loss: 0.0765 - rate_loss: 91.0087 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 91.0019 WARNING:absl:Computing quantization offsets using offset heuristic within a tf.function. Ideally, the offset heuristic should only be used to determine offsets once after training. Depending on the prior, estimating the offset might be computationally expensive. 469/469 [==============================] - 11s 20ms/step - loss: 113.9453 - distortion_loss: 0.0765 - rate_loss: 91.0087 - distortion_pass_through_loss: 0.0764 - rate_pass_through_loss: 91.0019 - val_loss: 96.5798 - val_distortion_loss: 0.0668 - val_rate_loss: 76.5345 - val_distortion_pass_through_loss: 0.0669 - val_rate_pass_through_loss: 76.5314 Epoch 2/15 469/469 [==============================] - 9s 19ms/step - loss: 85.7681 - distortion_loss: 0.0609 - rate_loss: 67.4997 - distortion_pass_through_loss: 0.0609 - rate_pass_through_loss: 67.4941 - val_loss: 73.9959 - val_distortion_loss: 0.0764 - val_rate_loss: 51.0617 - val_distortion_pass_through_loss: 0.0764 - val_rate_pass_through_loss: 51.0703 Epoch 3/15 469/469 [==============================] - 9s 20ms/step - loss: 68.7888 - distortion_loss: 0.0645 - rate_loss: 49.4513 - distortion_pass_through_loss: 0.0645 - rate_pass_through_loss: 49.4474 - val_loss: 57.9014 - val_distortion_loss: 0.0860 - val_rate_loss: 32.1110 - val_distortion_pass_through_loss: 0.0860 - val_rate_pass_through_loss: 32.1095 Epoch 4/15 469/469 [==============================] - 9s 19ms/step - loss: 58.2579 - distortion_loss: 0.0691 - rate_loss: 37.5138 - distortion_pass_through_loss: 0.0691 - rate_pass_through_loss: 37.5116 - val_loss: 49.0095 - val_distortion_loss: 0.0996 - val_rate_loss: 19.1215 - val_distortion_pass_through_loss: 0.0997 - val_rate_pass_through_loss: 19.1154 Epoch 5/15 469/469 [==============================] - 9s 19ms/step - loss: 52.0423 - distortion_loss: 0.0736 - rate_loss: 29.9553 - distortion_pass_through_loss: 0.0736 - rate_pass_through_loss: 29.9535 - val_loss: 42.9857 - val_distortion_loss: 0.1058 - val_rate_loss: 11.2561 - val_distortion_pass_through_loss: 0.1058 - val_rate_pass_through_loss: 11.2495 Epoch 6/15 469/469 [==============================] - 9s 19ms/step - loss: 48.1614 - distortion_loss: 0.0773 - rate_loss: 24.9860 - distortion_pass_through_loss: 0.0773 - rate_pass_through_loss: 24.9847 - val_loss: 39.5561 - val_distortion_loss: 0.1074 - val_rate_loss: 7.3465 - val_distortion_pass_through_loss: 0.1074 - val_rate_pass_through_loss: 7.3374 Epoch 7/15 469/469 [==============================] - 9s 20ms/step - loss: 45.4303 - distortion_loss: 0.0800 - rate_loss: 21.4250 - distortion_pass_through_loss: 0.0800 - rate_pass_through_loss: 21.4242 - val_loss: 36.2512 - val_distortion_loss: 0.1000 - val_rate_loss: 6.2472 - val_distortion_pass_through_loss: 0.1001 - val_rate_pass_through_loss: 6.2349 Epoch 8/15 469/469 [==============================] - 9s 19ms/step - loss: 43.2415 - distortion_loss: 0.0816 - rate_loss: 18.7648 - distortion_pass_through_loss: 0.0816 - rate_pass_through_loss: 18.7640 - val_loss: 34.6368 - val_distortion_loss: 0.0951 - val_rate_loss: 6.0988 - val_distortion_pass_through_loss: 0.0951 - val_rate_pass_through_loss: 6.0970 Epoch 9/15 469/469 [==============================] - 9s 20ms/step - loss: 41.3811 - distortion_loss: 0.0823 - rate_loss: 16.6768 - distortion_pass_through_loss: 0.0823 - rate_pass_through_loss: 16.6763 - val_loss: 33.9006 - val_distortion_loss: 0.0924 - val_rate_loss: 6.1897 - val_distortion_pass_through_loss: 0.0923 - val_rate_pass_through_loss: 6.1920 Epoch 10/15 469/469 [==============================] - 9s 19ms/step - loss: 39.6697 - distortion_loss: 0.0818 - rate_loss: 15.1418 - distortion_pass_through_loss: 0.0818 - rate_pass_through_loss: 15.1415 - val_loss: 33.1051 - val_distortion_loss: 0.0859 - val_rate_loss: 7.3338 - val_distortion_pass_through_loss: 0.0859 - val_rate_pass_through_loss: 7.3265 Epoch 11/15 469/469 [==============================] - 9s 19ms/step - loss: 38.1470 - distortion_loss: 0.0804 - rate_loss: 14.0248 - distortion_pass_through_loss: 0.0804 - rate_pass_through_loss: 14.0245 - val_loss: 32.6648 - val_distortion_loss: 0.0827 - val_rate_loss: 7.8640 - val_distortion_pass_through_loss: 0.0827 - val_rate_pass_through_loss: 7.8598 Epoch 12/15 469/469 [==============================] - 9s 19ms/step - loss: 36.9021 - distortion_loss: 0.0790 - rate_loss: 13.2012 - distortion_pass_through_loss: 0.0790 - rate_pass_through_loss: 13.2009 - val_loss: 32.3153 - val_distortion_loss: 0.0811 - val_rate_loss: 7.9952 - val_distortion_pass_through_loss: 0.0811 - val_rate_pass_through_loss: 7.9855 Epoch 13/15 469/469 [==============================] - 9s 20ms/step - loss: 35.9055 - distortion_loss: 0.0778 - rate_loss: 12.5702 - distortion_pass_through_loss: 0.0778 - rate_pass_through_loss: 12.5696 - val_loss: 32.1251 - val_distortion_loss: 0.0800 - val_rate_loss: 8.1255 - val_distortion_pass_through_loss: 0.0800 - val_rate_pass_through_loss: 8.1183 Epoch 14/15 469/469 [==============================] - 9s 20ms/step - loss: 35.1264 - distortion_loss: 0.0768 - rate_loss: 12.0743 - distortion_pass_through_loss: 0.0768 - rate_pass_through_loss: 12.0742 - val_loss: 31.9446 - val_distortion_loss: 0.0774 - val_rate_loss: 8.7320 - val_distortion_pass_through_loss: 0.0774 - val_rate_pass_through_loss: 8.7187 Epoch 15/15 469/469 [==============================] - 9s 19ms/step - loss: 34.5168 - distortion_loss: 0.0761 - rate_loss: 11.6742 - distortion_pass_through_loss: 0.0761 - rate_pass_through_loss: 11.6742 - val_loss: 31.8502 - val_distortion_loss: 0.0768 - val_rate_loss: 8.8203 - val_distortion_pass_through_loss: 0.0768 - val_rate_pass_through_loss: 8.8122
The strings begin to get much shorter now, on the order of one byte per digit. However, this comes at a cost. More digits are becoming unrecognizable.
This demonstrates that this model is agnostic to human perceptions of error, it just measures the absolute deviation in terms of pixel values. To achieve a better perceived image quality, we would need to replace the pixel loss with a perceptual loss.
Use the decoder as a generative model.
If we feed the decoder random bits, this will effectively sample from the distribution that the model learned to represent digits.
First, re-instantiate the compressor/decompressor without a sanity check that would detect if the input string isn't completely decoded.
compressor, decompressor = make_mnist_codec(trainer, decode_sanity_check=False)
Now, feed long enough random strings into the decompressor so that it can decode/sample digits from them.
import os
strings = tf.constant([os.urandom(8) for _ in range(16)])
samples = decompressor(strings)
fig, axes = plt.subplots(4, 4, sharex=True, sharey=True, figsize=(5, 5))
axes = axes.ravel()
for i in range(len(axes)):
axes[i].imshow(tf.squeeze(samples[i]))
axes[i].axis("off")
plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)