ddh-pipelined / app.py
yahaoh's picture
Update app.py
5fe3f71
try:
import detectron2
except:
import os
os.system('pip install git+https://github.com/haya-alwarthan/detectron2.git')
try:
import roboflow
except:
os.system('pip install git+https://github.com/roboflow-ai/roboflow-python.git')
import cv2
import gradio as gr
import numpy as np
import torch
import requests
import roboflow
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import ColorMode
from matplotlib import patches
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgba
import numpy as np
import cv2
import io
#Define the path to the pretrained weights of the model
keypoint_model_path= "https://huggingface.co/yahaoh/ddh-maskrcnn/resolve/main/keypoint_rcnn_binarymask.pth"
#configure roboflow project
rf = roboflow.Roboflow(api_key="0uv4bY5n7Vluj0yRHOOm")
project = rf.workspace().project("ddh-seg")
model = project.version(1).model
#convert detectron format predictions to normal pairs
def convert_to_pairs(three):
pairs=[element for i,element in enumerate(three) if i%3!=2 ]
return pairs
#function to output binary mask img
def segm_imf(prediction,rec_img):
w= rec_img.shape[0]
h=rec_img.shape[1]
colors={"LOWER_BONE":"#fabee6","UPPER_BONE":"#96e7e6","MIDDLE_BONE":"#fffa5b"}
# colors={"LOWER_BONE":(253,226,243),"UPPER_BONE":(173,228,219),"MIDDLE_BONE": (253,247,195)}
figure, axes = plt.subplots(figsize =(h/100.0,w/100.0))
for prediction in prediction["predictions"]:
points = [[p["x"], p["y"]] for p in prediction["points"]]
polygon = patches.Polygon(
points, linewidth=2, edgecolor=colors[prediction["class"]],facecolor= to_rgba(colors[prediction["class"]],0.4)
)
axes.add_patch(polygon)
plt.imshow(rec_img)
plt.axis("off")
plt.tight_layout()
canvas = plt.get_current_fig_manager().canvas
canvas.draw()
buf = io.BytesIO()
canvas.print_png(buf)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype='uint8')
img = cv2.imdecode(img_arr, cv2.IMREAD_UNCHANGED)
gray_image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return gray_image
def binary_cv2(rec_img,prediction):
w= rec_img.shape[0]
h=rec_img.shape[1]
mask = np.zeros((w, h), dtype = np.float64)
for prediction in prediction["predictions"]:
points = [[p["x"], p["y"]] for p in prediction["points"]]
cv2.fillPoly(mask, np.array([points]).astype(np.int32), color=(255, 0, 0))
masked_gray = cv2.merge((mask,mask,mask))
return masked_gray
def extend_line(p1, p2, distance=10000):
diff = np.arctan2(p1[1] - p2[1], p1[0] - p2[0])
p3_x = int(p1[0] + distance*np.cos(diff))
p3_y = int(p1[1] + distance*np.sin(diff))
p4_x = int(p1[0] - distance*np.cos(diff))
p4_y = int(p1[1] - distance*np.sin(diff))
return ((p3_x, p3_y), (p4_x, p4_y))
def visualize(image,pred):
op_img=image.copy()
pred=[int(i) for i in pred]
(p1_left,p2_left)=extend_line((pred[6],pred[7]),(pred[4],pred[5]))
(p1_h,p2_h)=extend_line((pred[4],pred[5]),(pred[2],pred[3]))
(p1_right,p2_right)=extend_line((pred[2],pred[3]),(pred[0],pred[1]))
op_img= cv2.line(op_img, p1_left, (pred[4],pred[5]),(152, 216, 170),1)
op_img= cv2.line(op_img, p1_h, p2_h,(152, 216, 170),1)
op_img= cv2.line(op_img, (pred[2],pred[3]), p2_right,(152, 216, 170),1)
for i in range(0,7,2):
op_img = cv2.circle(op_img, (round(pred[i]),round(pred[i+1])), int ((image.shape[0]+image.shape[1])/150), (255, 150, 128), -1)
return op_img
#Keypoint RCNN MODEL AND META DATA SETUP
KEYPOINT_NAMES = ["RU","RD","LD","LU"]
KEYPOINT_FLIP_MAP = [
("RU", "LU"),
("RD", "LD"),
]
KEYPOINT_CONNECTION_RULES = [
("RU", "RD", (102, 204, 255)),
("RD", "LD", (51, 153, 255)),
("LU", "LD", (102, 0, 204)),
]
kp_meta=MetadataCatalog.get("ddh_jordan_kp_train")
kp_meta.set(keypoint_names =KEYPOINT_NAMES)
kp_meta.set(keypoint_flip_map =KEYPOINT_FLIP_MAP)
kp_meta.set(keypoint_connection_rules =KEYPOINT_CONNECTION_RULES)
kp_cfg = get_cfg()
kp_cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
kp_cfg.DATALOADER.NUM_WORKERS = 2
kp_cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml") # Let training initialize from model zoo
kp_cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS=4
kp_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 1 # only has one class (landmarks). (see https://detectron2.readthedocs.io/tutorials/datasets.html#update-the-config-for-new-datasets)
kp_cfg.TEST.DETECTIONS_PER_IMAGE=1
kp_cfg.MODEL.WEIGHTS = keypoint_model_path
# kp_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7 # set a custom testing threshold
#check whether gpu is available
if not torch.cuda.is_available():
kp_cfg.MODEL.DEVICE = "cpu"
#Define the predictor of the model
kp_predictor = DefaultPredictor(kp_cfg)
def output_json(keypoints):
landmarks_labels= ["RU","RD","LD","LU"]
output= {}
for i in range(len(landmarks_labels)):
output[landmarks_labels[i].lower()]={'x':keypoints[2*i].item(), 'y':keypoints[2*i+1].item()}
return output
#Define a function to infernece from image
def predict_fn(img_path):
#Read and tranform input image
preds=model.predict(img_path).json()
og_img=cv2.imread(img_path)
img=binary_cv2(og_img,preds)
img= np.array(img,dtype=np.uint8)
outputs = kp_predictor(img) # format is documented at https://detectron2.readthedocs.io/tutorials/models.html#model-output-format
print("outputs==".format(outputs["instances"].to("cpu")))
p=np.asarray(outputs["instances"].to("cpu").pred_keypoints, dtype='float32')
landmarks={}
if p.size >0:
p=p[0].reshape(-1)
pairss=convert_to_pairs(p)
landmarks=output_json(pairss)
segm=segm_imf(preds,og_img)
img= visualize(segm,pairss)
return (img,landmarks)
inputs_image = [
gr.components.Image(type="filepath", label="Upload an XRay Image of the Pelvis"),
]
outputs_image = [
gr.components.Image(type="numpy", label="Output Image")
]
outputs_landmarks = [
gr.components.JSON( label="Output Landmarks")
]
outputs=[
gr.components.Image(type="numpy", label="Output Image"),
gr.components.JSON( label="Output Landmarks")
]
gr.Interface(
predict_fn,
inputs=inputs_image,
outputs=outputs,
title="Coordinates of the Landmarks",
).launch()