The goal of this blog post is to implement the DiffEdit paper to the best of my understanding. While this task is primarily to help my own understanding of the process, I also want to help the reader of my post understand the process better as well.

Before I start, I want to give a big thanks to the DiffEdit authors (Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord) for publishing this paper. Without the openness and willingness to share, this type of implementation would not be possible. I also want to thank the fast.ai community for helping me solve problems when I was unclear about how to move forward. Finally, I want to thank Jonathan Whitaker, the auther of the Stable Diffusion Deep Dive notebook. This was the notebook that I started with in my DiffEdit implementation.

If anybody reading this blog post works in manufacturing and cares about improving product quality (or knows somebody that does), please reach out to me at kevin@problemsolversguild.com.

The first thing we need to do is choose an image that we want to use as a starting point. I chose to go with a picture that was similar to one of the images used in the paper, but maybe a little harder.

p = FastDownload().download('https://negativespace.co/wp-content/uploads/2020/11/negative-space-horses-in-field-with-trees-1062x705.jpg')
init_image = Image.open(p).convert("RGB")
# init_image = init_image.resize((init_image.size[0]//2, init_image.size[1]//2))
init_image = init_image.resize((512,512))
init_image

Now that I have found an image, let's define the reference_text and the query_text. These are defined in the paper as R and Q. Let's follow the paper here and keep Q and R simple.

reference_text = "Two horses"
query_text = "Two zebras"

A good amount of the code in the next few cells is coming from the StableDiffusionImg2ImgPipeline function in the diffusers library. This was very helpful in creating the implementation I ended up with

def preprocess(image):
    w, h = image.size
    w, h = map(lambda x: x - x % 32, (w, h))  # resize to integer multiple of 32
    image = image.resize((w, h), resample=PIL.Image.Resampling.LANCZOS)
    image = np.array(image).astype(np.float32) / 255.0
    image = image[None].transpose(0, 3, 1, 2)
    image = torch.from_numpy(image)
    return 2.0 * image - 1.0
def get_text_embeddings(prompt, negative_prompt, tokenizer, text_encoder, do_classifier_free_guidance, device): #outputs text_embeddings
    # get prompt text embeddings
    text_inputs = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, 
                            return_tensors="pt", truncation=True)
    text_input_ids = text_inputs.input_ids
    text_embeddings = text_encoder(text_input_ids.to(device))[0]
    # text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
    
    if negative_prompt is None:
        uncond_tokens = [""]
    else:
        uncond_tokens = negative_prompt
    max_length = text_input_ids.shape[-1]
    uncond_input = tokenizer(uncond_tokens, padding="max_length", max_length=max_length, 
                             return_tensors="pt", truncation=True)
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]

    # duplicate unconditional embeddings for each generation per prompt
    # uncond_embeddings = uncond_embeddings.repeat_interleave(batch_size * num_images_per_prompt, dim=0)
    
    # For classifier free guidance, we need to do two forward passes.
    # Here we concatenate the unconditional and text embeddings into a single batch
    # to avoid doing two forward passes
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    return text_embeddings
def get_timestamps(scheduler, num_inference_steps, strength, device):
    scheduler.set_timesteps(num_inference_steps)

    # get the original timestep using init_timestep
    offset = scheduler.config.get("steps_offset", 0)
    init_timestep = int(num_inference_steps * strength) + offset
    init_timestep = min(init_timestep, num_inference_steps)

    timesteps = scheduler.timesteps[-init_timestep]
    timesteps = torch.tensor([timesteps], device=device)
    t_start = max(num_inference_steps - init_timestep + offset, 0)
    return timesteps, t_start
def encode_image(init_image, latents_dtype, device):
    # encode the init image into latents and scale the latents
    init_image = preprocess(init_image)
    init_image = init_image.to(device=device, dtype=latents_dtype)
    with torch.no_grad(): init_latent_dist = vae.encode(init_image).latent_dist
    init_latents = init_latent_dist.sample(generator=generator)
    init_latents = 0.18215 * init_latents
    return init_latents
def img2noise(init_image, 
              prompt,
              mask=None,
              strength = 0.5,
              num_inference_steps = 50,
              guidance_scale = 5,
              negative_prompt=None,
              generator = None, 
              output_type = "pil", 
              return_dict = True, 
              callback = None, 
              callback_steps = 1, 
              device='cuda'
             ):
    do_classifier_free_guidance = guidance_scale > 1.0
    text_embeddings = get_text_embeddings(prompt, negative_prompt, tokenizer, text_encoder, do_classifier_free_guidance, device)
    latents_dtype=text_embeddings.dtype
    timesteps, t_start = get_timestamps(scheduler, num_inference_steps, strength, device)
    
    # encode the init image into latents and scale the latents
    init_latents = encode_image(init_image, latents_dtype, device)

    # add noise to latents using the timesteps
    noise = torch.randn(init_latents.shape, generator=generator, device=device, dtype=latents_dtype)
    noisy_latents = scheduler.add_noise(init_latents, noise, timesteps)

    latents = noisy_latents

    # Some schedulers like PNDM have timesteps as arrays
    # It's more optimized to move all timesteps to correct device beforehand
    timesteps = scheduler.timesteps[t_start:].to(device)
    noise_preds = torch.tensor([], device='cuda')
    
    for i, t in enumerate(timesteps):
        # expand the latents if we are doing classifier free guidance
        latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
        
        noise_preds = torch.concat((noise_preds, noise_pred)) #This performs much worse when outside the for-loop
        # compute the previous noisy sample x_t -> x_t-1
        latent_step = scheduler.step(noise_pred, t, latents)
        latents = latent_step.prev_sample 
        if mask is not None: 
            latents = mask*latents+(1-mask)*init_latents
            
    latents = 1 / 0.18215 * latents
    with torch.no_grad(): image = vae.decode(latents).sample

    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images, noise_preds

Estimate noise conditioned to Reference Text R

generator = torch.cuda.manual_seed(32)
reference_noises = torch.tensor([], device='cuda')
for _ in range(10):
    reference_pil, reference_noise = img2noise(init_image, strength=0.5, prompt=reference_text, generator=generator)
    reference_noises = torch.concat((reference_noises, reference_noise))

Estimate noise conditioned to Query Q

generator = torch.cuda.manual_seed(32)
query_noises = torch.tensor([], device='cuda')
for _ in range(10):
    query_pil, query_noise = img2noise(init_image, strength=0.5, prompt=query_text, generator=generator)
    query_noises = torch.concat((query_noises, query_noise))

View Latent Noise Channels

fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(reference_noises.mean(0, keepdim=True)[0][c].cpu())
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(query_noises.mean(0, keepdim=True)[0][c].cpu())

While there isn't much that looks interesting when looking at the reference_noises or query_noises individually, let's look at the difference between the two.

diff_noises = (reference_noises.mean(0, keepdim=True) - query_noises.mean(0, keepdim=True))
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(diff_noises[0][c].cpu())

Now, we are seeing some signs that the noise that is being removed is quite different over the horse area of the picture and pretty similar outside of that area. One thing to note on these channels is that some of them are darker surrounding the horses and some are lighter.

diff_noises.min(), diff_noises.max()
(tensor(-0.8503, device='cuda:0'), tensor(0.6133, device='cuda:0'))

One thing I've found improves this is to determine the distance away from the midpoint, so I take the absolute value to make sure the intensity, not the direction, is being taken into consideration. The thought here is that whether the zebra query or the horse reference is activating that noise a lot, it is probably a pixel we should include in the mask.

diff_noises_abs = diff_noises.abs()
fig, axs = plt.subplots(1, 4, figsize=(16, 4))
for c in range(4):
    axs[c].imshow(diff_noises_abs[0][c].cpu())#, cmap='Greys')