File size: 5,316 Bytes
af5e402
 
 
 
 
 
 
4bd76a1
af5e402
 
 
 
 
 
4bd76a1
 
af5e402
4bd76a1
af5e402
 
 
 
 
 
 
f2ea7e1
 
b53225a
f2ea7e1
 
 
 
b53225a
af5e402
 
 
 
 
 
 
f2ea7e1
af5e402
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4bd76a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af5e402
 
 
 
4bd76a1
af5e402
 
 
 
 
 
 
c830096
af5e402
c830096
af5e402
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
try:
    import detectron2
except:
    import os 
    os.system('pip install git+https://github.com/facebookresearch/detectron2.git')

import cv2
import json
from matplotlib.pyplot import axis
import gradio as gr
import requests
import numpy as np
from torch import nn
import requests
from numpy.lib.type_check import imag
import random

import csv
import torch

from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.data.datasets import register_coco_instances
try:
    register_coco_instances("Fiber", {}, "./labels-fiver.json", "./Fiber")
    Fiber_metadata = MetadataCatalog.get("Fiber")
    dataset_dicts = DatasetCatalog.get("Fiber")
except:
    print("there is an issue")
model_path = "./model_final.pth"

cfg = get_cfg()
cfg.merge_from_file("./configs/detectron2/faster_rcnn_R_50_FPN_3x.yaml")
cfg.MODEL.ROI_HEADS.NUM_CLASSES = 2
cfg.MODEL.WEIGHTS = model_path

my_metadata = MetadataCatalog.get("dbmdz_coco_all")
my_metadata.thing_classes = ["Fiber", "Fiber"]

if not torch.cuda.is_available():
    cfg.MODEL.DEVICE = "cpu"


def inference(image_url, image, min_score):
    if image_url:
        r = requests.get(image_url)
        if r:
            im = np.frombuffer(r.content, dtype="uint8")
            im = cv2.imdecode(im, cv2.IMREAD_COLOR)
    else:
        # Model expect BGR!
        im = image[:,:,::-1]

    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = min_score
    predictor = DefaultPredictor(cfg)

    outputs = predictor(im)

    # v = Visualizer(im, my_metadata, scale=1.2)
    # out = v.draw_instance_predictions(outputs["instances"].to("cpu"))


    
    # for d in random.sample(dataset_dicts, 3):
    #     img = cv2.imread("https://meet.google.com/ice-wndh-joi.jpg")
    #     predictor(img)
    #     visualizer = Visualizer(img[:, :, ::-1], metadata=Fiber_metadata, scale=1)
    #     vis = visualizer.draw_dataset_dict(d)
    #     cv2_imshow(vis.get_image()[:, :, ::-1])
    # !zip -r ./fiber.zip  "/content/Fiber"
    # import json
    # from google.colab import files
    # uploaded = files.upload()
    
    # im = cv2.imread([key for key in uploaded.keys()
    #   ][0])
    
    
    
    # # im = cv2.imread(d["file_name"])
    # outputs = predictor(im)
    # v = Visualizer(im[:, :, ::-1],
    #                 metadata=Fiber_metadata, 
    #                 scale=1, 
    #                 instance_mode=ColorMode.IMAGE_BW   # remove the colors of unsegmented pixels
    # )
    # v = v.draw_instance_predictions(outputs["instances"].to("cpu"))
    # cv2_imshow(v.get_image()[:, :, ::-1])
    masks = np.asarray(outputs["instances"].pred_masks.to("cpu"))
    
    bbox = np.asarray(outputs["instances"].pred_boxes.to("cpu"))
    
    # Pick an item to mask
    img=v.get_image()
    
    # Define a dictionary to store the measurements and their positions
    measurements = {}
    
    for ind,item_mask in enumerate(masks):
      box=bbox[ind]
      # Get the true bounding box of the mask (not the same as the bbox prediction)
      segmentation = np.where(item_mask == True)
      x_min = int(np.min(segmentation[1]))
      x_max = int(np.max(segmentation[1]))
      y_min = int(np.min(segmentation[0]))
      y_max = int(np.max(segmentation[0]))
      measurement = int(0.5+len(segmentation[0])/600)
      measurements[ind] = {'measurement': measurement, 'x_min': x_min, 'x_max': x_max, 'y_min': y_min, 'y_max': y_max}
      # cv2.putText(img=img, text=str(int(0.5+len( segmentation[0])/600)), org=(x_min+20,y_min-10), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=0.8, color=(0, 255, 0),thickness=2)
    # cv2_imshow(img)
    
    
    # Loop over the masks
    for ind, item_mask in enumerate(masks):
        segmentation = np.where(item_mask == True)
        measurement = int(0.5+len(segmentation[0])/600)
        measurements[ind] = {'measurement': measurement, 'x_min': x_min, 'x_max': x_max, 'y_min': y_min, 'y_max': y_max}
    
    # Write the measurements to a CSV file
    with open('dmeasurements.csv', mode='w') as file:
        writer = csv.writer(file)
        writer.writerow(['ID', 'Measurement', 'X_Min', 'X_Max', 'Y_Min', 'Y_Max'])
        for id, data in measurements.items():
            writer.writerow([id, data['measurement'], data['x_min'], data['x_max'], data['y_min'], data['y_max']])
        return file

    return out.get_image()


title = " Detectron2 Model Demo"
description = "This demo introduces an interactive playground for our trained Detectron2 model. <br>The model was trained on manually annotated segments from digitized books to detect Illustration or Illumination segments on a given page."
article = '<p>Detectron model is available from our repository <a href="">here</a> on the Hugging Face Model Hub.</p>'

gr.Interface(
    inference,
    [gr.inputs.Textbox(label="Image URL", placeholder="https://api.digitale-sammlungen.de/iiif/image/v2/bsb10483966_00008/full/500,/0/default.jpg"),
     gr.inputs.Image(type="numpy", label="Input Image"),
     gr.Slider(minimum=0.0, maximum=1.0, value=0.1, label="Minimum score"),
    ], 
    gr.Dataframe(label="Data"),
    title=title,
    description=description,
    article=article,
    examples=[]).launch()