File size: 2,155 Bytes
be14aa6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
# Import Fast API
from fastapi import FastAPI, Request, UploadFile, File
from fastapi.templating import Jinja2Templates
from fastapi.responses import StreamingResponse

# Import bytes
from io import BytesIO
import os

# Import logging
import logging

# Import utilities
from src.utils.utils import IMAGE_FORMATS

# Import machine learning
from src.predict import predict
from ultralytics import YOLO
from huggingface_hub import hf_hub_download


# Initialazing FastAPI application
app = FastAPI()

# Initialazing templates
templates = Jinja2Templates(directory="templates")

# Initialazing logger
logger = logging.getLogger(__name__)


logger.info(f"Loading YOLO model...")

# Download YOLO model from Hugging Face Hub
model_path = hf_hub_download(
    repo_id="arnabdhar/YOLOv8-Face-Detection", filename="model.pt"
)


# Load YOLO model
model = YOLO(model_path)


# Index route
@app.get("/")
async def root(request: Request):
    context = {"request": request}
    # Render index.html
    return templates.TemplateResponse("index.html", context)


# Upload images decorator
@app.post("/predict-img")
def predict_image(file: UploadFile = File(...)):
    try:
        # Try to read the file
        contents = file.file.read()

        # Open file and write contents
        with open(file.filename, "wb") as f:
            f.write(contents)

        # Get image filename
        image = file.filename

        # Check if image format is valid
        if not image.endswith(IMAGE_FORMATS):
            # If not, raise an error
            raise ValueError("Invalid image format")

    except Exception as e:
        # If there is an error, return the error
        return {f"{e}"}

    finally:
        file.file.close()

    # Getting image path
    image = file.filename

    # Predicting
    results = predict(model, image)

    # TODO
    # extract extension from image and use it to save the image
    # Convert image to bytes
    img_bytes = BytesIO()
    results.save(img_bytes, "JPEG")
    img_bytes.seek(0)

    # Removing the image
    os.remove(image)

    # Render image
    return StreamingResponse(content=img_bytes, media_type="image/jpeg")