from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union import torch from diffusers import StableDiffusionImg2ImgPipeline from internals.data.result import Result from internals.pipelines.twoStepPipeline import two_step_pipeline from internals.util.commons import disable_safety_checker, download_image from internals.util.config import get_hf_token, num_return_sequences class AbstractPipeline: def load(self, model_dir: str): pass def create(self, pipe): pass class Text2Img(AbstractPipeline): @dataclass class Params: prompt: List[str] = None modified_prompt: List[str] = None prompt_left: List[str] = None prompt_right: List[str] = None def load(self, model_dir: str): self.pipe = two_step_pipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() ).to("cuda") self.__patch() def is_loaded(self): if hasattr(self, "pipe"): return True return False def create(self, pipeline: AbstractPipeline): self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") self.__patch() def __patch(self): self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, params: Params, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[str] = None, num_images_per_prompt: int = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, iteration: float = 3.0, ): prompt = params.prompt if params.prompt_left and params.prompt_right: # multi-character pipelines prompt = [params.prompt[0], params.prompt_left[0], params.prompt_right[0]] result = self.pipe.multi_character_diffusion( prompt=prompt, pos=["1:1-0:0", "1:2-0:0", "1:2-0:1"], mix_val=[0.2, 0.8, 0.8], height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=[negative_prompt or ""] * len(prompt), num_images_per_prompt=num_return_sequences, eta=eta, # generator=generator, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, ) else: # two step pipeline modified_prompt = params.modified_prompt result = self.pipe.two_step_pipeline( prompt=prompt, modified_prompts=modified_prompt, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=[negative_prompt or ""] * num_return_sequences, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, iteration=iteration, ) return Result.from_result(result) class Img2Img(AbstractPipeline): __loaded = False def load(self, model_dir: str): if self.__loaded: return self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_dir, torch_dtype=torch.float16, use_auth_token=get_hf_token() ).to("cuda") self.__patch() self.__loaded = True def create(self, pipeline: AbstractPipeline): self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( "cuda" ) self.__patch() self.__loaded = True def __patch(self): self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, prompt: List[str], imageUrl: str, negative_prompt: List[str], strength: float, guidance_scale: float, steps: int, width: int, height: int, ): image = download_image(imageUrl).resize((width, height)) result = self.pipe.__call__( prompt=prompt, image=image, strength=strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_images_per_prompt=1, num_inference_steps=steps, ) return Result.from_result(result)