Spaces:
Running
on
T4
Running
on
T4
import diffusers | |
import torch | |
from fastapi import FastAPI, UploadFile, HTTPException, File | |
from fastapi.responses import StreamingResponse | |
from PIL import Image | |
import io | |
app = FastAPI() | |
# Inicializa el pipeline al arrancar el servidor | |
async def startup_event(): | |
global pipe | |
print("[DEBUG] Cargando modelo Marigold-v1-0...") | |
pipe = diffusers.MarigoldDepthPipeline.from_pretrained( | |
"prs-eth/marigold-v1-0", variant="fp16", torch_dtype=torch.float16 | |
).to("cuda") | |
print("[DEBUG] Modelo Marigold-v1-0 cargado exitosamente.") | |
async def predict_depth(file: UploadFile = File(...)): | |
try: | |
# Verifica si el archivo es una imagen v谩lida | |
if not file.content_type.startswith("image/"): | |
raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.") | |
# Carga la imagen desde el archivo subido | |
image = Image.open(file.file).convert("RGB") | |
# Realiza la predicci贸n de profundidad | |
print("[DEBUG] Realizando predicci贸n de profundidad con Marigold-v1-0...") | |
depth = pipe(image) | |
# Exporta la profundidad como una imagen 16-bit PNG | |
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction) | |
# Guarda la imagen generada en un buffer | |
img_buffer = io.BytesIO() | |
depth_16bit[0].save(img_buffer, format="PNG") | |
img_buffer.seek(0) | |
# Devuelve la imagen como respuesta | |
return StreamingResponse(img_buffer, media_type="image/png") | |
except Exception as e: | |
print(f"[ERROR] {str(e)}") | |
raise HTTPException(status_code=500, detail="Error procesando la imagen.") | |
async def root(): | |
return {"message": "API de generaci贸n de mapas de profundidad con Marigold-v1-0"} | |