|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional, Union |
|
|
|
import torch |
|
from diffusers import StableDiffusionImg2ImgPipeline |
|
|
|
from internals.data.result import Result |
|
from internals.pipelines.twoStepPipeline import two_step_pipeline |
|
from internals.util.commons import disable_safety_checker, download_image |
|
from internals.util.config import num_return_sequences |
|
|
|
|
|
class AbstractPipeline: |
|
def load(self, model_dir: str): |
|
pass |
|
|
|
def create(self, pipe): |
|
pass |
|
|
|
|
|
class Text2Img(AbstractPipeline): |
|
@dataclass |
|
class Params: |
|
prompt: List[str] = None |
|
modified_prompt: List[str] = None |
|
prompt_left: List[str] = None |
|
prompt_right: List[str] = None |
|
|
|
def load(self, model_dir: str): |
|
self.pipe = two_step_pipeline.from_pretrained( |
|
model_dir, torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__patch() |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") |
|
self.__patch() |
|
|
|
def __patch(self): |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
params: Params, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
num_inference_steps: int = 50, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[str] = None, |
|
num_images_per_prompt: int = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
|
callback_steps: int = 1, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
iteration: float = 3.0, |
|
): |
|
prompt = params.prompt |
|
|
|
if params.prompt_left and params.prompt_right: |
|
|
|
prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]] |
|
result = self.pipe.multi_character_diffusion( |
|
prompt=prompt, |
|
pos=["1:1-0:0", "1:2-0:0", "1:2-0:1"], |
|
mix_val=[0.2, 0.8, 0.8], |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=[negative_prompt or ""] * len(prompt), |
|
num_images_per_prompt=num_return_sequences, |
|
eta=eta, |
|
|
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
) |
|
else: |
|
|
|
modified_prompt = params.modified_prompt |
|
|
|
result = self.pipe.two_step_pipeline( |
|
prompt=prompt, |
|
modified_prompts=modified_prompt, |
|
height=height, |
|
width=width, |
|
num_inference_steps=num_inference_steps, |
|
guidance_scale=guidance_scale, |
|
negative_prompt=[negative_prompt or ""] * num_return_sequences, |
|
num_images_per_prompt=num_images_per_prompt, |
|
eta=eta, |
|
generator=generator, |
|
latents=latents, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
output_type=output_type, |
|
return_dict=return_dict, |
|
callback=callback, |
|
callback_steps=callback_steps, |
|
cross_attention_kwargs=cross_attention_kwargs, |
|
iteration=iteration, |
|
) |
|
|
|
return Result.from_result(result) |
|
|
|
|
|
class Img2Img(AbstractPipeline): |
|
def load(self, model_dir: str): |
|
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
|
model_dir, torch_dtype=torch.float16 |
|
).to("cuda") |
|
self.__patch() |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
self.__patch() |
|
|
|
def __patch(self): |
|
self.pipe.enable_xformers_memory_efficient_attention() |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
prompt: List[str], |
|
imageUrl: str, |
|
negative_prompt: List[str], |
|
strength: float, |
|
guidance_scale: float, |
|
steps: int, |
|
width: int, |
|
height: int, |
|
): |
|
image = download_image(imageUrl).resize((width, height)) |
|
|
|
result = self.pipe.__call__( |
|
prompt=prompt, |
|
image=image, |
|
strength=strength, |
|
negative_prompt=negative_prompt, |
|
guidance_scale=guidance_scale, |
|
num_images_per_prompt=1, |
|
num_inference_steps=steps, |
|
) |
|
return Result.from_result(result) |
|
|