from typing import Any, Callable, Dict, List, Optional, Union import torch from diffusers import StableDiffusionImg2ImgPipeline from internals.data.result import Result from internals.pipelines.twoStepPipeline import two_step_pipeline from internals.util.commons import disable_safety_checker, download_image class AbstractPipeline: def load(self, model_dir: str): pass def create(self, pipe): pass class Text2Img(AbstractPipeline): def load(self, model_dir: str): self.pipe = two_step_pipeline.from_pretrained( model_dir, torch_dtype=torch.float16 ).to("cuda") self.__patch() def create(self, pipeline: AbstractPipeline): self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda") self.__patch() def __patch(self): self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, prompt: Union[str, List[str]] = None, modified_prompts: Union[str, List[str]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, iteration: float = 3.0, ): result = self.pipe.two_step_pipeline( prompt=prompt, modified_prompts=modified_prompts, height=height, width=width, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, negative_prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, eta=eta, generator=generator, latents=latents, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, output_type=output_type, return_dict=return_dict, callback=callback, callback_steps=callback_steps, cross_attention_kwargs=cross_attention_kwargs, iteration=iteration, ) return Result.from_result(result) class Img2Img(AbstractPipeline): def load(self, model_dir: str): self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_dir, torch_dtype=torch.float16 ).to("cuda") self.__patch() def create(self, pipeline: AbstractPipeline): self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to( "cuda" ) self.__patch() def __patch(self): self.pipe.enable_xformers_memory_efficient_attention() @torch.inference_mode() def process( self, prompt: List[str], imageUrl: str, negative_prompt: List[str], strength: float, guidance_scale: float, steps: int, width: int, height: int, ): image = download_image(imageUrl).resize((width, height)) result = self.pipe.__call__( prompt=prompt, image=image, strength=strength, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_images_per_prompt=1, num_inference_steps=steps, ) return Result.from_result(result)