File size: 2,435 Bytes
607bd68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from base64 import b64encode

import torch
import numpy as np
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login
import torch.nn.functional as F
# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging
import os
from device import torch_device,vae,text_encoder,unet,tokenizer,scheduler,token_emb_layer,pos_emb_layer,position_embeddings



# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return 0.18215 * latent.latent_dist.sample()


def latents_to_pil(latents):
    # batch of latents -> list of images
    latents = (1 / 0.18215) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images


def set_timesteps(scheduler, num_inference_steps):
    scheduler.set_timesteps(num_inference_steps)
    scheduler.timesteps = scheduler.timesteps.to(torch.float32)


def orange_loss(image):
    # Convert the image to a NumPy array
    #image = image.float()  # Convert to a more standard data type (float32)
    #image_np = image.detach().cpu().numpy()  # Use .detach() and .cpu() to ensure compatibility

    # Extract the orange channel (e.g., Red and Green channels)
    orange_channel = image[:, 0, :, :] + image[:, 1, :, :]

    # Calculate the mean intensity of the orange channel
    #orange_mean = np.mean(orange_channel)

    # Define the target mean intensity you desire
    target_mean = 0.8  # Replace with your desired mean intensity

    # Calculate the loss based on the squared difference from the target
    loss = torch.abs(orange_channel- target_mean).mean()

    # Convert the loss to a PyTorch tensor
    #loss = torch.tensor(loss, dtype=image.dtype)

    return loss