File size: 1,882 Bytes
01a6f2b
c794c9c
01a6f2b
 
b16b8cb
54d526a
8aac058
7b30620
98550e9
c794c9c
 
 
 
b1e03fc
c794c9c
b1e03fc
c794c9c
b1e03fc
b16b8cb
c794c9c
54d526a
7b30620
c794c9c
 
 
7b30620
c794c9c
 
7b30620
c794c9c
b1e03fc
c794c9c
7b30620
13a7c45
 
01a6f2b
 
54d526a
13a7c45
01a6f2b
54d526a
01a6f2b
54d526a
7b30620
c794c9c
01a6f2b
c794c9c
b16b8cb
 
b1e03fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import diffusers
import torch
from fastapi import FastAPI, UploadFile, HTTPException, File
from fastapi.responses import StreamingResponse
from PIL import Image
import io

app = FastAPI()

# Inicializa el pipeline al arrancar el servidor
@app.on_event("startup")
async def startup_event():
    global pipe
    print("[DEBUG] Cargando modelo Marigold-v1-0...")
    pipe = diffusers.MarigoldDepthPipeline.from_pretrained(
        "prs-eth/marigold-v1-0", variant="fp16", torch_dtype=torch.float16
    ).to("cuda")
    print("[DEBUG] Modelo Marigold-v1-0 cargado exitosamente.")

@app.post("/predict-depth/")
async def predict_depth(file: UploadFile = File(...)):
    try:
        # Verifica si el archivo es una imagen v谩lida
        if not file.content_type.startswith("image/"):
            raise HTTPException(status_code=400, detail="El archivo subido no es una imagen.")

        # Carga la imagen desde el archivo subido
        image = Image.open(file.file).convert("RGB")

        # Realiza la predicci贸n de profundidad
        print("[DEBUG] Realizando predicci贸n de profundidad con Marigold-v1-0...")
        depth = pipe(image)

        # Exporta la profundidad como una imagen 16-bit PNG
        depth_16bit = pipe.image_processor.export_depth_to_16bit_png(depth.prediction)

        # Guarda la imagen generada en un buffer
        img_buffer = io.BytesIO()
        depth_16bit[0].save(img_buffer, format="PNG")
        img_buffer.seek(0)

        # Devuelve la imagen como respuesta
        return StreamingResponse(img_buffer, media_type="image/png")
    except Exception as e:
        print(f"[ERROR] {str(e)}")
        raise HTTPException(status_code=500, detail="Error procesando la imagen.")

@app.get("/")
async def root():
    return {"message": "API de generaci贸n de mapas de profundidad con Marigold-v1-0"}