mintheinwin commited on
Commit
1283eb6
1 Parent(s): fadd86a

update app

Browse files
Files changed (1) hide show
  1. app.py +91 -24
app.py CHANGED
@@ -1,8 +1,11 @@
1
- from ultralytics import YOLO
2
- from PIL import Image
3
  import gradio as gr
4
  from huggingface_hub import snapshot_download
 
5
  import os
 
 
 
 
6
 
7
  #public model path location
8
  #MODEL_REPO_ID = "mintheinwin/3907578Y"
@@ -10,34 +13,98 @@ import os
10
  #Organizations model path location
11
  MODEL_REPO_ID = "ITI107-2024S2/3907578Y"
12
 
13
- #load model
14
  def load_model(repo_id):
15
  download_dir = snapshot_download(repo_id)
16
- print(download_dir)
17
- path = os.path.join(download_dir, "best_int8_openvino_model")
18
- print(path)
19
- detection_model = YOLO(path, task='detect')
20
  return detection_model
21
-
 
22
  detection_model = load_model(MODEL_REPO_ID)
23
 
24
  #Student ID
25
  student_info = "Student Id: 3907578Y, Name: Min Thein Win"
26
 
27
- #prdeict
28
- def predict(pilimg):
29
- source = pilimg
30
- result = detection_model.predict(source, conf=0.5, iou=0.5)
31
- img_bgr = result[0].plot()
32
- out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # RGB-order PIL image
33
  return out_pilimg
34
-
35
- #UI interface
36
- gr.Markdown("# Wild Animal Detection (Tiger/Lion)")
37
- gr.Markdown(student_info)
38
- gr.Interface(fn=predict,
39
- inputs=gr.Image(type="pil",label="Input"),
40
- outputs=gr.Image(type="pil",label="Output"),
41
- title="Wild Animal Detection (Tiger/Lion)",
42
- description=student_info,
43
- ).launch(share=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  from huggingface_hub import snapshot_download
3
+ from ultralytics import YOLO
4
  import os
5
+ from PIL import Image
6
+ import cv2
7
+ import numpy as np
8
+ import tempfile
9
 
10
  #public model path location
11
  #MODEL_REPO_ID = "mintheinwin/3907578Y"
 
13
  #Organizations model path location
14
  MODEL_REPO_ID = "ITI107-2024S2/3907578Y"
15
 
16
+ # Load model
17
  def load_model(repo_id):
18
  download_dir = snapshot_download(repo_id)
19
+ path = os.path.join(download_dir, "best_int8_openvino_model")
20
+ detection_model = YOLO(path, task="detect")
 
 
21
  return detection_model
22
+
23
+
24
  detection_model = load_model(MODEL_REPO_ID)
25
 
26
  #Student ID
27
  student_info = "Student Id: 3907578Y, Name: Min Thein Win"
28
 
29
+ #Prediction for images
30
+ def predict_image(pil_img):
31
+ result = detection_model.predict(pil_img, conf=0.5, iou=0.5)
32
+ img_bgr = result[0].plot() # Annotated image
33
+ out_pilimg = Image.fromarray(img_bgr[..., ::-1]) # Convert to RGB PIL image
 
34
  return out_pilimg
35
+
36
+ #Prediction for videos
37
+ def predict_video(video):
38
+ cap = cv2.VideoCapture(video)
39
+ frames = []
40
+ temp_dir = tempfile.mkdtemp()
41
+
42
+ while cap.isOpened():
43
+ ret, frame = cap.read()
44
+ if not ret:
45
+ break
46
+
47
+ # Detection
48
+ result = detection_model.predict(frame, conf=0.5, iou=0.5)
49
+ annotated_frame = result[0].plot()
50
+ frames.append(annotated_frame)
51
+
52
+ cap.release()
53
+
54
+ # Save annotated video
55
+ height, width, _ = frames[0].shape
56
+ output_path = os.path.join(temp_dir, "annotated_video.mp4")
57
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), 20, (width, height))
58
+
59
+ for frame in frames:
60
+ out.write(frame)
61
+
62
+ out.release()
63
+ return output_path
64
+
65
+ # Unified prediction function
66
+ def unified_predict(file):
67
+ if isinstance(file, Image.Image):
68
+ # If the input is a PIL Image, treat it as an image
69
+ return predict_image(file)
70
+ elif isinstance(file, str) and file.endswith(('.mp4', '.avi', '.mov')):
71
+ # If the input is a video file path, treat it as a video
72
+ return predict_video(file)
73
+ else:
74
+ raise ValueError("Unsupported file type. Please upload an image or a video.")
75
+
76
+ # UI Interface
77
+ with gr.Blocks() as interface:
78
+ gr.Markdown("# Wild Animal Detection (Tiger/Lion)")
79
+ gr.Markdown(student_info)
80
+
81
+ # Unified Section
82
+ with gr.Row():
83
+ with gr.Column():
84
+ gr.Markdown("### Upload an Image or Video:")
85
+ input_file = gr.File(label="Input File")
86
+
87
+ with gr.Column():
88
+ gr.Markdown("### Output Results:")
89
+ output_display = gr.Output(label="Output")
90
+
91
+ clear_btn= gr.Button("CLEAR")
92
+ submit_btn = gr.Button("SUBMIT")
93
+
94
+ def process_file(file):
95
+ if file.name.endswith((".jpg", ".jpeg", ".png")):
96
+ pil_image = Image.open(file.name)
97
+ return predict_image(pil_image)
98
+ elif file.name.endswith((".mp4", ".avi", ".mov")):
99
+ return predict_video(file.name)
100
+ else:
101
+ return "Unsupported file type. Please upload an image or a video."
102
+
103
+ def clear_all():
104
+ return None, ""
105
+
106
+ submit_btn.click(fn=process_file, inputs=input_file, outputs=output_display)
107
+ clear_btn.click(fn=clear_all, inputs=None, outputs=[input_file, output_display])
108
+
109
+ # Launch app
110
+ interface.launch(share=True)