radames's picture
Update server/pipelines/pix2pixTurbo.py
5f8cb33 verified
raw
history blame contribute delete
No virus
3.45 kB
import torch
from torchvision import transforms
from config import Args
from pydantic import BaseModel, Field
from PIL import Image
from pipelines.pix2pix.pix2pix_turbo import Pix2Pix_Turbo
from pipelines.utils.canny_gpu import ScharrOperator
default_prompt = "close-up photo of the joker"
page_content = """
<h1 class="text-3xl font-bold">Real-Time pix2pix_turbo</h1>
<h3 class="text-xl font-bold">pix2pix turbo</h3>
<p class="text-sm">
This demo showcases
<a
href="https://github.com/GaParmar/img2img-turbo"
target="_blank"
class="text-blue-500 underline hover:no-underline">One-Step Image Translation with Text-to-Image Models
</a>
</p>
<p class="text-sm text-gray-500">
Web app <a href="https://github.com/radames/Real-Time-Latent-Consistency-Model" target="_blank" class="text-blue-500 underline hover:no-underline">
Real-Time Latent Consistency Models
</a>
</p>
"""
class Pipeline:
class Info(BaseModel):
name: str = "img2img"
title: str = "Image-to-Image SDXL"
description: str = "Generates an image from a text prompt"
input_mode: str = "image"
page_content: str = page_content
class InputParams(BaseModel):
prompt: str = Field(
default_prompt,
title="Prompt",
field="textarea",
id="prompt",
)
width: int = Field(
512, min=2, max=15, title="Width", disabled=True, hide=True, id="width"
)
height: int = Field(
512, min=2, max=15, title="Height", disabled=True, hide=True, id="height"
)
canny_low_threshold: float = Field(
0.0,
min=0,
max=1.0,
step=0.001,
title="Canny Low Threshold",
field="range",
hide=True,
id="canny_low_threshold",
)
canny_high_threshold: float = Field(
1.0,
min=0,
max=1.0,
step=0.001,
title="Canny High Threshold",
field="range",
hide=True,
id="canny_high_threshold",
)
debug_canny: bool = Field(
False,
title="Debug Canny",
field="checkbox",
hide=True,
id="debug_canny",
)
def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
self.model = Pix2Pix_Turbo("edge_to_image")
self.canny_torch = ScharrOperator(device=device)
self.device = device
self.last_time = 0.0
def predict(self, params: "Pipeline.InputParams") -> Image.Image:
canny_pil, canny_tensor = self.canny_torch(
params.image,
params.canny_low_threshold,
params.canny_high_threshold,
output_type="pil,tensor",
)
canny_tensor = torch.cat((canny_tensor, canny_tensor, canny_tensor), dim=1)
output_image = self.model(
canny_tensor,
params.prompt,
)
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
result_image = output_pil
if params.debug_canny:
# paste control_image on top of result_image
w0, h0 = (200, 200)
control_image = canny_pil.resize((w0, h0))
w1, h1 = result_image.size
result_image.paste(control_image, (w1 - w0, h1 - h0))
return result_image