HashamUllah commited on
Commit
fceec31
1 Parent(s): 523d797

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -4
app.py CHANGED
@@ -1,7 +1,50 @@
 
 
1
  import tensorflow as tf
 
 
 
 
2
 
3
- # Load the existing model
4
- model = tf.keras.models.load_model('./plant_disease_detection_saved_model')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Save it in the SavedModel format
7
- model.save('./plant_disease_detection_saved_model', save_format='tf')
 
1
+ from fastapi import FastAPI, File, UploadFile, HTTPException
2
+ from fastapi.responses import JSONResponse
3
  import tensorflow as tf
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+ import json
8
 
9
+ app = FastAPI()
10
+
11
+ # Load the TensorFlow model
12
+ model = tf.keras.models.load_model('./plant_disease_detection.h5')
13
+
14
+ # Load categories
15
+ with open('./categories.json') as f:
16
+ categories = json.load(f)
17
+
18
+ def preprocess_image(image_bytes):
19
+ # Convert the image to a NumPy array
20
+ image = Image.open(io.BytesIO(image_bytes))
21
+ image = image.resize((224, 224)) # Adjust size as needed
22
+ image_array = np.array(image) / 255.0 # Normalize to [0, 1]
23
+ image_array = np.expand_dims(image_array, axis=0) # Add batch dimension
24
+ return image_array
25
+
26
+ @app.post('/predict')
27
+ async def predict(file: UploadFile = File(...)):
28
+ if file.content_type.startswith('image/') is False:
29
+ raise HTTPException(status_code=400, detail='Invalid file type')
30
+
31
+ image_bytes = await file.read()
32
+ image_array = preprocess_image(image_bytes)
33
+
34
+ # Make prediction
35
+ predictions = model.predict(image_array)
36
+ predicted_class = np.argmax(predictions, axis=1)[0]
37
+
38
+ # Map to category names
39
+ predicted_label = categories.get(str(predicted_class), 'Unknown')
40
+
41
+ return JSONResponse(content={
42
+ 'class': predicted_label,
43
+ 'confidence': float(predictions[0][predicted_class])
44
+ })
45
+
46
+
47
+ if __name__ == '__main__':
48
+ import uvicorn
49
+ uvicorn.run(app, host='0.0.0.0', port=8080)
50