File size: 1,610 Bytes
e3d77b5
2e28b6e
 
ed9e6ae
 
 
2e28b6e
 
 
 
 
 
 
e3d77b5
 
 
 
 
2e28b6e
e3d77b5
 
2e28b6e
 
 
 
 
 
 
e3d77b5
2e28b6e
 
 
ed9e6ae
2e28b6e
 
ed9e6ae
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import os
from diffusers import AutoPipelineForText2Image
import torch
from PIL import Image
from io import BytesIO
import base64

class EndpointHandler:
    def __init__(self, path: str = ""):
        """
        Initialize the handler, loading the model and LoRA weights.
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Retrieve the Hugging Face token from environment variable
        hf_token = os.getenv("HF_TOKEN")  # Ensure HF_TOKEN is set in environment

        # Load the model using the token
        self.pipeline = AutoPipelineForText2Image.from_pretrained(
            'black-forest-labs/FLUX.1-dev',
            use_auth_token=hf_token,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        ).to(self.device)

        # Load LoRA weights
        lora_weights_path = 'krtk00/pan_crd_lora_v2'
        self.pipeline.load_lora_weights(lora_weights_path, weight_name='lora.safetensors')

    def __call__(self, data):
        prompt = data.get("inputs", None)
        if not prompt:
            raise ValueError("No prompt provided in the input")

        with torch.no_grad():
            images = self.pipeline(prompt).images

        # Get the first generated image
        pil_image = images[0]

        # Convert the image to bytes to return as the serialized format (for instance, base64)
        buffered = BytesIO()
        pil_image.save(buffered, format="PNG")
        img_bytes = buffered.getvalue()

        return img_bytes  # Return the image bytes directly for Hugging Face to serialize