youl commited on
Commit
965ddf7
·
1 Parent(s): ea8a726

Add application file

Browse files
Files changed (2) hide show
  1. app.py +80 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import cv2
4
+ import torch.nn as nn
5
+ from torchvision.ops import box_iou
6
+ from PIL import Image
7
+ import albumentations as A
8
+ from albumentations.pytorch import ToTensorV2
9
+
10
+ # apply nms algorithm
11
+ def apply_nms(orig_prediction, iou_thresh=0.3):
12
+ # torchvision returns the indices of the bboxes to keep
13
+ keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
14
+ final_prediction = orig_prediction
15
+ final_prediction['boxes'] = final_prediction['boxes'][keep]
16
+ final_prediction['scores'] = final_prediction['scores'][keep]
17
+ final_prediction['labels'] = final_prediction['labels'][keep]
18
+
19
+ return final_prediction
20
+
21
+ # Draw the bounding box
22
+ def plot_img_bbox(img, target):
23
+ for box in (target['boxes']):
24
+ xmin, ymin, xmax, ymax = int(box[0].cpu()), int(box[1].cpu()), int(box[2].cpu()),int(box[3].cpu())
25
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
26
+ label = "palm"
27
+ # Add the label and confidence score
28
+ label = f'{label}'
29
+ cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
30
+
31
+ # Display the image with detections
32
+ filename = 'pred.jpg'
33
+ cv2.imwrite(filename, img)
34
+
35
+ # transform image
36
+ test_transforms = A.Compose([
37
+ A.Resize(height=1024, width=1024, always_apply=True),
38
+ A.Normalize(always_apply=True),
39
+ ToTensorV2(always_apply=True),])
40
+
41
+ # select device (whether GPU or CPU)
42
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
43
+
44
+ # model loading
45
+ model = torch.load('pickel.pth',map_location=torch.device('cpu'))
46
+ model = model.to(device)
47
+
48
+ st.title("Palm trees detection")
49
+
50
+ file_name = st.file_uploader("Upload oil palm tree image")
51
+
52
+ if file_name is not None:
53
+ col1, col2 = st.columns(2)
54
+
55
+ image = Image.open(file_name)
56
+ col1.image(image, use_column_width=True)
57
+ transformed = test_transforms(image= image)
58
+ image_transformed = transformed["image"]
59
+ image_transformed = image_transformed.unsqueeze(0)
60
+ image_transformed = image_transformed.to(device)
61
+ # inference
62
+ model.eval()
63
+ with torch.no_grad():
64
+ predictions = model(image_transformed)[0]
65
+
66
+ nms_prediction = apply_nms(predictions, iou_thresh=0.1)
67
+
68
+ plot_img_bbox(image, nms_prediction)
69
+ pred = Image.open("pred.jpg")
70
+ col2.image(pred, use_column_width=True)
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ albumentations