Palm_counting / app.py
youl's picture
Update app.py
c8a1260
raw
history blame
2.54 kB
import streamlit as st
import torch
import torchvision
import cv2
import numpy as np
import torch.nn as nn
from torchvision.ops import box_iou
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2
# apply nms algorithm
def apply_nms(orig_prediction, iou_thresh=0.3):
# torchvision returns the indices of the bboxes to keep
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh)
final_prediction = orig_prediction
final_prediction['boxes'] = final_prediction['boxes'][keep]
final_prediction['scores'] = final_prediction['scores'][keep]
final_prediction['labels'] = final_prediction['labels'][keep]
return final_prediction
# Draw the bounding box
def plot_img_bbox(img, target):
for box in (target['boxes']):
xmin, ymin, xmax, ymax = int(box[0].cpu()), int(box[1].cpu()), int(box[2].cpu()),int(box[3].cpu())
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 0, 255), 2)
label = "palm"
# Add the label and confidence score
label = f'{label}'
cv2.putText(img, label, (xmin, ymin - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
# Display the image with detections
filename = 'pred.jpg'
cv2.imwrite(filename, img)
# transform image
test_transforms = A.Compose([
A.Resize(height=1024, width=1024, always_apply=True),
A.Normalize(always_apply=True),
ToTensorV2(always_apply=True),])
# select device (whether GPU or CPU)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# model loading
model = torch.load('pickel.pth',map_location=torch.device('cpu'))
model = model.to(device)
st.title("Palm trees detection")
file_name = st.file_uploader("Upload oil palm tree image")
if file_name is not None:
col1, col2 = st.columns(2)
image = np.array(Image.open(file_name))
col1.image(image, use_column_width=True)
transformed = test_transforms(image= image)
image_transformed = transformed["image"]
image_transformed = image_transformed.unsqueeze(0)
image_transformed = image_transformed.to(device)
# inference
model.eval()
with torch.no_grad():
predictions = model(image_transformed)[0]
nms_prediction = apply_nms(predictions, iou_thresh=0.1)
plot_img_bbox(image, nms_prediction)
pred = np.array(Image.open("pred.jpg"))
col2.image(pred, use_column_width=True)