huyai123's picture
Update handler.py
41181c9 verified
raw
history blame contribute delete
2.4 kB
import os
import torch
from PIL import Image
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from io import BytesIO
import logging
class EndpointHandler:
def __init__(self, model_dir="huyai123/Flux.1-dev-Image-Upscaler"):
# Set memory limit
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
# Access the environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
raise ValueError("HF_TOKEN environment variable is not set")
logging.basicConfig(level=logging.INFO)
logging.info("Using HF_TOKEN")
# Clear GPU memory
torch.cuda.empty_cache()
# Load model and pipeline
self.controlnet = FluxControlNetModel.from_pretrained(
model_dir, torch_dtype=torch.float16, use_auth_token=HF_TOKEN
)
self.pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN
)
self.pipe.to("cuda")
self.pipe.enable_attention_slicing("auto")
self.pipe.enable_sequential_cpu_offload()
self.pipe.enable_memory_efficient_attention()
def preprocess(self, data):
image_file = data.get("control_image", None)
if not image_file:
raise ValueError("Missing control_image in input.")
image = Image.open(image_file)
return image.resize((512, 512)) # Resize to reduce memory usage
def postprocess(self, output):
buffer = BytesIO()
output.save(buffer, format="PNG")
buffer.seek(0)
return buffer
def inference(self, data):
control_image = self.preprocess(data)
torch.cuda.empty_cache()
output_image = self.pipe(
prompt=data.get("prompt", ""),
control_image=control_image,
controlnet_conditioning_scale=0.5,
num_inference_steps=10,
height=control_image.size[1],
width=control_image.size[0],
).images[0]
return self.postprocess(output_image)
if __name__ == "__main__":
data = {'control_image': 'path/to/your/image.png', 'prompt': 'Your prompt here'}
handler = EndpointHandler()
output = handler.inference(data)
print(output)