Finetuning et guidage
Dans ce notebook, nous allons couvrir deux approches principales pour adapter les modĂšles de diffusion existants :
- Avec le finetuning, nous rĂ©entraĂźnerons les modĂšles existants sur de nouvelles donnĂ©es afin de modifier le type de rĂ©sultats quâils produisent.
- Avec le guidage, nous prenons un modĂšle existant et dirigeons le processus de gĂ©nĂ©ration au moment de lâinfĂ©rence pour un contrĂŽle supplĂ©mentaire.
Ce que vous apprendrez :
A la fin de ce notebook, vous saurez comment :
- CrĂ©er une boucle dâĂ©chantillonnage et gĂ©nĂ©rer des Ă©chantillons plus rapidement Ă lâaide dâun nouveau planificateur
- Finetuner un modÚle de diffusion existant sur de nouvelles données, y compris :
- Utiliser lâaccumulation du gradient pour contourner certains des problĂšmes liĂ©s aux petits batchs.
- Enregistrer les Ă©chantillons dans Weights and Biases pendant lâentraĂźnement pour suivre la progression (via le script dâexemple joint).
- Sauvegarder le pipeline résultant et le télécharger sur le Hub
- Guider le processus dâĂ©chantillonnage avec des fonctions de perte supplĂ©mentaires pour ajouter un contrĂŽle sur les modĂšles existants, y compris :
- Explorer différentes approches de guidage avec une simple perte basée sur la couleur
- Utiliser CLIP pour guider la gĂ©nĂ©ration Ă lâaide dâun prompt de texte
- Partager une boucle dâĂ©chantillonnage personnalisĂ©e en utilisant Gradio et đ€ Spaces.
Configuration et importations
Pour enregistrer vos modĂšles finetunĂ©s sur le Hub dâHugging Face, vous devrez vous connecter avec un token qui a un accĂšs en Ă©criture. Le code ci-dessous vous invite Ă le faire et vous renvoie Ă la page des tokens de votre compte. Vous aurez Ă©galement besoin dâun compte Weights and Biases si vous souhaitez utiliser le script dâentraĂźnement pour enregistrer des Ă©chantillons au fur et Ă mesure que le modĂšle sâentraĂźne. LĂ encore, le code devrait vous inviter Ă vous connecter lĂ oĂč câest nĂ©cessaire.
A part cela, la seule chose Ă faire est dâinstaller quelques dĂ©pendances, dâimporter tout ce dont nous aurons besoin et de spĂ©cifier lâappareil que nous utiliserons :
!pip install -qq diffusers datasets accelerate wandb open-clip-torch
# Code pour se connecter au Hub d'Hugging Face, nécessaire pour partager les modÚles
# Assurez-vous d'utiliser un *token* avec un accÚs WRITE (écriture)
from huggingface_hub import notebook_login
notebook_login()
Token is valid.
Your token has been saved in your configured git credential helpers (store).
Your token has been saved to /root/.huggingface/token
Login successful
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
device = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
Chargement dâun pipeline prĂ©-entraĂźnĂ©
Pour commencer ce notebook, chargeons un pipeline existant et voyons ce que nous pouvons en faire :
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
image_pipe.to(device);
La gĂ©nĂ©ration dâimages est aussi simple que lâexĂ©cution de la mĂ©thode __call__ du pipeline en lâappelant comme une fonction :
images = image_pipe().images
images[0]
Sympathique, mais LENT ! Avant dâaborder les sujets principaux du jour, jetons un coup dâĆil Ă la boucle dâĂ©chantillonnage proprement dite et voyons comment nous pouvons utiliser un Ă©chantillonneur plus sophistiquĂ© pour lâaccĂ©lĂ©rer.
Ăchantillonnage plus rapide avec DDIM
Ă chaque Ă©tape, le modĂšle est nourri dâune entrĂ©e bruyante et il lui est demandĂ© de prĂ©dire le bruit (et donc une estimation de ce Ă quoi lâimage entiĂšrement dĂ©bruitĂ©e pourrait ressembler). Au dĂ©part, ces prĂ©dictions ne sont pas trĂšs bonnes, câest pourquoi nous dĂ©composons le processus en plusieurs Ă©tapes. Cependant, lâutilisation de plus de 1000 Ă©tapes sâest avĂ©rĂ©e inutile, et une multitude de recherches rĂ©centes ont explorĂ© la maniĂšre dâobtenir de bons Ă©chantillons avec le moins dâĂ©tapes possible.
Dans la bibliothĂšque đ€ Diffusers, ces mĂ©thodes dâĂ©chantillonnage sont gĂ©rĂ©es par un planificateur, qui doit effectuer chaque mise Ă jour via la fonction step(). Pour gĂ©nĂ©rer une image, on commence par un bruit alĂ©atoire $x$. Ensuite, pour chaque pas de temps dans le planificateur de bruit, nous introduisons lâentrĂ©e bruitĂ©e $x$ dans le modĂšle et transmettons la prĂ©diction rĂ©sultante Ă la fonction step(). Celle-ci renvoie une sortie avec un attribut prev_sample. âpreviousâ parce que nous revenons en arriĂšre dans le temps, dâun niveau de bruit Ă©levĂ© Ă un niveau de bruit faible (Ă lâinverse du processus de diffusion vers lâavant).
Voyons cela en action ! Tout dâabord, nous chargeons un planificateur, ici un DDIMScheduler basĂ© sur le papier Denoising Diffusion Implicit Models qui peut donner des Ă©chantillons dĂ©cents en beaucoup moins dâĂ©tapes que lâimplĂ©mentation originale du DDPM :
# Créer un nouveau planificateur et définir le nombre d'étapes d'inférence
scheduler = DDIMScheduler.from_pretrained("google/ddpm-celebahq-256")
scheduler.set_timesteps(num_inference_steps=40)
Vous pouvez constater que ce modÚle effectue 40 étapes au total, chaque saut équivalant à 25 étapes du programme original de 1000 étapes :
scheduler.timesteps
tensor([975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650,
625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300,
275, 250, 225, 200, 175, 150, 125, 100, 75, 50, 25, 0])
CrĂ©ons 4 images alĂ©atoires et exĂ©cutons la boucle dâĂ©chantillonnage, en visualisant Ă la fois le $x$ actuel et la version dĂ©bruitĂ©e prĂ©dite au fur et Ă mesure de lâavancement du processus :
# Le point de départ aléatoire
x = torch.randn(4, 3, 256, 256).to(device) # Batch de 4 images Ă 3 canaux de 256 x 256 px
# Boucle sur les pas de temps d'échantillonnage
for i, t in tqdm(enumerate(scheduler.timesteps)):
# Préparer l'entrée du modÚle
model_input = scheduler.scale_model_input(x, t)
# Obtenir la prédiction
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
# Calculer la forme que devrait prendre l'échantillon mis à jour à l'aide du planificateur
scheduler_output = scheduler.step(noise_pred, t, x)
# Mise Ă jour de x
x = scheduler_output.prev_sample
# Occasionnellement, afficher à la fois x et les images débruitées prédites
if i % 10 == 0 or i == len(scheduler.timesteps) - 1:
fig, axs = plt.subplots(1, 2, figsize=(12, 5))
grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)
axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[0].set_title(f"Current x (step {i})")
pred_x0 = (
scheduler_output.pred_original_sample
) # Non disponible pour tous les planificateurs
grid = torchvision.utils.make_grid(pred_x0, nrow=4).permute(1, 2, 0)
axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
axs[1].set_title(f"Predicted denoised images (step {i})")
plt.show()
Comme vous pouvez le voir, les prĂ©dictions initiales ne sont pas trĂšs bonnes, mais au fur et Ă mesure que le processus se poursuit, les rĂ©sultats prĂ©dits deviennent de plus en plus prĂ©cis. Si vous ĂȘtes curieux de savoir ce qui se passe Ă lâintĂ©rieur de la fonction step(), inspectez le code (bien commentĂ©) avec :
# ??scheduler.step
Vous pouvez également insérer ce nouveau planificateur à la place du planificateur original fourni avec le pipeline, et échantillonner de la maniÚre suivante :
image_pipe.scheduler = scheduler
images = image_pipe(num_inference_steps=40).images
images[0]
TrÚs bien, nous pouvons maintenant obtenir des échantillons dans un délai raisonnable ! Cela devrait accélérer les choses au fur et à mesure que nous avançons dans le reste de ce notebook :)
Finetuning
Et maintenant, le plus amusant ! Ătant donnĂ© ce pipeline prĂ©-entraĂźnĂ©, comment pouvons-nous rĂ©entraĂźner le modĂšle pour gĂ©nĂ©rer des images sur la base de nouvelles donnĂ©es dâentraĂźnement ?
Il sâavĂšre que cela est presque identique Ă entraĂźner un modĂšle Ă partir de zĂ©ro (comme nous lâavons vu dans lâunitĂ© 1), sauf que nous commençons avec le modĂšle existant. Voyons cela en action et abordons quelques considĂ©rations supplĂ©mentaires au fur et Ă mesure.
Tout dâabord, le jeu de donnĂ©es : vous pouvez essayer ce jeu de donnĂ©es de visages vintage ou ces visages animĂ©s pour quelque chose de plus proche des donnĂ©es dâentraĂźnement originales de ce modĂšle de visages. Mais pour le plaisir, utilisons plutĂŽt le mĂȘme petit jeu de donnĂ©es de papillons que nous avons utilisĂ© pour nous entraĂźner Ă partir de zĂ©ro dans lâunitĂ© 1. ExĂ©cutez le code ci-dessous pour tĂ©lĂ©charger le jeu de donnĂ©es papillons et crĂ©er un chargeur de donnĂ©es Ă partir duquel nous pouvons Ă©chantillonner un batch dâimages :
# Pas sur Colab ? Les commentaires avec #@ permettent de modifier l'interface utilisateur comme les titres ou les entrées
# mais peuvent ĂȘtre ignorĂ©s si vous travaillez sur une plateforme diffĂ©rente.
dataset_name = "huggan/smithsonian_butterflies_subset" # @param
dataset = load_dataset(dataset_name, split="train")
image_size = 256 # @param
batch_size = 4 # @param
preprocess = transforms.Compose(
[
transforms.Resize((image_size, image_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def transform(examples):
images = [preprocess(image.convert("RGB")) for image in examples["image"]]
return {"images": images}
dataset.set_transform(transform)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=True
)
print("Previewing batch:")
batch = next(iter(train_dataloader))
grid = torchvision.utils.make_grid(batch["images"], nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5);
ConsidĂ©ration 1 : notre taille de batch ici (4) est assez petite, puisque nous entraĂźnons sur une grande taille dâimage (256 pixels) en utilisant un modĂšle assez grand et que nous manquerons de RAM du GPU si nous augmentons trop la taille du batch. Vous pouvez rĂ©duire la taille de lâimage pour accĂ©lĂ©rer les choses et permettre des batchs plus importants, mais ces modĂšles ont Ă©tĂ© conçus et entraĂźnĂ©s Ă lâorigine pour une gĂ©nĂ©ration de 256 pixels.
Passons maintenant Ă la boucle dâentraĂźnement. Nous allons mettre Ă jour les poids du modĂšle prĂ©-entraĂźnĂ© en fixant la cible dâoptimisation Ă image_pipe.unet.parameters(). Le reste est presque identique Ă lâexemple de boucle dâentraĂźnement de lâunitĂ© 1. Cela prend environ 10 minutes Ă exĂ©cuter sur Colab, câest donc le bon moment pour prendre un cafĂ© ou un thĂ© pendant que vous attendez :
num_epochs = 2 # @param
lr = 1e-5 # 2param
grad_accumulation_steps = 2 # @param
optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr)
losses = []
for epoch in range(num_epochs):
for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
clean_images = batch["images"].to(device)
# bruit Ă ajouter aux images
noise = torch.randn(clean_images.shape).to(clean_images.device)
bs = clean_images.shape[0]
# un pas de temps aléatoire pour chaque image
timesteps = torch.randint(
0,
image_pipe.scheduler.num_train_timesteps,
(bs,),
device=clean_images.device,
).long()
# Ajouter du bruit aux images propres en fonction de la magnitude du bruit Ă chaque pas de temps
# (il s'agit du processus de diffusion vers l'avant)
noisy_images = image_pipe.scheduler.add_noise(clean_images, noise, timesteps)
# Obtenir la prédiction du modÚle pour le bruit
noise_pred = image_pipe.unet(noisy_images, timesteps, return_dict=False)[0]
# Comparez la prédiction avec le bruit réel :
loss = F.mse_loss(
noise_pred, noise
) # NB : essayer de prédire le bruit (eps) pas (noisy_ims-clean_ims) ou juste (clean_ims)
# Stocker pour un plot ultérieur
losses.append(loss.item())
# Mettre Ă jour les paramĂštres du modĂšle avec l'optimiseur sur la base de cette perte
loss.backward(loss)
# Accumulation des gradients
if (step + 1) % grad_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
print(
f"Epoch {epoch} average loss: {sum(losses[-len(train_dataloader):])/len(train_dataloader)}"
)
# Tracer la courbe de perte :
plt.plot(losses)
ConsidĂ©ration 2 : notre signal de perte est extrĂȘmement bruyant, puisque nous ne travaillons quâavec quatre exemples Ă des niveaux de bruit alĂ©atoires pour chaque Ă©tape. Ce nâest pas idĂ©al pour lâentraĂźnement. Une solution consiste Ă utiliser un taux dâapprentissage extrĂȘmement faible pour limiter la taille de la mise Ă jour Ă chaque Ă©tape. Ce serait encore mieux si nous pouvions trouver un moyen dâobtenir les mĂȘmes avantages quâen utilisant une taille de batch plus importante sans que les besoins en mĂ©moire ne montent en flĂšcheâŠ
Entrez dans lâaccumulation des gradients. Si nous appelons loss.backward() plusieurs fois avant dâexĂ©cuter optimizer.step() et optimizer.zero_grad(), PyTorch accumule (somme) les gradients, fusionnant effectivement le signal de plusieurs batchs pour donner une seule (meilleure) estimation qui est ensuite utilisĂ©e pour mettre Ă jour les paramĂštres. Il en rĂ©sulte moins de mises Ă jour totales, tout comme nous le verrions si nous utilisions une taille de batch plus importante. Câest quelque chose que de nombreux frameworks gĂšrent pour vous (par exemple, đ€ Accelerate rend cela facile), mais il est agrĂ©able de le voir mis en Ćuvre Ă partir de zĂ©ro car il sâagit dâune technique utile pour traiter lâentraĂźnement sous les contraintes de mĂ©moire du GPU ! Comme vous pouvez le voir dans le code ci-dessus (aprĂšs le commentaire # Gradient accumulation), il nây a pas vraiment besoin de beaucoup de code.
âïž Ă votre tour ! Voyez si vous pouvez ajouter lâaccumulation des gradients Ă la boucle dâentraĂźnement de lâunitĂ© 1. Comment se comporte-t-elle ? RĂ©flĂ©chissez Ă la maniĂšre dont vous pourriez ajuster le taux dâapprentissage en fonction du nombre dâĂ©tapes dâaccumulation des gradients ; devrait-il rester identique Ă auparavant ?
ConsidĂ©ration 3 : Cela prend encore beaucoup de temps, et afficher une mise Ă jour dâune ligne Ă chaque Ă©poque nâest pas suffisant pour nous donner une bonne idĂ©e de ce qui se passe. Nous devrions probablement :
- GĂ©nĂ©rer quelques Ă©chantillons de temps en temps pour examiner visuellement la performance qualitativement au fur et Ă mesure que le modĂšle sâentraĂźne.
- Enregistrer des Ă©lĂ©ments tels que la perte et les gĂ©nĂ©rations dâĂ©chantillons pendant lâentraĂźnement, peut-ĂȘtre en utilisant quelque chose comme Weights and Biases ou Tensorboard.
Nous avons créé un script rapide (finetune_model.py) qui reprend le code dâentraĂźnement ci-dessus et y ajoute une fonctionnalitĂ© minimale de logging. Vous pouvez voir les logs dâun entraĂźnement ci-dessous :
%wandb johnowhitaker/dm_finetune/2upaa341 # Vous aurez besoin d'un compte W&B pour que cela fonctionne - sautez si vous ne voulez pas vous connecter.
Il est amusant de voir comment les Ă©chantillons gĂ©nĂ©rĂ©s changent au fur et Ă mesure que lâentraĂźnement progresse. MĂȘme si la perte ne semble pas sâamĂ©liorer beaucoup, on peut voir une progression du domaine original (images de chambres Ă coucher) vers les nouvelles donnĂ©es dâentraĂźnement (wikiart). A la fin de ce notebook se trouve un code commentĂ© pour finetunĂ© un modĂšle en utilisant ce script comme alternative Ă lâexĂ©cution de la cellule ci-dessus.
âïž Ă votre tour ! Voyez si vous pouvez modifier lâexemple officiel de script dâentraĂźnement que nous avons vu dans lâunitĂ© 1 pour commencer avec un modĂšle prĂ©-entraĂźnĂ© plutĂŽt que dâentraĂźner Ă partir de zĂ©ro. Comparez-le au script minimal dont le lien figure ci-dessus ; quelles sont les fonctionnalitĂ©s supplĂ©mentaires qui manquent au script minimal ? En gĂ©nĂ©rant quelques images avec ce modĂšle, nous pouvons voir que ces visages ont dĂ©jĂ lâair trĂšs Ă©tranges !
x = torch.randn(8, 3, 256, 256).to(device) # Batch de 8
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5);
ConsidĂ©ration 4 : Le finetuning peut ĂȘtre tout Ă fait imprĂ©visible ! Si nous entraĂźnions plus longtemps, nous pourrions voir des papillons parfaits. Mais les Ă©tapes intermĂ©diaires peuvent ĂȘtre extrĂȘmement intĂ©ressantes en elles-mĂȘmes, surtout si vos intĂ©rĂȘts sont plutĂŽt artistiques ! EntraĂźnez sur des pĂ©riodes trĂšs courtes ou trĂšs longues et faites varier le taux dâapprentissage pour voir comment cela affecte les types de rĂ©sultats produits par le modĂšle final.
Code pour finetuner un modĂšle en utilisant le script dâexemple minimal que nous avons utilisĂ© sur le modĂšle de dĂ©monstration WikiArt
Si vous souhaitez entraßner un modÚle similaire à celui que nous avons créé sur WikiArt, vous pouvez décommenter et exécuter les cellules ci-dessous. Comme cela prend un certain temps et peut épuiser la mémoire de votre GPU, nous vous conseillons de le faire aprÚs avoir parcouru le reste de ce notebook.
## Pour télécharger le script de finetuning :
# !wget https://github.com/huggingface/diffusion-models-class/raw/main/unit2/finetune_model.py
## Pour exécuter le script, entraßnant le modÚle de visage sur des visages vintage
## (l'idéal est d'exécuter ce script dans un terminal) :
# !python finetune_model.py --image_size 128 --batch_size 8 --num_epochs 16\
# --grad_accumulation_steps 2 --start_model "google/ddpm-celebahq-256"\
# --dataset_name "Norod78/Vintage-Faces-FFHQAligned" --wandb_project 'dm-finetune'\
# --log_samples_every 100 --save_model_every 1000 --model_save_name 'vintageface'
Sauvegarde et chargement des pipelines finetunés
Maintenant que nous avons finetuné le UNet dans notre modÚle de diffusion, sauvegardons-le dans un dossier local en exécutant :
image_pipe.save_pretrained("my-finetuned-model")
Comme nous lâavons vu dans lâunitĂ© 1, cela permet de sauvegarder la configuration, le modĂšle et le planificateur :
!ls {"my-finetuned-model"}
Ensuite, vous pouvez suivre les mĂȘmes Ă©tapes que celles dĂ©crites dans le notebook dâintroduction Ă đ€ Diffusers de lâunitĂ© 1 pour pousser le modĂšle vers le Hub en vue dâune utilisation ultĂ©rieure :
# Code pour télécharger un pipeline sauvegardé localement vers le Hub
from huggingface_hub import HfApi, ModelCard, create_repo, get_full_repo_name
# Mise en place du repo et téléchargement des fichiers
model_name = "ddpm-celebahq-finetuned-butterflies-2epochs" # @param Le nom que vous souhaitez lui donner sur le Hub
local_folder_name = "my-finetuned-model" # @param Créé par le script ou par vous via image_pipe.save_pretrained('save_name')
description = "Describe your model here" # @param
hub_model_id = get_full_repo_name(model_name)
create_repo(hub_model_id)
api = HfApi()
api.upload_folder(
folder_path=f"{local_folder_name}/scheduler", path_in_repo="", repo_id=hub_model_id
)
api.upload_folder(
folder_path=f"{local_folder_name}/unet", path_in_repo="", repo_id=hub_model_id
)
api.upload_file(
path_or_fileobj=f"{local_folder_name}/model_index.json",
path_in_repo="model_index.json",
repo_id=hub_model_id,
)
# Ajouter une carte modĂšle (facultatif mais sympa !)
content = f"""
---
license: mit
tags:
- pytorch
- diffusers
- unconditional-image-generation
- diffusion-models-class
---
# Example Fine-Tuned Model for Unit 2 of the [Diffusion Models Class đ§š](https://github.com/huggingface/diffusion-models-class)
{description}
## Usage
```python
from diffusers import DDPMPipeline
pipeline = DDPMPipeline.from_pretrained('{hub_model_id}')
image = pipeline().images[0]
image
```python
"""
card = ModelCard(content)
card.push_to_hub(hub_model_id)
'https://huggingface.co/lewtun/ddpm-celebahq-finetuned-butterflies-2epochs/blob/main/README.md'
Félicitations, vous avez maintenant finetuné votre premier modÚle de diffusion !
Pour le reste de ce notebook, nous utiliserons un modĂšle que nous avons finetunĂ© Ă partir dâun modĂšle entraĂźnĂ© sur LSUN bedrooms environ une fois sur le WikiArt dataset. Si vous prĂ©fĂ©rez, vous pouvez sauter cette cellule et utiliser le pipeline faces/butterflies que nous avons finetunĂ© dans la section prĂ©cĂ©dente ou en charger un depuis le Hub Ă la place :
# Chargement du pipeline pré-entraßné
pipeline_name = "johnowhitaker/sd-class-wikiart-from-bedrooms"
image_pipe = DDPMPipeline.from_pretrained(pipeline_name).to(device)
# Ăchantillon d'images avec un planificateur DDIM sur 40 Ă©tapes
scheduler = DDIMScheduler.from_pretrained(pipeline_name)
scheduler.set_timesteps(num_inference_steps=40)
# Point de départ aléatoire (batch de 8 images)
x = torch.randn(8, 3, 256, 256).to(device)
# Boucle d'échantillonnage minimale
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = scheduler.step(noise_pred, t, x).prev_sample
# Voir les résultats
grid = torchvision.utils.make_grid(x, nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5);
ConsidĂ©ration 5 : Il est souvent difficile de savoir si le finetunĂ© fonctionne bien, et ce que lâon entend par âbonnes performancesâ peut varier selon le cas dâutilisation. Par exemple, si vous finetunĂ© un modĂšle conditionnĂ© par du texte comme Stable Diffusion sur un petit jeu de donnĂ©es, vous voudrez probablement quâil conserve la plus grande partie de son apprentissage original afin de pouvoir comprendre des prompts arbitraires non couverts par votre nouveau jeu de donnĂ©es, tout en sâadaptant pour mieux correspondre au style de vos nouvelles donnĂ©es dâentraĂźnement. Cela pourrait signifier lâutilisation dâun faible taux dâapprentissage avec quelque chose comme la moyenne exponentielle du modĂšle, comme dĂ©montrĂ© dans cet excellent article de blog sur la crĂ©ation dâune version Pokemon de Stable Diffusion. Dans une autre situation, vous pouvez vouloir rĂ©-entraĂźner complĂštement un modĂšle sur de nouvelles donnĂ©es (comme notre exemple chambre â wikiart), auquel cas un taux dâapprentissage plus Ă©levĂ© et un entraĂźnement plus poussĂ© sâavĂšrent judicieux. MĂȘme si le graphique de la perte ne montre pas beaucoup dâamĂ©lioration, les Ă©chantillons sâĂ©loignent clairement des donnĂ©es dâorigine et sâorientent vers des rĂ©sultats plus âartistiquesâ, bien quâils restent pour la plupart incohĂ©rents.
Ce qui nous amĂšne Ă la section suivante, oĂč nous examinons comment nous pourrions ajouter des conseils supplĂ©mentaires Ă un tel modĂšle pour mieux contrĂŽler les rĂ©sultats.
Guidage
Que faire si lâon souhaite exercer un certain contrĂŽle sur les Ă©chantillons gĂ©nĂ©rĂ©s ? Par exemple, supposons que nous voulions biaiser les images gĂ©nĂ©rĂ©es pour quâelles soient dâune couleur spĂ©cifique. Comment procĂ©der ? Câest lĂ quâintervient le guidage, une technique qui permet dâajouter un contrĂŽle supplĂ©mentaire au processus dâĂ©chantillonnage.
La premiĂšre Ă©tape consiste Ă crĂ©er notre fonction de conditionnement : une mesure (perte) que nous souhaitons minimiser. En voici une pour lâexemple de la couleur, qui compare les pixels dâune image Ă une couleur cible (par dĂ©faut, une sorte de sarcelle claire) et renvoie lâerreur moyenne :
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
"""Ătant donnĂ© une couleur cible (R, G, B), retourner une perte correspondant Ă la distance moyenne entre
les pixels de l'image et cette couleur. Par défaut, il s'agit d'une couleur sarcelle claire : (0.1, 0.9, 0.5)"""
target = (
torch.tensor(target_color).to(images.device) * 2 - 1
) # Map target color to (-1, 1)
target = target[
None, :, None, None
] # Obtenir la forme nécessaire pour fonctionner avec les images (b, c, h, w)
error = torch.abs(
images - target
).mean() # Différence absolue moyenne entre les pixels de l'image et la couleur cible
return error
Ensuite, nous allons crĂ©er une version modifiĂ©e de la boucle dâĂ©chantillonnage oĂč, Ă chaque Ă©tape, nous ferons ce qui suit :
- Créer une nouvelle version de
xavecrequires_grad = True - Calculer la version débruitée (
x0) - Introduire la version prédite
x0dans notre fonction de perte - Trouver le gradient de cette fonction de perte par rapport Ă
x - Utiliser ce gradient de conditionnement pour modifier
xavant dâutiliser le planificateur, en espĂ©rant pousser x dans une direction qui conduira Ă une perte plus faible selon notre fonction dâorientation.
Il existe deux variantes que vous pouvez explorer. Dans la premiĂšre, nous fixons requires_grad sur x aprĂšs avoir obtenu notre prĂ©diction de bruit du UNet, ce qui est plus efficace en termes de mĂ©moire (puisque nous nâavons pas Ă retracer les gradients Ă travers le modĂšle de diffusion), mais donne un gradient moins prĂ©cis. Dans le second cas, nous dĂ©finissons dâabord requires_grad sur x, puis nous le faisons passer par lâunet et nous calculons le x0 prĂ©dit.
# Variante 1 : méthode rapide
# L'échelle de guidance détermine l'intensité de l'effet
guidance_loss_scale = 40 # Envisagez de modifier cette valeur Ă 5, ou Ă 100
x = torch.randn(8, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
# Préparer l'entrée du modÚle
model_input = scheduler.scale_model_input(x, t)
# Prédire le bruit résiduel
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
# Fixer x.requires_grad Ă True
x = x.detach().requires_grad_()
# Obtenir la valeur prédite x0
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculer la perte
loss = color_loss(x0) * guidance_loss_scale
if i % 10 == 0:
print(i, "loss:", loss.item())
# Obtenir le gradient
cond_grad = -torch.autograd.grad(loss, x)[0]
# Modifier x en fonction de ce gradient
x = x.detach() + cond_grad
# Le planificateur
x = scheduler.step(noise_pred, t, x).prev_sample
# Voir le résultat
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
0 loss: 27.279136657714844
10 loss: 11.286816596984863
20 loss: 10.683112144470215
30 loss: 10.942476272583008
Cette deuxiĂšme option nĂ©cessite presque le double de RAM GPU pour fonctionner, mĂȘme si nous ne gĂ©nĂ©rons quâun batch de quatre images au lieu de huit. Voyez si vous pouvez repĂ©rer la diffĂ©rence et rĂ©flĂ©chissez Ă la raison pour laquelle cette mĂ©thode est plus « prĂ©cise » :
# Variante 2 : définir x.requires_grad avant de calculer les prédictions du modÚle
guidance_loss_scale = 40
x = torch.randn(4, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
# Définir requires_grad avant la passe avant du modÚle
x = x.detach().requires_grad_()
model_input = scheduler.scale_model_input(x, t)
# prédire (avec grad cette fois)
noise_pred = image_pipe.unet(model_input, t)["sample"]
# Obtenir la valeur prédite x0 :
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculer la perte
loss = color_loss(x0) * guidance_loss_scale
if i % 10 == 0:
print(i, "loss:", loss.item())
# Obtenir le gradient
cond_grad = -torch.autograd.grad(loss, x)[0]
# Modifier x en fonction de ce gradient
x = x.detach() + cond_grad
# Le planificateur
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
0 loss: 30.750328063964844
10 loss: 18.550724029541016
20 loss: 17.515094757080078
30 loss: 17.55681037902832
Dans la seconde variante, les besoins en mĂ©moire sont plus importants et lâeffet est moins prononcĂ©, de sorte que vous pouvez penser quâelle est infĂ©rieure. Cependant, les rĂ©sultats sont sans doute plus proches des types dâimages sur lesquels le modĂšle a Ă©tĂ© entraĂźnĂ©, et vous pouvez toujours augmenter lâĂ©chelle de guidage pour obtenir un effet plus important. Lâapproche que vous utiliserez dĂ©pendra en fin de compte de ce qui fonctionne le mieux sur le plan expĂ©rimental.
âïž Ă votre tour ! Choisissez votre couleur prĂ©fĂ©rĂ©e et recherchez ses valeurs dans lâespace RGB. Modifiez la ligne
color_loss()dans la cellule ci-dessus pour recevoir ces nouvelles valeurs RGB et examinez les résultats ; correspondent-ils à ce que vous attendez ?
Guidage avec CLIP
Guider vers une couleur nous donne un peu de contrÎle, mais que se passerait-il si nous pouvions simplement taper un texte décrivant ce que nous voulons ?
CLIP est un modĂšle créé par OpenAI qui nous permet de comparer des images Ă des lĂ©gendes textuelles. Câest extrĂȘmement puissant, car cela nous permet de quantifier Ă quel point une image correspond Ă un prompt. Et comme le processus est diffĂ©rentiable, nous pouvons lâutiliser comme fonction de perte pour guider notre modĂšle de diffusion !
Nous nâentrerons pas dans les dĂ©tails ici. Lâapproche de base est la suivante :
- EnchĂąsser le prompt pour obtenir un enchĂąssement CLIP Ă 512 dimensions
- Pour chaque étape du processus du modÚle de diffusion :
- CrĂ©er plusieurs variantes de lâimage dĂ©bruitĂ©e prĂ©dite (le fait dâavoir plusieurs variantes permet dâobtenir un signal de perte plus propre).
- Pour chacune dâentre elles, enchĂąsser lâimage avec CLIP et comparez cet enchĂąssement avec celui du prompt (Ă lâaide dâune mesure appelĂ©e « distance du grand cercle »).
- Calculer le gradient de cette perte par rapport Ă lâimage bruyante actuelle x et utiliser ce gradient pour modifier x avant de le mettre Ă jour avec le planificateur.
Pour une explication plus approfondie de CLIP, consultez cette leçon sur le sujet ou ce rapport sur le projet OpenCLIP que nous utilisons pour charger le modÚle CLIP. Exécutez la cellule suivante pour charger un modÚle CLIP :
import open_clip
clip_model, _, preprocess = open_clip.create_model_and_transforms(
"ViT-B-32", pretrained="openai"
)
clip_model.to(device)
# Transformations pour redimensionner et augmenter une image + normalisation pour correspondre aux données entraßnées par CLIP
tfms = torchvision.transforms.Compose(
[
torchvision.transforms.RandomResizedCrop(224), # CROP aléatoire à chaque fois
torchvision.transforms.RandomAffine(
5
), # Une augmentation aléatoire possible : biaiser l'image
torchvision.transforms.RandomHorizontalFlip(), # Vous pouvez ajouter des augmentations supplémentaires si vous le souhaitez
torchvision.transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
# Et définir une fonction de perte qui prend une image, l'enchùsse et la compare avec les caractéristiques textuelles du prompt
def clip_loss(image, text_features):
image_features = clip_model.encode_image(
tfms(image)
) # Note : applique les transformations ci-dessus
input_normed = torch.nn.functional.normalize(image_features.unsqueeze(1), dim=2)
embed_normed = torch.nn.functional.normalize(text_features.unsqueeze(0), dim=2)
dists = (
input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
) # Distance du grand cercle
return dists.mean()
Une fois la fonction de perte dĂ©finie, notre boucle dâĂ©chantillonnage guidĂ© ressemble aux exemples prĂ©cĂ©dents, en remplaçant color_loss() par notre nouvelle fonction de perte basĂ©e sur CLIP :
prompt = "Red Rose (still life), red flower painting" # @param
# Explorer en changeant ça
guidance_scale = 8 # @param
n_cuts = 4 # @param
# Plus d'étapes -> plus de temps pour que le guidage ait un effet
scheduler.set_timesteps(50)
# Nous enchĂąssons un prompt avec CLIP comme cible
text = open_clip.tokenize([prompt]).to(device)
with torch.no_grad(), torch.cuda.amp.autocast():
text_features = clip_model.encode_text(text)
x = torch.randn(4, 3, 256, 256).to(
device
) # L'utilisation de la RAM est Ă©levĂ©e, vous ne voulez peut-ĂȘtre qu'une seule image Ă la fois.
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
# prédire le bruit résiduel
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
cond_grad = 0
for cut in range(n_cuts):
# nécessite un grad sur x
x = x.detach().requires_grad_()
# Obtenir le x0 prédit
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
# Calculer la perte
loss = clip_loss(x0, text_features) * guidance_scale
# Obtenir le gradient (échelle par n_cuts puisque nous voulons la moyenne)
cond_grad -= torch.autograd.grad(loss, x)[0] / n_cuts
if i % 25 == 0:
print("Step:", i, ", Guidance loss:", loss.item())
# Modifier x en fonction de ce gradient
alpha_bar = scheduler.alphas_cumprod[i]
x = (
x.detach() + cond_grad * alpha_bar.sqrt()
) # Note the additional scaling factor here!
# Le planificateur
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x.detach(), nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))
Step: 0 , Guidance loss: 7.437869548797607
Step: 25 , Guidance loss: 7.174620628356934
Cela ressemble un peu Ă des roses ! Ce nâest pas parfait, mais si vous jouez avec les paramĂštres, vous pouvez obtenir des images agrĂ©ables.
Si vous examinez le code ci-dessus, vous verrez que nous mettons Ă lâĂ©chelle le gradient de conditionnement par un facteur de alpha_bar.sqrt(). Il existe des thĂ©ories sur la âbonneâ maniĂšre dâĂ©chelonner ces gradients, mais en pratique, vous pouvez expĂ©rimenter. Pour certains types de guidage, vous voudrez peut-ĂȘtre que la plupart des effets soient concentrĂ©s dans les premiĂšres Ă©tapes, pour dâautres (par exemple, une perte de style axĂ©e sur les textures), vous prĂ©fĂ©rerez peut-ĂȘtre quâils nâinterviennent que vers la fin du processus de gĂ©nĂ©ration. Quelques programmes possibles sont prĂ©sentĂ©s ci-dessous :
plt.plot([1 for a in scheduler.alphas_cumprod], label="no scaling")
plt.plot([a for a in scheduler.alphas_cumprod], label="alpha_bar")
plt.plot([a.sqrt() for a in scheduler.alphas_cumprod], label="alpha_bar.sqrt()")
plt.plot(
[(1 - a).sqrt() for a in scheduler.alphas_cumprod], label="(1-alpha_bar).sqrt()"
)
plt.legend()
plt.title("Possible guidance scaling schedules")
ExpĂ©rimentez avec diffĂ©rents planificateurs, Ă©chelles de guidage et toute autre astuce Ă laquelle vous pouvez penser (lâĂ©crĂȘtage des gradients dans une certaine plage est une modification populaire) pour voir jusquâĂ quel point vous pouvez obtenir ce rĂ©sultat ! Nâoubliez pas non plus dâessayer dâintervertir dâautres modĂšles. Peut-ĂȘtre le modĂšle de visages que nous avons chargĂ© au dĂ©but ; pouvez-vous le guider de maniĂšre fiable pour produire un visage masculin ? Que se passe-t-il si vous combinez le guidage CLIP avec la perte de couleur que nous avons utilisĂ©e plus tĂŽt ? Etc.
Si vous consultez quelques codes pour la diffusion guidĂ©e par CLIP en pratique, vous verrez une approche plus complexe avec une meilleure classe pour choisir des dĂ©coupes alĂ©atoires dans les images et de nombreux ajustements supplĂ©mentaires de la fonction de perte pour de meilleures performances. Avant lâapparition des modĂšles de diffusion conditionnĂ©s par le texte, il sâagissait du meilleur systĂšme de conversion texte-image qui soit ! La petite version de notre jouet peut encore ĂȘtre amĂ©liorĂ©e, mais elle capture lâidĂ©e principale : grĂące au guidage et aux capacitĂ©s Ă©tonnantes de CLIP, nous pouvons ajouter le contrĂŽle du texte Ă un modĂšle de diffusion inconditionnel đš.
Partager une boucle dâĂ©chantillonnage personnalisĂ©e en tant que dĂ©mo Gradio
Vous avez peut-ĂȘtre trouvĂ© une perte amusante pour guider la gĂ©nĂ©ration et vous souhaitez maintenant partager avec le monde entier votre modĂšle finetunĂ© et cette stratĂ©gie dâĂ©chantillonnage personnalisĂ©eâŠ
Entrez dans Gradio. Gradio est un outil gratuit et open-source qui permet aux utilisateurs de crĂ©er et de partager facilement des modĂšles interactifs dâapprentissage automatique via une simple interface web. Avec Gradio, les utilisateurs peuvent construire des interfaces personnalisĂ©es pour leurs modĂšles dâapprentissage automatique, qui peuvent ensuite ĂȘtre partagĂ©s avec dâautres par le biais dâune URL unique. Il est Ă©galement intĂ©grĂ© Ă đ€ Spaces, ce qui permet dâhĂ©berger facilement des dĂ©mos et de les partager avec dâautres.
Nous placerons notre logique de base dans une fonction qui prend certaines entrĂ©es et produit une image en sortie. Cette fonction peut ensuite ĂȘtre enveloppĂ©e dans une interface simple qui permet Ă lâutilisateur de spĂ©cifier certains paramĂštres (qui sont transmis en tant quâentrĂ©es Ă la fonction principale de gĂ©nĂ©ration). De nombreux composants sont disponibles ; pour cet exemple, nous utiliserons un curseur pour lâĂ©chelle dâorientation et un sĂ©lecteur de couleurs pour dĂ©finir la couleur cible.
!pip install -q gradio
import gradio as gr
from PIL import Image, ImageColor
# La fonction qui fait le gros du travail
def generate(color, guidance_loss_scale):
target_color = ImageColor.getcolor(color, "RGB") # Couleur cible en RGB
target_color = [a / 255 for a in target_color] # Rééchelonner de (0, 255) à (0, 1)
x = torch.randn(1, 3, 256, 256).to(device)
for i, t in tqdm(enumerate(scheduler.timesteps)):
model_input = scheduler.scale_model_input(x, t)
with torch.no_grad():
noise_pred = image_pipe.unet(model_input, t)["sample"]
x = x.detach().requires_grad_()
x0 = scheduler.step(noise_pred, t, x).pred_original_sample
loss = color_loss(x0, target_color) * guidance_loss_scale
cond_grad = -torch.autograd.grad(loss, x)[0]
x = x.detach() + cond_grad
x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
im = Image.fromarray(np.array(im * 255).astype(np.uint8))
im.save("test.jpeg")
return im
# Voir la documentation de gradio pour les types d'entrées et de sorties disponibles.
inputs = [
gr.ColorPicker(label="color", value="55FFAA"), # Ajoutez ici toutes les entrées dont vous avez besoin
gr.Slider(label="guidance_scale", minimum=0, maximum=30, value=3),
]
outputs = gr.Image(label="result")
# Et l'interface minimale
demo = gr.Interface(
fn=generate,
inputs=inputs,
outputs=outputs,
examples=[
["#BB2266", 3],
["#44CCAA", 5], # Vous pouvez fournir des exemples d'entrées pour aider les gens à démarrer
],
)
demo.launch(debug=True) # debug=True vous permet de voir les erreurs et les sorties dans Colab
Il est possible de construire des interfaces beaucoup plus compliquĂ©es, avec un style fantaisiste et un large Ă©ventail dâentrĂ©es possibles, mais pour cette dĂ©mo, nous la gardons aussi simple que possible.
Les dĂ©mos sur đ€ Spaces sâexĂ©cutent par dĂ©faut sur CPU, il est donc prĂ©fĂ©rable de prototyper votre interface dans Colab (comme ci-dessus) avant de la migrer. Lorsque vous ĂȘtes prĂȘt Ă partager votre dĂ©mo, vous devez crĂ©er un Space, mettre en place un fichier requirements.txt listant les bibliothĂšques que votre code utilisera, puis placer tout le code dans un fichier app.py qui dĂ©finit les fonctions pertinentes et lâinterface.
Heureusement pour vous, il est Ă©galement possible de âdupliquerâ un Space. Vous pouvez visiter le Space ici et cliquer sur âDupliquer cet espaceâ pour obtenir un modĂšle que vous pouvez ensuite modifier pour utiliser votre propre modĂšle et votre propre fonction dâorientation.
Dans les paramĂštres, vous pouvez configurer votre Space pour quâil fonctionne avec du matĂ©riel plus sophistiquĂ© (qui est facturĂ© Ă lâheure). Vous avez créé quelque chose dâextraordinaire et vous voulez le partager sur un meilleur matĂ©riel, mais vous nâavez pas lâargent nĂ©cessaire ? Faites-le nous savoir via Discord et nous verrons si nous pouvons vous aider !
Résumé et prochaines étapes
Nous avons couvert beaucoup de choses dans ce notebook ! Récapitulons les idées principales :
- Il est relativement facile de charger des modÚles existants et de les échantillonner avec différents planificateurs
- Le finetuning ressemble Ă lâentraĂźnement Ă partir de zĂ©ro, sauf quâen partant dâun modĂšle existant, nous espĂ©rons obtenir de meilleurs rĂ©sultats plus rapidement.
- Pour finetuner de grands modĂšles sur de grandes images, nous pouvons utiliser des astuces comme lâaccumulation de gradient pour contourner les limitations de la taille des batchs.
- Lâenregistrement dâĂ©chantillons dâimages est important pour le finetuning, oĂč une courbe de perte peut ne pas fournir beaucoup dâinformations utiles.
- Le guidage nous permet de prendre un modĂšle inconditionnel et dâorienter le processus de gĂ©nĂ©ration sur la base dâune fonction de guidage/perte, oĂč Ă chaque Ă©tape nous trouvons le gradient de la perte par rapport Ă lâimage bruitĂ©e $x$ et lâactualisons en fonction de ce gradient avant de passer Ă lâĂ©tape temporelle suivante.
- Le guidage avec CLIP nous permet de contrĂŽler des modĂšles inconditionnels avec du texte !
Pour mettre cela en pratique, voici quelques étapes spécifiques que vous pouvez suivre :
- FinetunĂ© votre propre modĂšle et le pousser vers le Hub. Cela implique de choisir un point de dĂ©part (par exemple, un modĂšle entraĂźnĂ© sur faces, bedrooms, cats ou wikiart et un jeu de donnĂ©es (peut-ĂȘtre ces faces dâanimaux ou vos propres images), puis dâentraĂźner soit le code de ce notebook, soit le script dâexemple (utilisation de dĂ©monstration ci-dessous).
- Explorer le guidage en utilisant votre modĂšle finetunĂ©, soit en utilisant lâune des fonctions de guidage de lâexemple (color_loss ou CLIP), soit en inventant la vĂŽtre.
- Partagez une dĂ©mo basĂ©e sur ceci en utilisant Gradio, soit en modifiant le Space dâexemple pour utiliser votre propre modĂšle, soit en crĂ©ant votre propre version personnalisĂ©e avec plus de fonctionnalitĂ©s.
Nous sommes impatients de voir vos rĂ©sultats sur Discord, Twitter et ailleurs đ€ !