Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,643 Bytes
8483373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import torch
import torchvision
import os
import gc
import tqdm
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from lora_w2w import LoRAw2w
from transformers import AutoTokenizer, PretrainedConfig
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
######## Editing Utilities
def get_direction(df, label, pinverse, return_dim, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
direction = (pinverse@labels).unsqueeze(0)
if return_dim == 1000:
return direction
else:
direction = torch.cat((direction, torch.zeros((1, return_dim-1000)).to(device)), dim=1)
return direction
def debias(direction, label, df, pinverse, device):
### get labels
labels = []
for folder in list(df.index):
labels.append(df.loc[folder][label])
labels = torch.Tensor(labels).to(device).bfloat16()
### solve least squares
d = (pinverse@labels).unsqueeze(0)
###align dimensionalities of the two vectors
if direction.shape[1] == 1000:
pass
else:
d = torch.cat((d, torch.zeros((1, direction.shape[1]-1000)).to(device)), dim=1)
#remove this component from the direction
direction = direction - (([email protected])/(torch.norm(d)**2))*d
return direction
@torch.no_grad
def edit_inference(network, edited_weights, unet, vae, text_encoder, tokenizer, prompt, negative_prompt, guidance_scale, noise_scheduler, ddim_steps, start_noise, seed, generator, device):
original_weights = network.proj.clone()
generator = generator.manual_seed(seed)
latents = torch.randn(
(1, unet.in_channels, 512 // 8, 512 // 8),
generator = generator,
device = device
).bfloat16()
text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer(
[negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
)
uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
noise_scheduler.set_timesteps(ddim_steps)
latents = latents * noise_scheduler.init_noise_sigma
for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
if t>start_noise:
pass
elif t<=start_noise:
network.proj = torch.nn.Parameter(edited_weights)
network.reset()
with network:
noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
#guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1)
#reset weights back to original
network.proj = torch.nn.Parameter(original_weights)
network.reset()
return image
|