aico commited on
Commit
91678c8
·
1 Parent(s): c96ffbf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -74
app.py CHANGED
@@ -7,92 +7,26 @@ import cv2
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST")
9
 
10
- def _group_rectangles(rec):
11
- """
12
- Uion intersecting rectangles.
13
- Args:
14
- rec - list of rectangles in form [x, y, w, h]
15
- Return:
16
- list of grouped ractangles
17
- """
18
- tested = [False for i in range(len(rec))]
19
- final = []
20
- i = 0
21
- while i < len(rec):
22
- if not tested[i]:
23
- j = i+1
24
- while j < len(rec):
25
- if not tested[j] and intersect_area(rec[i], rec[j]):
26
- rec[i] = union(rec[i], rec[j])
27
- tested[j] = True
28
- j = i
29
- j += 1
30
- final += [rec[i]]
31
- i += 1
32
-
33
- return final
34
 
35
  def process_image(image):
36
- bounding_boxes = []
37
- generated_text_list = []
38
- #boundingBoxes_2 = []
39
  #print(np.shape(image))
40
  #print(image)
41
- #dim = (28,28)
42
- #resized = cv2.resize(image, dim, interpolation = cv2.INTER_AREA)
43
  #rint(image.astype('uint8'))
44
  #cv2.imwrite("image.png",image.astype('uint8'),(28, 28))
45
- #mask = np.zeros(np.shape(image), dtype=np.uint8)
46
- thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1]
47
- #gray = cv2.cvtColor(thresh, cv2.COLOR_BGR2GRAY)
48
-
49
- cnts = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
50
- cnts = cnts[0] if len(cnts) == 2 else cnts[1]
51
- (cnts, _) = contours.sort_contours(cnts, method="left-to-right")
52
- dim = (28, 28)
53
- for c in cnts:
54
- area = cv2.contourArea(c)
55
- #print(area)
56
- #if area < 120:
57
- bounding_boxes.append(cv2.boundingRect(c))
58
- #print("for loop bb: ",bounding_boxes)
59
-
60
- boundingBoxes_filter = [i for i in bounding_boxes if i != (0 , 0, 128, 128)]
61
-
62
- boundingBoxes = _group_rectangles(boundingBoxes_filter)
63
- #print(boundingBoxes)
64
- #
65
- #print(boundingBoxes_2)
66
- for (x, y, w, h) in boundingBoxes:
67
- #print(x,y,w,h)
68
- ROI = thresh[y:y+h, x:x+w]
69
- ROI2 = cv2.bitwise_not(ROI)
70
- borderoutput = cv2.copyMakeBorder(ROI2, 30, 30, 30, 30, cv2.BORDER_CONSTANT, value=[0, 0, 0])
71
-
72
- resized = cv2.resize(borderoutput, dim, interpolation = cv2.INTER_AREA)
73
- cv2.imwrite('ROI_{}.png'.format(x), resized)
74
- #imageinv = cv2.bitwise_not(resized)
75
- img = Image.fromarray(resized.astype('uint8')).convert("RGB")
76
-
77
- pixel_values = processor(img, return_tensors="pt").pixel_values
78
- generated_ids = model.generate(pixel_values)
79
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
- #print(generated_text)
81
- generated_text_list.append(generated_text)
82
- #img = Image.fromarray(image.astype('uint8')).convert("RGB")
83
  #img = Image.open("image.png").convert("RGB")
84
- #print(img)
85
-
86
  # prepare image
87
- #pixel_values = processor(img, return_tensors="pt").pixel_values
88
 
89
  # generate (no beam search)
90
- #generated_ids = model.generate(pixel_values)
91
 
92
  # decode
93
- #generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
94
- return ''.join(generated_text_list)
95
- #return generated_text
96
 
97
  title = "Interactive demo: Single Digits MNIST"
98
  description = "Aico - University Utrecht"
 
7
  processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
8
  model = VisionEncoderDecoderModel.from_pretrained("aico/TrOCR-MNIST")
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def process_image(image):
12
+
 
 
13
  #print(np.shape(image))
14
  #print(image)
 
 
15
  #rint(image.astype('uint8'))
16
  #cv2.imwrite("image.png",image.astype('uint8'),(28, 28))
17
+ img = Image.fromarray(image.astype('uint8')).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  #img = Image.open("image.png").convert("RGB")
19
+ print(img)
 
20
  # prepare image
21
+ pixel_values = processor(img, return_tensors="pt").pixel_values
22
 
23
  # generate (no beam search)
24
+ generated_ids = model.generate(pixel_values)
25
 
26
  # decode
27
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
28
+
29
+ return generated_text
30
 
31
  title = "Interactive demo: Single Digits MNIST"
32
  description = "Aico - University Utrecht"