pan_crd_lora_v2 / handler.py
krtk00's picture
Update handler.py
ed9e6ae verified
raw
history blame
1.61 kB
import os
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image
from io import BytesIO
import base64
class EndpointHandler:
def __init__(self, path: str = ""):
"""
Initialize the handler, loading the model and LoRA weights.
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Retrieve the Hugging Face token from environment variable
hf_token = os.getenv("HF_TOKEN") # Ensure HF_TOKEN is set in environment
# Load the model using the token
self.pipeline = AutoPipelineForText2Image.from_pretrained(
'black-forest-labs/FLUX.1-dev',
use_auth_token=hf_token,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
).to(self.device)
# Load LoRA weights
lora_weights_path = 'krtk00/pan_crd_lora_v2'
self.pipeline.load_lora_weights(lora_weights_path, weight_name='lora.safetensors')
def __call__(self, data):
prompt = data.get("inputs", None)
if not prompt:
raise ValueError("No prompt provided in the input")
with torch.no_grad():
images = self.pipeline(prompt).images
# Get the first generated image
pil_image = images[0]
# Convert the image to bytes to return as the serialized format (for instance, base64)
buffered = BytesIO()
pil_image.save(buffered, format="PNG")
img_bytes = buffered.getvalue()
return img_bytes # Return the image bytes directly for Hugging Face to serialize