Spaces:
Sleeping
Sleeping
File size: 1,882 Bytes
01a6f2b c794c9c 01a6f2b b16b8cb 54d526a 8aac058 7b30620 98550e9 c794c9c b1e03fc c794c9c b1e03fc c794c9c b1e03fc b16b8cb c794c9c 54d526a 7b30620 c794c9c 7b30620 c794c9c 7b30620 c794c9c b1e03fc c794c9c 7b30620 13a7c45 01a6f2b 54d526a 13a7c45 01a6f2b 54d526a 01a6f2b 54d526a 7b30620 c794c9c 01a6f2b c794c9c b16b8cb b1e03fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import diffusers
import torch
from fastapi import FastAPI, UploadFile, HTTPException, File
from fastapi.responses import StreamingResponse
from PIL import Image
import io
app = FastAPI()
# Inicializa el pipeline al arrancar el servidor
@app.on_event("startup")
async def startup_event():
global pipe
print("[DEBUG] Cargando modelo Marigold-v1-0...")
pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
"prs-eth/marigold-v1-0", variant="fp16", torch_dtype=torch.float16
).to("cuda")
print("[DEBUG] Modelo Marigold-v1-0 cargado exitosamente.")
@app.post("/predict-depth/")
async def predict_depth(file: UploadFile = File(...)):
try:
# Verifica si el archivo es una imagen v谩lida
if not file.content_type.startswith("image/"):
raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.")
# Carga la imagen desde el archivo subido
image = Image.open(file.file).convert("RGB")
# Realiza la predicci贸n de profundidad
print("[DEBUG] Realizando predicci贸n de profundidad con Marigold-v1-0...")
depth = pipe(image)
# Exporta la profundidad como una imagen 16-bit PNG
depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)
# Guarda la imagen generada en un buffer
img_buffer = io.BytesIO()
depth_16bit[0].save(img_buffer, format="PNG")
img_buffer.seek(0)
# Devuelve la imagen como respuesta
return StreamingResponse(img_buffer, media_type="image/png")
except Exception as e:
print(f"[ERROR] {str(e)}")
raise HTTPException(status_code=500, detail="Error procesando la imagen.")
@app.get("/")
async def root():
return {"message": "API de generaci贸n de mapas de profundidad con Marigold-v1-0"}
|