Published: 18.12.2023

Idempotent Generative Networks

Introduction

Recently, I don't really recall how / by which medium, I stepped upon Idempotent Generative Network by Shocher et. al. The paper was very well written, presented a simple yet novel and elegant idea and kicked off with a funny quote:

GEORGE: You're gonna”overdry”it.
JERRY: You, you can't ”overdry.”
GEORGE: "Why not?"
JERRY: Same as you can't ”overwet.” You see, once something is wet, it's wet. Same thing with dead: like once you die you're dead, right? Let's say you drop dead and I shoot you: you're not gonna die again, you're already dead. You can't ”overdie,” you can't ”overdry.”

"Seinfield", Season 1, Episode 1, NBC 1989

Underneath the irony of this quote, there is an interesting idea: Some concepts are binary by nature. Something is either wet or dry. Someone is either dead or alive (Schrödinger's cat laughs in super-position). A sample is either in-distribution or out of distribution. So functions that bring this binary state to one of the two ends, like drying, are idempotent. If you dry something, it gets dry. If you dry it again, it will stay dry just as it was.

Driven by this last observation, the question (in the context of generative models) is: What if we learned a model that could "in-distribufy"? In other words, what if we enforced the fact that a prior distribution (random noise) has to be mapped to our data distribution, but things in-distribution cannot be made yet more in-distribution?

Then, we could draw something like this:

Idempotent Generative Network

Figure 1 from Idempotent Generative Network by Shocher et. al.

where we have samples z\mathbf{z} from a source distribution that have to be put in-distribution with a function fθf_{\theta}, but fθf_{\theta} has to leave in-distribution samples x\mathbf{x} untouched. This gives rise to two objectives:

fθ(x)=xf_{\theta}(\mathbf{x}) = \mathbf{x}

Reconstruction term

fθ(fθ(z))=fθ(z)f_{\theta}(f_{\theta}(\mathbf{z})) = f_{\theta}(\mathbf{z})

Idempotent term

where with the first, we encourage the function to give us back the same sample for in-disitribution samples, and with the second we encourage the function to be idempotent, that is, applying it twice should give the same effect as applying it just once (like when drying).

There is still something missing though: both objectives are perfectly satisfied if fθf_\theta is the identity function. In fact, the identity function is the most basic idempotent function there is, and as of right now we are only optimizing for idempotence.

This is where the authors make the key observation that there are two pathways for the gradient: one, which is the desired one, encourages fθ(z)f_\theta(\mathbf{z}) to put things into target distribution. In fact, if fθ(z)xf_\theta(\mathbf{z}) \approx \mathbf{x}, since we already encourage fθ(x)=xf_\theta(\mathbf{x}) = \mathbf{x} we automatically have that fθ(fθ(z))=fθ(z)f_{\theta}(f_{\theta}(\mathbf{z})) = f_{\theta}(\mathbf{z}). The other pathway instead, which we want to avoid, encourages fθf_\theta to act as an identity regardless if the given input is in distribution or not. This is well represented with the second figure from the paper:

Pathways of gradients

Figure 2 from Idempotent Generative Network by Shocher et. al.

In the figure, we see a data manifold in blue and the prior distribution as the sphere / circle. The red pathway Δf\color{red}{\Delta f} pushes fθ(z)f_\theta(\mathbf{z}) to be mapped onto the data manifold, where fθ(fθ(z))=fθ(z)f_\theta(f_\theta(\mathbf{z}))=f_\theta(\mathbf{z}) will also be enforced by fθ(x)=xf_\theta(\mathbf{x}) = \mathbf{x}. The green pathway Δf\color{green}{\Delta f} is instead trying to map back to whatever we got in the first run, whether it was on the data manifold or not.

So the key idea now is to favour the good pathway, while discouraging the bad one. In this way, we get the desired behaviour out of the model by tightening the data manifold! This gives rise to a third term:

fθ(fθ(z))fθ(z)f_{\theta}(f_{\theta}(\mathbf{z})) \ne f_{\theta}(\mathbf{z})

Tightening term

Which is the exact opposite of the idempotent term ...Wait, what?

Yes, we are in fact optimizing two exact opposing things. We want to make the function idempotent, but also to make it the exact opposite of idempotent. However, here's the catch: We want the function to be idempotent only within the data manifold, and to be quite the exact opposite of idempotent outside!

To do so, we encourage the model to output something that will remain the same whether it is mapped through a copy of the model again or not. At the same time, we encourage the model to map something that already went through the model once to something as different as possible.

This is perhaps better explained with an example. Let's assume that fθ(z)f_\theta(\mathbf{z}) maps onto the data manifold. Then, despite the fact that we have one term trying to make fθ(fθ(z))fθ(z)f_\theta(f_\theta(\mathbf{z})) \ne f_\theta(\mathbf{z}), we also have two terms opposing this effect (the term on the inner part and the reconstruction term). Our function will act idempotent inside the data manifold.Let's now assume that fθ(z)f_\theta(\mathbf{z}) maps out of distribution. We now have only one term fighting this effect, since the reconstruction term only works on the data manifold. Our function will thus try to map things closer to the data manifold.

Here is the total loss function:

Lrec(θ)=Ex[D(fθ(x),x)]\mathcal{L}_{\text{rec}}(\theta) = \mathbb{E}_\mathbf{x} [ D(f_\theta(\mathbf{x}), \mathbf{x})]
Lidem(z;θ,θ^)=D(fθ^(fθ(z)),fθ(z))\mathcal{L}_{\text{idem}}(\mathbf{z}; \theta, \hat{\theta}) = D(f_{\hat{\theta}}(f_\theta(\mathbf{z})), f_\theta(\mathbf{z}))
Ltight(z;θ,θ^)=D(fθ(fθ^(z)),fθ^(z))\mathcal{L}_{\text{tight}}(\mathbf{z}; \theta, \hat{\theta}) = - D(f_\theta(f_{\hat{\theta}}(\mathbf{z})), f_{\hat{\theta}}(\mathbf{z}))

where DD is any distance metric like L1, L2, etc... and we really only optimize θ\theta, whereas for θ^\hat{\theta} we use a copy of the model that we do not optimize. Ultimately, the total loss function is the sum of these terms weighted by some hyper-parameters:

L(θ,θ^)=Lrec(θ)+λiLidem(θ,θ^)+λtLtight(θ,θ^)\mathcal{L}(\theta, \hat{\theta}) = \mathcal{L}_{\text{rec}}(\theta) + \lambda_i \mathcal{L}_{\text{idem}}(\theta, \hat{\theta}) + \lambda_t \mathcal{L}_{\text{tight}}(\theta, \hat{\theta})

Implementation

Given the simple yet elegant idea, and the fact that this could be tried quite quickly, I could not resist re-implementing this paper. My full re-implementation is available on GitHub.

The real highlight is the IdempotentNetwork lightning module, which can be used to train any model architecture using the above defined objectives:

1from copy import deepcopy
2
3from torch.optim import Adam
4from torch.nn import L1Loss
5import pytorch_lightning as pl
6
7
8class IdempotentNetwork(pl.LightningModule):
9    def __init__(
10        self,
11        prior,
12        model,
13        lr=1e-4,
14        criterion=L1Loss(),
15        lrec_w=20.0,
16        lidem_w=20.0,
17        ltight_w=2.5,
18    ):
19        super(IdempotentNetwork, self).__init__()
20        self.prior = prior
21        self.model = model
22        self.model_copy = deepcopy(model)
23        self.lr = lr
24        self.criterion = criterion
25        self.lrec_w = lrec_w
26        self.lidem_w = lidem_w
27        self.ltight_w = ltight_w
28
29    def forward(self, x):
30        return self.model(x)
31
32    def configure_optimizers(self):
33        optim = Adam(self.model.parameters(), lr=self.lr, betas=(0.5, 0.999))
34        return optim
35
36    def get_losses(self, x):
37        # Prior samples
38        z = self.prior.sample_n(x.shape[0]).to(x.device)
39
40        # Updating the copy
41        self.model_copy.load_state_dict(self.model.state_dict())
42
43        # Forward passes
44        fx = self(x)
45        fz = self(z)
46        fzd = fz.detach()
47
48        l_rec = self.lrec_w * self.criterion(fx, x)
49        l_idem = self.lidem_w * self.criterion(self.model_copy(fz), fz)
50        l_tight = -self.ltight_w * self.criterion(self(fzd), fzd)
51
52        return l_rec, l_idem, l_tight
53
54    def training_step(self, batch, batch_idx):
55        l_rec, l_idem, l_tight = self.get_losses(batch)
56        loss = l_rec + l_idem + l_tight
57
58        self.log_dict(
59            {
60                "train/loss_rec": l_rec,
61                "train/loss_idem": l_idem,
62                "train/loss_tight": l_tight,
63                "train/loss": l_rec + l_idem + l_tight,
64            },
65            sync_dist=True,
66        )
67
68        return loss
69
70    def validation_step(self, batch, batch_idx):
71        l_rec, l_idem, l_tight = self.get_losses(batch)
72        loss = l_rec + l_idem + l_tight
73
74        self.log_dict(
75            {
76                "val/loss_rec": l_rec,
77                "val/loss_idem": l_idem,
78                "val/loss_tight": l_tight,
79                "val/loss": loss,
80            },
81            sync_dist=True,
82        )
83
84    def test_step(self, batch, batch_idx):
85        l_rec, l_idem, l_tight = self.get_losses(batch)
86        loss = l_rec + l_idem + l_tight
87
88        self.log_dict(
89            {
90                "test/loss_rec": l_rec,
91                "test/loss_idem": l_idem,
92                "test/loss_tight": l_tight,
93                "test/loss": loss,
94            },
95            sync_dist=True,
96        )
97
98    def generate_n(self, n, device=None):
99        z = self.prior.sample_n(n)
100
101        if device is not None:
102            z = z.to(device)
103
104        return self(z)
105

The model does just what we covered above: with l_rec we encourage the function to act as the identity for in-distribution samples, with l_idem we encourage the model to output something that will remain the same whether it is mapped through a copy of the model again or not and, finally, with l_tight we encourage the model to make input and output as different as possible (for when the model acts as a copy trying to disrupt its own output). This whole idea is quite unique and brilliant to be fair.

Now that we can train any model with the IGN objectives, we just need to write a classical training boilerplate code to try and generate MNIST digits. I went for my favourite stack: Pytorch Lightning with Weights and Biases:

1import os
2from argparse import ArgumentParser
3
4import torch
5from torch.utils.data import DataLoader
6from torchvision.datasets import MNIST
7from torchvision.utils import save_image
8from torchvision.transforms import Compose, ToTensor, Lambda
9import pytorch_lightning as pl
10from pytorch_lightning.loggers import WandbLogger
11from pytorch_lightning.callbacks import ModelCheckpoint
12
13from model import DCGANLikeModel
14from ign import IdempotentNetwork
15
16
17def main(args):
18    # Set seed
19    pl.seed_everything(args["seed"])
20
21    # Load datas
22    normalize = Lambda(lambda x: (x - 0.5) * 2)
23    noise = Lambda(lambda x: (x + torch.randn_like(x) * 0.15).clamp(-1, 1))
24    train_transform = Compose([ToTensor(), normalize, noise])
25    val_transform = Compose([ToTensor(), normalize])
26
27    train_set = MNIST(
28        root="mnist", train=True, download=True, transform=train_transform
29    )
30    val_set = MNIST(root="mnist", train=False, download=True, transform=val_transform)
31
32    def collate_fn(samples):
33        return torch.stack([sample[0] for sample in samples])
34
35    train_loader = DataLoader(
36        train_set,
37        batch_size=args["batch_size"],
38        shuffle=True,
39        collate_fn=collate_fn,
40        num_workers=args["num_workers"],
41    )
42    val_loader = DataLoader(
43        val_set,
44        batch_size=args["batch_size"],
45        shuffle=False,
46        collate_fn=collate_fn,
47        num_workers=args["num_workers"],
48    )
49
50    # Initialize model
51    prior = torch.distributions.Normal(torch.zeros(1, 28, 28), torch.ones(1, 28, 28))
52    net = DCGANLikeModel()
53    model = IdempotentNetwork(prior, net, args["lr"])
54
55    if not args["skip_train"]:
56        # Train model
57        logger = WandbLogger(name="IGN", project="Papers Re-implementations")
58        callbacks = [
59            ModelCheckpoint(
60                monitor="val/loss",
61                mode="min",
62                dirpath="checkpoints",
63                filename="best",
64            )
65        ]
66        trainer = pl.Trainer(
67            strategy="ddp",
68            accelerator="auto",
69            max_epochs=args["epochs"],
70            logger=logger,
71            callbacks=callbacks,
72        )
73        trainer.fit(model, train_loader, val_loader)
74
75    # Loading the best model
76    device = "cuda" if torch.cuda.is_available() else "cpu"
77    model = (
78        IdempotentNetwork.load_from_checkpoint(
79            "checkpoints/best.ckpt", prior=prior, model=net
80        )
81        .eval()
82        .to(device)
83    )
84
85    # Generating images with the trained model
86    os.makedirs("generated", exist_ok=True)
87
88    images = model.generate_n(100, device=device)
89    save_image(images, "generated.png", nrow=10, normalize=True)
90
91    print("Done!")
92
93
94if __name__ == "__main__":
95    parser = ArgumentParser()
96    parser.add_argument("--seed", type=int, default=0)
97    parser.add_argument("--lr", type=float, default=1e-4)
98    parser.add_argument("--batch_size", type=int, default=256)
99    parser.add_argument("--epochs", type=int, default=50)
100    parser.add_argument("--num_workers", type=int, default=8)
101    parser.add_argument("--skip_train", action="store_true")
102    args = vars(parser.parse_args())
103
104    main(args)

At first, I found training of IGNs to be unstable. I suspected this might be the case, since we basically have an instance of adversarial training. With adversarial training, like in GANs, you might face the problem where the two actors keep on changing, entering a loop where neither converges because each one has to adapt to the moving counterpart.

However, I was skeptical of the model used in this work, DCGAN from Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks by Radford et. al., which being from 2015 is stone-age old and does not include modern architectural choices. For example, it completely lacks dropout and normalization layers, as well as residual connections, which are key to any model nowadays. While no big models are needed to generate MNIST digits, doing without these core components felt wrong. I suspected that this might further harm training stability, so I opted to add batch normalizations and dropout layers here and there.

1"""DCGAN code from https://github.com/kpandey008/dcgan"""
2import torch.nn as nn
3
4
5class Discriminator(nn.Module):
6    def __init__(self, in_channels=1, base_c=64):
7        super(Discriminator, self).__init__()
8        self.main = nn.Sequential(
9            # Input Size: 1 x 28 x 28
10            nn.Conv2d(in_channels, base_c, 4, 2, 1, bias=False),
11            nn.Dropout2d(0.1),
12            nn.LeakyReLU(negative_slope=0.2, inplace=True),
13            # Input Size: 32 x 14 x 14
14            nn.BatchNorm2d(base_c),
15            nn.Conv2d(base_c, base_c * 2, 4, 2, 1, bias=False),
16            nn.Dropout2d(0.1),
17            nn.LeakyReLU(negative_slope=0.2, inplace=True),
18            # Input Size: 64 x 7 x 7
19            nn.BatchNorm2d(base_c * 2),
20            nn.Conv2d(base_c * 2, base_c * 4, 3, 1, 0, bias=False),
21            nn.Dropout2d(0.1),
22            nn.LeakyReLU(negative_slope=0.2, inplace=True),
23            # Input Size: 128 x 7 x 7
24            nn.BatchNorm2d(base_c * 4),
25            nn.Conv2d(base_c * 4, base_c * 8, 3, 1, 0, bias=False),
26            nn.Dropout2d(0.1),
27            nn.LeakyReLU(negative_slope=0.2, inplace=True),
28            # Input Size: 256 x 7 x 7
29            nn.Conv2d(base_c * 8, base_c * 8, 3, 1, 0, bias=False),
30        )
31
32    def forward(self, input):
33        return self.main(input)
34
35
36class Generator(nn.Module):
37    def __init__(self, in_channels=512, out_channels=1):
38        super(Generator, self).__init__()
39        self.main = nn.Sequential(
40            # Input Size: 256 x 7 x 7
41            nn.BatchNorm2d(in_channels),
42            nn.ConvTranspose2d(in_channels, in_channels // 2, 3, 1, 0, bias=False),
43            nn.Dropout2d(0.1),
44            nn.ReLU(True),
45            # Input Size: 128 x 7 x 7
46            nn.BatchNorm2d(in_channels // 2),
47            nn.ConvTranspose2d(in_channels // 2, in_channels // 4, 3, 1, 0, bias=False),
48            nn.Dropout2d(0.1),
49            nn.ReLU(True),
50            # Input Size: 64 x 7 x 7
51            nn.BatchNorm2d(in_channels // 4),
52            nn.ConvTranspose2d(in_channels // 4, in_channels // 8, 3, 1, 0, bias=False),
53            nn.Dropout2d(0.1),
54            nn.ReLU(True),
55            # Input Size: 32 x 14 x 14
56            nn.BatchNorm2d(in_channels // 8),
57            nn.ConvTranspose2d(
58                in_channels // 8, in_channels // 16, 4, 2, 1, bias=False
59            ),
60            nn.Dropout2d(0.1),
61            nn.ReLU(True),
62            # Input Size : 16 x 28 x 28
63            nn.ConvTranspose2d(in_channels // 16, out_channels, 4, 2, 1, bias=False),
64            nn.Tanh(),
65            # Final Output : 1 x 28 x 28
66        )
67
68    def forward(self, input):
69        return self.main(input)
70
71
72class DCGANLikeModel(nn.Module):
73    def __init__(self, in_channels=1, base_c=64):
74        super(DCGANLikeModel, self).__init__()
75        self.discriminator = Discriminator(in_channels=in_channels, base_c=base_c)
76        self.generator = Generator(base_c * 8, out_channels=in_channels)
77
78    def forward(self, x):
79        return self.generator(self.discriminator(x))
80

Surprisingly, this fixed the stability issues during training, although I would not be surprised if training IGNs would turn out to be difficult for some datasets.

Here I share the Weights and Biases run of the final model that I used to generate MNIST digits.

And of course, ✨ dulcis in fundo ✨, here are the generated images:

Idempotent Generative Network

Generated images with Idempotent Generative Network

Thank you for reading! If you found this helpful / interesting, or have suggestions on how to improve, please do not hesitate to contact me at me@brianpulfer.ch