from fastapi import FastAPI, File, UploadFile
from tensorflow.keras.models import load_model
from PIL import Image, ImageOps
from scipy.stats import entropy
import numpy as np
import io
import time
from tensorflow.keras.layers import DepthwiseConv2D
app = FastAPI(
title="Dental Diseases Prediction API",
description="Cette API prédit le type de maladie bucco-dentaire à partir d'images.",
def read_root():
return {"message": "API Dental Diseases Prediction"}
def custom_depthwise_conv2d(*args, **kwargs):
if 'groups' in kwargs:
del kwargs['groups']
return DepthwiseConv2D(*args, **kwargs)
model = load_model("keras_model.h5", custom_objects={'DepthwiseConv2D': custom_depthwise_conv2d}, compile=False)
with open("labels.txt", "r") as file:
class_names = [line.strip() for line in file.readlines()]
def preprocess(image):
image = ImageOps.fit(image, (224, 224), Image.Resampling.LANCZOS)
image_array = np.asarray(image.convert("RGB"), dtype=np.float32)
normalized_image_array = (image_array / 127.5) - 1
data = np.ndarray(shape=(1, 224, 224, 3), dtype=np.float32)
data[0] = normalized_image_array
return data
@app.post("/predict/", summary="Upload an image for dental disease prediction", description="This endpoint allows you to upload an image of the teeth and get a prediction of the dental disease.")
async def predict(file: UploadFile = File(...)):
Charge une image et la passe au modèle pour prédire.
- **file**: fichier image (jpeg ou png) à analyser
image = Image.open(io.BytesIO(await file.read()))
except Exception as e:
return {"error": f"Invalid image file: {str(e)}"}
processed_image = preprocess(image)
start_time = time.time()
prediction = model.predict(processed_image)
processing_time = time.time() - start_time
predicted_class_index = np.argmax(prediction[0])
predicted_class = class_names[predicted_class_index]
sorted_probs = np.sort(prediction[0])[::-1]
confidence_margin = sorted_probs[0] - sorted_probs[1]
uncertainty = entropy(prediction[0])
return {
"prediction": predicted_class,
"confidence": float(prediction[0][predicted_class_index]),
"confidence_margin": float(confidence_margin),
"uncertainty": float(uncertainty),
"processing_time": float(processing_time)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="", port=8000) |