01_ddim_inversion_CN

DDIM reverse

Settings

# !pip install -q transformers diffusers accelerate
import torch
import requests
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from io import BytesIO
from tqdm.auto import tqdm
from matplotlib import pyplot as plt
from torchvision import transforms as tfms
from diffusers import StableDiffusionPipeline, DDIMScheduler

# Useful function for later
def load_image(url, size=None):
    response = requests.get(url,timeout=0.2)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    if size is not None:
        img = img.resize(size)
    return img
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load a trained pipeline

# Load a pipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to(device)
# Set up a DDIM scheduler:
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

```python
# Sample an image to make sure it is all working
prompt = 'Beautiful DSLR Photograph of a penguin on the beach, golden hour'
negative_prompt = 'blurry, ugly, stock photo'
im = pipe(prompt, negative_prompt=negative_prompt).images[0]
im.resize((256, 256)) # resize for convenient viewing

DDIM sampling process

at a given time

t

t

t, noisy image

x

t

x_t

xt? is the original image (

x

0

x_0

x0?) with some noise (

?

\epsilon

?) superposition. This is in the DDIM paper

x

t

x_t

The definition of xt?, we quote it into this section:

x

t

=

α

t

x

0

+

1

?

α

t

?

x_t = \sqrt{\alpha_t}x_0 + \sqrt{1-\alpha_t}\epsilon

xt?=αt?
?x0? + 1?αt?

?

\epsilon

? is some Gaussian noise with normalized variance

α

t

\alpha_t

αt? (alpha’) is also called in the DDPM paper

α

ˉ

\bar{\alpha}

αˉ (alpha_bar’), is used to define the noise scheduler. In the diffusion model, the alpha scheduler is calculated and sorted and stored in scheduler.alphas_cumprod. This is confusing, I understand! We plot these values, and then we will use the DDIM annotation method below.

# Plot 'alpha' (alpha_bar in DDPM language, alphas_cumprod in diffusers for clarity)
timesteps = pipe.scheduler.timesteps.cpu()
alphas = pipe.scheduler.alphas_cumprod[timesteps]
plt.plot(timesteps, alphas, label='alpha_t');
plt.legend();

Initially (timestep 0, left in the figure) starts from a clean image without noise,

α

t

=

1

\alpha_t = 1

αt?=1. As we reach higher timesteps, we get an image that is almost entirely noisy,

α

t

\alpha_t

αt? also drops almost to 0.

During the sampling process, we start from pure noise at timestep1000 and slowly move towards timestep0. To calculate the next moment in the sampled trajectory (

x

t

?

1

x_{t-1}

xt?1? Since we are moving from back to front), we predict the noise (

?

θ

(

x

t

)

\epsilon_\theta(x_t)

?θ?(xt?), this is the output of our model), use it to predict noise-free pictures

x

0

x_0

x0?. After that we use this prediction result to move towards’

x

t

x_t

Move a small step in the direction of xt? Finally, we can add some bands

σ

t

\sigma_t

Additional noise of σt? coefficients. This is the chapter in the paper related to the above operations:

OK, we have, under controllable metric noise, from

x

t

x_t

xt? move to

x

t

?

1

x_{t-1}

The formula of xt?1?. The case we are talking about today does not require additional noise – that is, completely deterministic DDIM sampling. Let’s see how this is expressed in code.

# Sample function (regular DDIM)
@torch.no_grad()
def sample(prompt, start_step=0, start_latents=None,
           guidance_scale=3.5, num_inference_steps=30,
           num_images_per_prompt=1, do_classifier_free_guidance=True,
           negative_prompt='', device=device):
  
    # Encode prompt
    text_embeddings = pipe._encode_prompt(
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Create a random starting point if we don't have one already
    if start_latents is None:
        start_latents = torch.randn(1, 4, 64, 64, device=device)
        start_latents *= pipe.scheduler.init_noise_sigma

    latents = start_latents.clone()

    for i in tqdm(range(start_step, num_inference_steps)):
    
        t = pipe.scheduler.timesteps[i]

        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)


        # Normally we'd rely on the scheduler to handle the update step:
        # latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample

        # Instead, let's do it ourselves:
        prev_t = max(1, t.item() - (1000//num_inference_steps)) # t-1
        alpha_t = pipe.scheduler.alphas_cumprod[t.item()]
        alpha_t_prev = pipe.scheduler.alphas_cumprod[prev_t]
        predicted_x0 = (latents - (1-alpha_t).sqrt()*noise_pred) / alpha_t.sqrt()
        direction_pointing_to_xt = (1-alpha_t_prev).sqrt()*noise_pred
        latents = alpha_t_prev.sqrt()*predicted_x0 + direction_pointing_to_xt

    #Post-processing
    images = pipe.decode_latents(latents)
    images = pipe.numpy_to_pil(images)

    return images
# Test our sampling function by generating an image
sample('Watercolor painting of a beach sunset', negative_prompt=negative_prompt, num_inference_steps=50)[0].resize((256, 256))

See if you can match these codes to the formulas in the paper. Notice

σ

\sigma

σ=0 because we only pay attention to the scenario without extra noise, so we omit that part of the formula.

Reverse

The goal of inversion is to ‘reverse’ the sampling process. We want to end up with a noisy latent, which if used as the starting point of our normal sampling process, will result in an original image.

Here we first load an original image, of course you can also generate an image instead.

# https://www.pexels.com/photo/a-beagle-on-green-grass-field-8306128/
input_image = load_image('https://images.pexels.com/photos/8306128/pexels-photo-8306128.jpeg', size=(512, 512))
input_image

We can use the prompt containing the classifier-free-guidance to perform the inversion operation and enter a description of the image:

input_image_prompt = "Photograph of a puppy on the grass"

Next we turn this PIL image into a series of implicits, which will be used as the starting point for inversion:

# encode with VAE
with torch.no_grad(): latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device)*2-1)
l = 0.18215 * latent.latent_dist.sample()

Okay, here comes the fun part. This function looks very similar to the sampling function above, but we are moving in the opposite direction in timesteps, starting at t=0 and moving towards more and more noise. Instead of updating implicitly, which is less noisy, we estimate the predicted noise and use it to undo one update operation, moving them from t to t + 1.

## Inversion
@torch.no_grad()
def invert(start_latents, prompt, guidance_scale=3.5, num_inference_steps=80,
           num_images_per_prompt=1, do_classifier_free_guidance=True,
           negative_prompt='', device=device):
  
    # Encode prompt
    text_embeddings = pipe._encode_prompt(
            prompt, device, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
    )

    # latents are now the specified start latents
    latents = start_latents.clone()

    # We'll keep a list of the inverted latents as the process goes on
    intermediate_latents = []

    # Set num inference steps
    pipe.scheduler.set_timesteps(num_inference_steps, device=device)

    # Reversed timesteps <<<<<<<<<<<<<<<<<<<<
    timesteps = reversed(pipe.scheduler.timesteps)

    for i in tqdm(range(1, num_inference_steps), total=num_inference_steps-1):

        # We'll skip the final iteration
        if i >= num_inference_steps - 1: continue

        t = timesteps[i]

        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        current_t = max(0, t.item() - (1000//num_inference_steps))#t
        next_t = t # min(999, t.item() + (1000//num_inference_steps)) # t + 1
        alpha_t = pipe.scheduler.alphas_cumprod[current_t]
        alpha_t_next = pipe.scheduler.alphas_cumprod[next_t]

        # Inverted update step (re-arranging the update step to get x(t) (new latents) as a function of x(t-1) (current latents)
        latents = (latents - (1-alpha_t).sqrt()*noise_pred)*(alpha_t_next.sqrt()/alpha_t.sqrt()) + (1-alpha_t_next).sqrt()*noise_pred


        #Store
        intermediate_latents.append(latents)
            
    return torch.cat(intermediate_latents)

Running this on the implicit expression for the puppy picture, we can get a series of implicits in the middle of the inversion:

inverted_latents = invert(l, input_image_prompt,num_inference_steps=50)
inverted_latents.shape
 0%| | 0/49 [00:00<?, ?it/s]





torch.Size([48, 4, 64, 64])

We can take a look at the final implicit noise – hopefully this can serve as a starting point for us to try new sampling procedures:

# Decode the final inverted latents:
with torch.no_grad():
  im = pipe.decode_latents(inverted_latents[-1].unsqueeze(0))
pipe.numpy_to_pil(im)[0]

You can pass this inversion implicitly to the pipeline via the normal call method.

pipe(input_image_prompt, latents=inverted_latents[-1][None], num_inference_steps=50, guidance_scale=3.5).images[0]

 0%| | 0/50 [00:00<?, ?it/s]

But here we encounter the first problem: this is not the picture we started with! This is because the inversion of DDIM relies on an important assumption, that the predicted noise at time t will be the same as at time t + 1 – this is not true when we only invert 50 or 100 steps. We can hope to get a more accurate reversal by opening more timesteps, but we can also ‘cheat’, that is, start directly from the implicit step 20/50 of the corresponding reversal process:

# The reason we want to be able to specify start step
start_step=20
sample(input_image_prompt, start_latents=inverted_latents[-(start_step + 1)][None],
       start_step=start_step, num_inference_steps=50)[0]

 0%| | 0/30 [00:00<?, ?it/s]

It’s very close to our input image! Why do we do this? Well, that’s because if we were to generate the image now with a new prompt, we would get an image that matches the source image, except for the content associated with the new prompt. For example, replacing ‘puppy’ with ‘cat’, we can see a cat on an almost identical grass background:

# Sampling with a new prompt
start_step=10
new_prompt = input_image_prompt.replace('puppy', 'cat')
sample(new_prompt, start_latents=inverted_latents[-(start_step + 1)][None],
       start_step=start_step, num_inference_steps=50)[0]
 0%| | 0/40 [00:00<?, ?it/s]

Why not just use img2img?

Why do we need to reverse it? Isn’t it unnecessary? Why not just add noise to the input image, and then use the new promt to directly remove the noise? We could do this, but this would result in a photo where everything is changed and exaggerated (if we add a lot of noise), or an image where nothing changes much (if we add too little noise). Come try it yourself:

start_step = 10
num_inference_steps=50
pipe.scheduler.set_timesteps(num_inference_steps)
noise_l = pipe.scheduler.add_noise(l, torch.randn_like(l), pipe.scheduler.timesteps[start_step])
sample(new_prompt, start_latents=noisy_l, start_step=start_step, num_inference_steps=num_inference_steps)[0]
 0%| | 0/40 [00:00<?, ?it/s]

Notice that the background and lawn have changed significantly.

Assemble them all

Let’s assemble all the code we have written so far into a simple function. If we input an image and two prompts, we will get a modified image obtained by inversion:

def edit(input_image, input_image_prompt, edit_prompt, num_steps=100, start_step=30, guidance_scale=3.5):
    with torch.no_grad(): latent = pipe.vae.encode(tfms.functional.to_tensor(input_image).unsqueeze(0).to(device)*2-1)
    l = 0.18215 * latent.latent_dist.sample()
    inverted_latents = invert(l, input_image_prompt,num_inference_steps=num_steps)
    final_im = sample(edit_prompt, start_latents=inverted_latents[-(start_step + 1)][None],
                      start_step=start_step, num_inference_steps=num_steps, guidance_scale=guidance_scale)[0]
    return final_im

And in action:
In practice:

edit(input_image, 'A puppy on the grass', 'an old gray dog on the grass', num_steps=50, start_step=10)
 0%| | 0/49 [00:00

edit(input_image, 'A puppy on the grass', 'A blue dog on the lawn', num_steps=50, start_step=12, guidance_scale=6)
 0%| | 0/49 [00:00<?, ?it/s]



  0%| | 0/38 [00:00<?, ?it/s]