|
from diffusers import DiffusionPipeline |
|
import torch |
|
import torch.nn as nn |
|
import os |
|
from diffusers.utils import BaseOutput |
|
from dataclasses import dataclass |
|
from typing import List, Union, Optional |
|
from PIL import Image |
|
import numpy as np |
|
import json |
|
from safetensors.torch import load_file |
|
from tqdm import tqdm |
|
|
|
@dataclass |
|
class SdxsPipelineOutput(BaseOutput): |
|
images: Union[List[Image.Image], np.ndarray] |
|
|
|
class SdxsPipeline(DiffusionPipeline): |
|
def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, text_projector=None): |
|
super().__init__() |
|
|
|
|
|
self.register_modules( |
|
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, |
|
unet=unet, scheduler=scheduler |
|
) |
|
|
|
|
|
model_path = None |
|
if hasattr(self, '_internal_dict') and self._internal_dict.get('_name_or_path'): |
|
model_path = self._internal_dict.get('_name_or_path') |
|
|
|
|
|
device = "cuda" |
|
dtype = torch.float16 |
|
|
|
|
|
projector_path = None |
|
|
|
|
|
if model_path and os.path.exists(f"{model_path}/text_projector"): |
|
projector_path = f"{model_path}/text_projector" |
|
elif os.path.exists("./text_projector"): |
|
projector_path = "./text_projector" |
|
|
|
if projector_path: |
|
|
|
try: |
|
with open(f"{projector_path}/config.json", "r") as f: |
|
projector_config = json.load(f) |
|
|
|
|
|
self.text_projector = nn.Linear( |
|
in_features=projector_config["in_features"], |
|
out_features=projector_config["out_features"], |
|
bias=False |
|
) |
|
|
|
|
|
self.text_projector.load_state_dict(load_file(f"{projector_path}/model.safetensors")) |
|
self.text_projector.to(device=device, dtype=dtype) |
|
print(f"Successfully loaded text_projector from {projector_path}",device, dtype) |
|
except Exception as e: |
|
print(f"Error loading text_projector: {e}") |
|
|
|
self.vae_scale_factor = 8 |
|
|
|
|
|
|
|
def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None): |
|
"""Кодирование текстовых промптов в эмбеддинги. |
|
|
|
Возвращает: |
|
- text_embeddings: Тензор эмбеддингов [batch_size, 1, dim] или [2*batch_size, 1, dim] с guidance |
|
""" |
|
if prompt is None and negative_prompt is None: |
|
raise ValueError("Требуется хотя бы один из параметров: prompt или negative_prompt") |
|
|
|
|
|
device = device or self.device |
|
dtype = dtype or next(self.unet.parameters()).dtype |
|
|
|
with torch.no_grad(): |
|
|
|
if prompt is not None: |
|
if isinstance(prompt, str): |
|
prompt = [prompt] |
|
|
|
text_inputs = self.tokenizer( |
|
prompt, return_tensors="pt", padding="max_length", |
|
max_length=512, truncation=True |
|
).to(device) |
|
|
|
|
|
outputs = self.text_encoder(text_inputs.input_ids, text_inputs.attention_mask) |
|
last_hidden_state = outputs.last_hidden_state.to(device, dtype=dtype) |
|
pos_embeddings = self.text_projector(last_hidden_state[:, 0]) |
|
|
|
|
|
if pos_embeddings.ndim == 2: |
|
pos_embeddings = pos_embeddings.unsqueeze(1) |
|
else: |
|
|
|
|
|
batch_size = len(negative_prompt) if isinstance(negative_prompt, list) else 1 |
|
pos_embeddings = torch.zeros( |
|
batch_size, 1, self.unet.config.cross_attention_dim, |
|
device=device, dtype=dtype |
|
) |
|
|
|
|
|
if negative_prompt is not None: |
|
if isinstance(negative_prompt, str): |
|
negative_prompt = [negative_prompt] |
|
|
|
|
|
if prompt is not None and len(negative_prompt) != len(prompt): |
|
neg_batch_size = len(prompt) |
|
if len(negative_prompt) == 1: |
|
negative_prompt = negative_prompt * neg_batch_size |
|
else: |
|
negative_prompt = negative_prompt[:neg_batch_size] |
|
|
|
neg_inputs = self.tokenizer( |
|
negative_prompt, return_tensors="pt", padding="max_length", |
|
max_length=512, truncation=True |
|
).to(device) |
|
|
|
neg_outputs = self.text_encoder(neg_inputs.input_ids, neg_inputs.attention_mask) |
|
neg_last_hidden_state = neg_outputs.last_hidden_state.to(device, dtype=dtype) |
|
neg_embeddings = self.text_projector(neg_last_hidden_state[:, 0]) |
|
|
|
if neg_embeddings.ndim == 2: |
|
neg_embeddings = neg_embeddings.unsqueeze(1) |
|
|
|
|
|
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) |
|
else: |
|
|
|
batch_size = pos_embeddings.shape[0] |
|
neg_embeddings = torch.zeros_like(pos_embeddings) |
|
text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0) |
|
|
|
return text_embeddings.to(device=device, dtype=dtype) |
|
|
|
@torch.no_grad() |
|
def generate_latents( |
|
self, |
|
text_embeddings, |
|
height: int = 576, |
|
width: int = 576, |
|
num_inference_steps: int = 40, |
|
guidance_scale: float = 5.0, |
|
latent_channels: int = 16, |
|
batch_size: int = 1, |
|
generator = None, |
|
): |
|
"""Генерация латентов с использованием эмбеддингов промптов.""" |
|
device = self.device |
|
dtype = next(self.unet.parameters()).dtype |
|
|
|
|
|
do_classifier_free_guidance = guidance_scale > 0 |
|
embedding_dim = text_embeddings.shape[0] // 2 if do_classifier_free_guidance else text_embeddings.shape[0] |
|
|
|
if batch_size > embedding_dim: |
|
|
|
if do_classifier_free_guidance: |
|
neg_embeds, pos_embeds = text_embeddings.chunk(2) |
|
neg_embeds = neg_embeds.repeat(batch_size // embedding_dim, 1, 1) |
|
pos_embeds = pos_embeds.repeat(batch_size // embedding_dim, 1, 1) |
|
text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0) |
|
else: |
|
text_embeddings = text_embeddings.repeat(batch_size // embedding_dim, 1, 1) |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
|
|
|
latent_shape = ( |
|
batch_size, |
|
latent_channels, |
|
height // self.vae_scale_factor, |
|
width // self.vae_scale_factor |
|
) |
|
latents = torch.randn( |
|
latent_shape, |
|
device=device, |
|
dtype=dtype, |
|
generator=generator |
|
) |
|
|
|
|
|
for t in tqdm(self.scheduler.timesteps, desc="Генерация"): |
|
|
|
if do_classifier_free_guidance: |
|
latent_input = torch.cat([latents] * 2) |
|
else: |
|
latent_input = latents |
|
|
|
latent_input = self.scheduler.scale_model_input(latent_input, t) |
|
|
|
|
|
noise_pred = self.unet(latent_input, t, text_embeddings).sample |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + guidance_scale * ( |
|
noise_pred_text - noise_pred_uncond |
|
) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
return latents |
|
|
|
def decode_latents(self, latents, output_type="pil"): |
|
"""Декодирование латентов в изображения.""" |
|
|
|
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor |
|
|
|
|
|
with torch.no_grad(): |
|
images = self.vae.decode(latents).sample |
|
|
|
|
|
images = (images / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
if output_type == "pil": |
|
images = images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
images = (images * 255).round().astype("uint8") |
|
return [Image.fromarray(image) for image in images] |
|
else: |
|
return images.cpu().permute(0, 2, 3, 1).float().numpy() |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Optional[Union[str, List[str]]] = None, |
|
height: int = 576, |
|
width: int = 576, |
|
num_inference_steps: int = 40, |
|
guidance_scale: float = 5.0, |
|
latent_channels: int = 16, |
|
output_type: str = "pil", |
|
return_dict: bool = True, |
|
batch_size: int = 1, |
|
seed: Optional[int] = None, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
text_embeddings: Optional[torch.FloatTensor] = None, |
|
): |
|
"""Генерация изображения из текстовых промптов или эмбеддингов.""" |
|
device = self.device |
|
|
|
|
|
generator = None |
|
if seed is not None: |
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
if text_embeddings is None: |
|
if prompt is None and negative_prompt is None: |
|
raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings") |
|
|
|
|
|
text_embeddings = self.encode_prompt( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
device=device |
|
) |
|
else: |
|
|
|
text_embeddings = text_embeddings.to(device) |
|
|
|
|
|
latents = self.generate_latents( |
|
text_embeddings=text_embeddings, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
latent_channels=latent_channels, |
|
batch_size=batch_size, |
|
generator=generator |
|
) |
|
|
|
|
|
images = self.decode_latents(latents, output_type=output_type) |
|
|
|
if not return_dict: |
|
return images |
|
|
|
return SdxsPipelineOutput(images=images) |