shng2025 commited on
Commit
69678e6
·
1 Parent(s): 6b7ec9f
Files changed (1) hide show
  1. inference.py +69 -0
inference.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import matplotlib.pyplot as plt
3
+ import glob
4
+ import os
5
+
6
+ def visualize_predictions(result_dir):
7
+ """Visualize up to four prediction results."""
8
+ image_paths = glob.glob(os.path.join(result_dir, '*.jpg'))
9
+ num_images = min(4, len(image_paths))
10
+
11
+ if num_images == 0:
12
+ print("No images found for visualization.")
13
+ return
14
+
15
+ plt.figure(figsize=(15, 12))
16
+ for i, image_path in enumerate(image_paths[:num_images]):
17
+ image = plt.imread(image_path)
18
+ plt.subplot(2, 2, i + 1)
19
+ plt.imshow(image)
20
+ plt.axis('off')
21
+ plt.tight_layout()
22
+ plt.show()
23
+
24
+ def run_inference(checkpoint_path, inference_source='combined_dataset/images/valid', inference_name='yolo_infer_last'):
25
+ """Run inference using the saved checkpoint."""
26
+ if not os.path.exists(checkpoint_path):
27
+ print(f"Checkpoint '{checkpoint_path}' does not exist. Please ensure the path is correct.")
28
+ return
29
+
30
+ print(f"Loading the model from '{checkpoint_path}'...")
31
+ try:
32
+ # Load the model with the saved weights
33
+ model = YOLO(checkpoint_path)
34
+ print("Model loaded successfully.")
35
+ except Exception as e:
36
+ print(f"Error loading model: {e}")
37
+ return
38
+
39
+ # Verify inference source
40
+ if not os.path.exists(inference_source):
41
+ print(f"Inference source '{inference_source}' does not exist. Please provide a valid path.")
42
+ return
43
+
44
+ print(f"Running inference on '{inference_source}'...")
45
+ try:
46
+ results = model.predict(
47
+ source=inference_source,
48
+ save=True,
49
+ project='runs/predict',
50
+ name=inference_name,
51
+ exist_ok=True
52
+ )
53
+ print("Inference completed.")
54
+ except Exception as e:
55
+ print(f"Error during inference: {e}")
56
+ return
57
+
58
+ # Visualize predictions
59
+ visualize_predictions(os.path.join('runs', 'predict', inference_name))
60
+
61
+ def main():
62
+ # Define the path to the checkpoint
63
+ checkpoint_path = 'Edutech/train/weights/last.pt' # Adjust the path if necessary
64
+
65
+ # Run inference
66
+ run_inference(checkpoint_path, inference_name='yolo_infer_last')
67
+
68
+ if __name__ == "__main__":
69
+ main()