jayparmr's picture
Upload 118 files
19b3da3
raw
history blame
3.92 kB
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from internals.data.result import Result
from internals.pipelines.twoStepPipeline import two_step_pipeline
from internals.util.commons import disable_safety_checker, download_image
class AbstractPipeline:
def load(self, model_dir: str):
pass
def create(self, pipe):
pass
class Text2Img(AbstractPipeline):
def load(self, model_dir: str):
self.pipe = two_step_pipeline.from_pretrained(
model_dir, torch_dtype=torch.float16
).to("cuda")
self.__patch()
def create(self, pipeline: AbstractPipeline):
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
self.__patch()
def __patch(self):
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
prompt: Union[str, List[str]] = None,
modified_prompts: Union[str, List[str]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
iteration: float = 3.0,
):
result = self.pipe.two_step_pipeline(
prompt=prompt,
modified_prompts=modified_prompts,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
eta=eta,
generator=generator,
latents=latents,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
output_type=output_type,
return_dict=return_dict,
callback=callback,
callback_steps=callback_steps,
cross_attention_kwargs=cross_attention_kwargs,
iteration=iteration,
)
return Result.from_result(result)
class Img2Img(AbstractPipeline):
def load(self, model_dir: str):
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
model_dir, torch_dtype=torch.float16
).to("cuda")
self.__patch()
def create(self, pipeline: AbstractPipeline):
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
"cuda"
)
self.__patch()
def __patch(self):
self.pipe.enable_xformers_memory_efficient_attention()
@torch.inference_mode()
def process(
self,
prompt: List[str],
imageUrl: str,
negative_prompt: List[str],
strength: float,
guidance_scale: float,
steps: int,
width: int,
height: int,
):
image = download_image(imageUrl).resize((width, height))
result = self.pipe.__call__(
prompt=prompt,
image=image,
strength=strength,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_inference_steps=steps,
)
return Result.from_result(result)