Modèle de diffusion conditionné par la classe

Dans ce notebook, nous allons illustrer une façon d’ajouter des informations de conditionnement à un modèle de diffusion. Plus précisément, nous allons entraîner un modèle de diffusion conditionné par la classe sur MNIST à la suite de l’exemple d’entraînement à partir de 0 de l’unité 1, où nous pouvons spécifier quel chiffre nous voulons que le modèle génère au moment de l’inférence.

Comme indiqué dans l’introduction de cette unité, il s’agit d’une des nombreuses façons d’ajouter des informations de conditionnement supplémentaires à un modèle de diffusion, et elle a été choisie pour sa relative simplicité. Tout comme le notebook de l’unité 1, ce notebook n’a qu’un but illustratif et vous pouvez l’ignorer si vous le souhaitez.

Configuration et préparation des données

!pip install -q diffusers
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

device = 'mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
# Charger le jeu de données
dataset = torchvision.datasets.MNIST(root="mnist/", train=True, download=True, transform=torchvision.transforms.ToTensor())

# Introduire les données dans un chargeur de données (batch de taille 8 ici pour la démonstration)
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Visualiser quelques exemples
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys')
Input shape: torch.Size([8, 1, 28, 28]) 
Labels: tensor([8, 1, 5, 9, 7, 6, 2, 2])
Bref aperçu du contenu du cours.

Création d’une UNet conditionnée par la classe

La façon dont nous introduirons le conditionnement de la classe est la suivante :

  • Créer un UNet2DModel standard avec quelques canaux d’entrée supplémentaires.
  • Associer l’étiquette de la classe à un vecteur appris de forme (class_emb_size) via une couche d’enchâssement.
  • Concaténer ces informations en tant que canaux supplémentaires pour l’entrée interne du UNet avec net_input = torch.cat((x, class_cond), 1)
  • Introduire ce net_input (qui a (class_emb_size+1) canaux au total) dans l’UNet pour obtenir la prédiction finale.

Dans cet exemple, nous avons fixé la taille de class_emb_size à 4, mais c’est complètement arbitraire et vous pourriez envisager de la fixer à 1 (pour voir si cela fonctionne toujours), à 10 (pour correspondre au nombre de classes), ou de remplacer le nn.Embedding appris par un simple encodage à un coup de l’étiquette de la classe directement.

Voici à quoi ressemble l’implémentation :

class ClassConditionedUNet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()
    
    # La couche d'intégration associe l'étiquette de la classe à un vecteur de taille `class_emb_size`
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model est un UNet inconditionnel avec des canaux d'entrée supplémentaires pour accepter les informations de conditionnement (l'enchâssement de la classe).
    self.model = UNet2DModel(
        sample_size=28,           # la résolution de l'image cible
        in_channels=1 + class_emb_size, # Canaux d'entrée supplémentaires pour la classe conditionnée
        out_channels=1,           # le nombre de canaux de sortie
        layers_per_block=2,       # le nombre de couches ResNet à utiliser par bloc UNet
        block_out_channels=(32, 64, 64), 
        down_block_types=( 
            "DownBlock2D",        # un bloc de sous-échantillonnage ResNet standard
            "AttnDownBlock2D",    # un bloc de sous-échantillonnage ResNet avec auto-attention spatiale
            "AttnDownBlock2D",
        ), 
        up_block_types=(
            "AttnUpBlock2D", 
            "AttnUpBlock2D",      # un bloc de suréchantillonnage ResNet avec auto-attention spatiale
            "UpBlock2D",          # un bloc de suréchantillonnage ResNet standard
          ),
    )

  # Notre méthode de transfert prend maintenant les étiquettes de la classe comme argument supplémentaire
  def forward(self, x, t, class_labels):
    # Forme de x :
    bs, ch, w, h = x.shape
    
    # conditionnement de la classe en bon état pour l'ajouter comme canaux d'entrée supplémentaires
    class_cond = self.class_emb(class_labels) # Map to embedding dinemsion
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h)
    # x est de forme (bs, 1, 28, 28) et class_cond est maintenant (bs, 4, 28, 28)

    # L'entrée nette est maintenant x et la classe cond concaténée ensemble le long de la dimension 1
    net_input = torch.cat((x, class_cond), 1) # (bs, 5, 28, 28)

    # Cette information est transmise à l'UNet en même temps que le pas de temps et renvoie la prédiction
    return self.model(net_input, t).sample # (bs, 1, 28, 28)

Si l’une des formes ou des transformations vous semble confuse, ajoutez des print() pour afficher les formes pertinentes et vérifiez qu’elles correspondent à vos attentes. Nous avons également annoté les formes de certaines variables intermédiaires dans l’espoir de rendre les choses plus claires.

Entraînement et échantillonnage

Alors qu’auparavant nous faisions quelque chose comme prediction = UNet(x, t), nous allons maintenant ajouter les bonnes étiquettes comme troisième argument (prediction = UNet(x, t, y)) pendant l’entraînement, et lors de l’inférence nous pouvons passer les étiquettes que nous voulons et si tout va bien le modèle devrait générer des images qui correspondent. $y$ dans ce cas est l’étiquette des chiffres MNIST, avec des valeurs de 0 à 9.

La boucle d’entraînement est très similaire à l’exemple de l’unité 1. Nous prédisons maintenant le bruit (plutôt que l’image débruitée comme dans l’unité 1) pour correspondre à l’objectif attendu par le DDPMScheduler par défaut que nous utilisons pour ajouter du bruit pendant l’entraînement et pour générer des échantillons au moment de l’inférence. L’entraînement prend du temps. L’accélérer pourrait être un mini-projet amusant, mais la plupart d’entre vous peuvent probablement parcourir le code (et en fait tout ce notebook) sans l’exécuter puisque nous ne faisons qu’illustrer une idée.

# Créer un planificateur
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')
# Redéfinition du chargeur de données pour fixer la taille du batch à un niveau supérieur à la démonstration de 8
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

# Combien de fois devrions-nous passer les données en revue ?
n_epochs = 10

# Notre réseau 
net = ClassConditionedUNet().to(device)

# Notre fonction de perte
loss_fn = nn.MSELoss()

# L'optimiseur
opt = torch.optim.Adam(net.parameters(), lr=1e-3) 

# Conserver une trace des pertes pour les consulter ultérieurement
losses = []

# La boucle d'entraînement
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):
        
        # Obtenir des données et préparer la version corrompue
        x = x.to(device) * 2 - 1 # Données sur le GPU (sur (-1, 1))
        y = y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # Obtenir la prédiction du modèle
        pred = net(noisy_x, timesteps, y) # Notez que nous passons les étiquettes y

        # Calculer la perte
        loss = loss_fn(pred, noise) # Quelle est la distance entre la sortie et le bruit ?

        # Rétropopagation et mise à jour des paramètres :
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Stocker la perte pour plus tard
        losses.append(loss.item())

    # Afficher la moyenne des 100 dernières valeurs de perte pour vous faire une idée de la progression :
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

# Visualiser la courbe des pertes
plt.plot(losses)
Bref aperçu du contenu du cours.

Une fois l’entraînement terminé, nous pouvons échantillonner quelques images en introduisant différentes étiquettes comme conditionnement :

# Préparer un x aléatoire comme point de départ, ainsi que les étiquettes souhaitées y
x = torch.randn(80, 1, 28, 28).to(device)
y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)

# Boucle d'échantillonnage
for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # Obtenir la prédiction du modèle
    with torch.no_grad():
        residual = net(x, t, y)  # Notez à nouveau que nous transmettons nos étiquettes y

    # Mise à jour de l'échantillon avec l'étape
    x = noise_scheduler.step(residual, t, x).prev_sample

# Montrer les résultats
fig, ax = plt.subplots(1, 1, figsize=(12, 12))
ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
Bref aperçu du contenu du cours.

Nous y voilà ! Nous pouvons maintenant contrôler les images produites.

Nous espérons que cet exemple vous a plu. Comme toujours, n’hésitez pas à poser des questions sur Discord.

✏️ À votre tour !
Essayez de refaire la même chose avec FashionMNIST. Modifiez le taux d’apprentissage, la taille du batch et le nombre d’époques.

Pouvez-vous obtenir des images de mode décentes avec moins de temps d’entraînement que dans l’exemple ci-dessus ?