Spaces:
Running
Running
File size: 2,155 Bytes
be14aa6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
# Import Fast API
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.templating import Jinja2Templates
from fastapi.responses import StreamingResponse
# Import bytes
from io import BytesIO
import os
# Import logging
import logging
# Import utilities
from src.utils.utils import IMAGE_FORMATS
# Import machine learning
from src.predict import predict
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
# Initialazing FastAPI application
app = FastAPI()
# Initialazing templates
templates = Jinja2Templates(directory="templates")
# Initialazing logger
logger = logging.getLogger(__name__)
logger.info(f"Loading YOLO model...")
# Download YOLO model from Hugging Face Hub
model_path = hf_hub_download(
repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt"
)
# Load YOLO model
model = YOLO(model_path)
# Index route
@app.get("/")
async def root(request: Request):
context = {"request": request}
# Render index.html
return templates.TemplateResponse("index.html", context)
# Upload images decorator
@app.post("/predict-img")
def predict_image(file: UploadFile = File(...)):
try:
# Try to read the file
contents = file.file.read()
# Open file and write contents
with open(file.filename, "wb") as f:
f.write(contents)
# Get image filename
image = file.filename
# Check if image format is valid
if not image.endswith(IMAGE_FORMATS):
# If not, raise an error
raise ValueError("Invalid image format")
except Exception as e:
# If there is an error, return the error
return {f"{e}"}
finally:
file.file.close()
# Getting image path
image = file.filename
# Predicting
results = predict(model, image)
# TODO
# extract extension from image and use it to save the image
# Convert image to bytes
img_bytes = BytesIO()
results.save(img_bytes, "JPEG")
img_bytes.seek(0)
# Removing the image
os.remove(image)
# Render image
return StreamingResponse(content=img_bytes, media_type="image/jpeg")
|