HashamUllah commited on
Commit
8cfdcf9
1 Parent(s): 1c2ca15

Upload 7 files

Browse files
Files changed (7) hide show
  1. ReadMe.txt +6 -0
  2. cnn_model.h5 +3 -0
  3. crop_image.jpg +0 -0
  4. crop_image1.jpg +0 -0
  5. crop_image2.jpg +0 -0
  6. label_transform.pkl +3 -0
  7. main.py +83 -0
ReadMe.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ - cnn_model.h5 and label_transform.pkl is Model File.
2
+
3
+ - main.py is to needed to run to start the server.
4
+ - in main.py, "/test-predict" is for testing purpose, "/predict" endpoint is for uploading image.
5
+ - "/test-predict" endpoint contains image picked from the same folder.
6
+ - after starting the server, call "http://0.0.0.0:8000/test-predict"in browser to test the prediction on the image.
cnn_model.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68953b84edbf58393bdb203c7a08062d4367382883025f5b729d607e968cffeb
3
+ size 697340252
crop_image.jpg ADDED
crop_image1.jpg ADDED
crop_image2.jpg ADDED
label_transform.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1322197d23b302375c2f8a75a8e6954375a2241af686fa40d02be8f4f1bd07b
3
+ size 2906
main.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import JSONResponse
3
+ import numpy as np
4
+ import cv2
5
+ import pickle
6
+ from tensorflow.keras.models import load_model
7
+ from tensorflow.keras.preprocessing.image import img_to_array
8
+
9
+ app = FastAPI()
10
+
11
+ print("app run")
12
+ # Load the model and the label binarizer
13
+ model = load_model('cnn_model.h5')
14
+ print("model loaded")
15
+ label_binarizer = pickle.load(open('label_transform.pkl', 'rb'))
16
+ print("labels loaded")
17
+
18
+ # Function to convert images to array
19
+ def convert_image_to_array(image_dir):
20
+ try:
21
+ image = cv2.imdecode(np.frombuffer(image_dir, np.uint8), cv2.IMREAD_COLOR)
22
+ if image is not None:
23
+ image = cv2.resize(image, (256, 256))
24
+ return img_to_array(image)
25
+ else:
26
+ return np.array([])
27
+ except Exception as e:
28
+ print(f"Error : {e}")
29
+ return None
30
+
31
+ @app.post("/predict")
32
+ async def predict(file: UploadFile = File(...)):
33
+ try:
34
+ # Read the file and convert it to an array
35
+ image_data = await file.read()
36
+ image_array = convert_image_to_array(image_data)
37
+
38
+ if image_array.size == 0:
39
+ return JSONResponse(content={"error": "Invalid image"}, status_code=400)
40
+
41
+ # Normalize the image
42
+ image_array = np.array(image_array, dtype=np.float16) / 255.0
43
+
44
+ # Ensure the image_array has the correct shape (1, 256, 256, 3)
45
+ image_array = np.expand_dims(image_array, axis=0)
46
+
47
+ # Make a prediction
48
+ prediction = model.predict(image_array)
49
+ predicted_class = label_binarizer.inverse_transform(prediction)[0]
50
+
51
+ return {"prediction": predicted_class}
52
+ except Exception as e:
53
+ return JSONResponse(content={"error": str(e)}, status_code=500)
54
+
55
+ # Add a test GET endpoint to manually trigger the prediction
56
+ @app.get("/test-predict")
57
+ def test_predict():
58
+ try:
59
+ image_path = 'crop_image1.jpg'
60
+ image = cv2.imread(image_path)
61
+ image_array = cv2.resize(image, (256, 256))
62
+ image_array = img_to_array(image_array)
63
+
64
+ if image_array.size == 0:
65
+ return JSONResponse(content={"error": "Invalid image"}, status_code=400)
66
+
67
+ # Normalize the image
68
+ image_array = np.array(image_array, dtype=np.float16) / 255.0
69
+
70
+ # Ensure the image_array has the correct shape (1, 256, 256, 3)
71
+ image_array = np.expand_dims(image_array, axis=0)
72
+
73
+ # Make a prediction
74
+ prediction = model.predict(image_array)
75
+ predicted_class = label_binarizer.inverse_transform(prediction)[0]
76
+
77
+ return {"prediction": predicted_class}
78
+ except Exception as e:
79
+ return JSONResponse(content={"error": str(e)}, status_code=500)
80
+
81
+ if __name__ == "__main__":
82
+ import uvicorn
83
+ uvicorn.run(app, host="127.0.0.1", port=8000)