mattraj commited on
Commit
7c511d2
·
1 Parent(s): 00a0836

demo buildout 9

Browse files
Files changed (1) hide show
  1. app.py +1 -219
app.py CHANGED
@@ -49,41 +49,15 @@ def infer(
49
  text: str,
50
  max_new_tokens: int
51
  ) -> str:
52
- inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt").to(device)
53
  with torch.inference_mode():
54
  generated_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=max_new_tokens,
57
- do_sample=False
58
  )
59
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
60
  return result[0][len(text):].lstrip("\n")
61
 
62
-
63
- ##### Parse segmentation output tokens into masks
64
- ##### Also returns bounding boxes with their labels
65
-
66
- def parse_segmentation(input_image, input_text):
67
- out = infer(input_image, input_text, max_new_tokens=100)
68
- objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
69
- labels = set(obj.get('name') for obj in objs if obj.get('name'))
70
- color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
71
- highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
72
- annotated_img = (
73
- input_image,
74
- [
75
- (
76
- obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
77
- obj['name'] or '',
78
- )
79
- for obj in objs
80
- if 'mask' in obj or 'xyxy' in obj
81
- ],
82
- )
83
- has_annotations = bool(annotated_img[1])
84
- return annotated_img
85
-
86
-
87
  ######## Demo
88
 
89
  INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
@@ -138,198 +112,6 @@ with gr.Blocks(css="style.css") as demo:
138
  examples=examples,
139
  inputs=chat_inputs,
140
  )
141
- '''
142
- with gr.Tab("Segment/Detect"):
143
- image = gr.Image(type="pil")
144
- seg_input = gr.Text(label="Entities to Segment/Detect")
145
- seg_btn = gr.Button("Submit")
146
- annotated_image = gr.AnnotatedImage(label="Output")
147
-
148
- examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."],
149
- ["./examples/sign.jpg", "Transcribe the Arabic text."]]
150
- gr.Markdown(
151
- "")
152
- gr.Examples(
153
- examples=examples,
154
- inputs=[image, seg_input],
155
- )
156
-
157
- seg_inputs = [
158
- image,
159
- seg_input
160
- ]
161
- seg_outputs = [
162
- annotated_image
163
- ]
164
- seg_btn.click(
165
- fn=parse_segmentation,
166
- inputs=seg_inputs,
167
- outputs=seg_outputs,
168
- )
169
- '''
170
-
171
- ### Postprocessing Utils for Segmentation Tokens
172
- ### Segmentation tokens are passed to another VAE which decodes them to a mask
173
-
174
- _MODEL_PATH = 'vae-oid.npz'
175
-
176
- _SEGMENT_DETECT_RE = re.compile(
177
- r'(.*?)' +
178
- r'<loc(\d{4})>' * 4 + r'\s*' +
179
- '(?:%s)?' % (r'<seg(\d{3})>' * 16) +
180
- r'\s*([^;<>]+)? ?(?:; )?',
181
- )
182
-
183
-
184
- def _get_params(checkpoint):
185
- """Converts PyTorch checkpoint to Flax params."""
186
-
187
- def transp(kernel):
188
- return np.transpose(kernel, (2, 3, 1, 0))
189
-
190
- def conv(name):
191
- return {
192
- 'bias': checkpoint[name + '.bias'],
193
- 'kernel': transp(checkpoint[name + '.weight']),
194
- }
195
-
196
- def resblock(name):
197
- return {
198
- 'Conv_0': conv(name + '.0'),
199
- 'Conv_1': conv(name + '.2'),
200
- 'Conv_2': conv(name + '.4'),
201
- }
202
-
203
- return {
204
- '_embeddings': checkpoint['_vq_vae._embedding'],
205
- 'Conv_0': conv('decoder.0'),
206
- 'ResBlock_0': resblock('decoder.2.net'),
207
- 'ResBlock_1': resblock('decoder.3.net'),
208
- 'ConvTranspose_0': conv('decoder.4'),
209
- 'ConvTranspose_1': conv('decoder.6'),
210
- 'ConvTranspose_2': conv('decoder.8'),
211
- 'ConvTranspose_3': conv('decoder.10'),
212
- 'Conv_1': conv('decoder.12'),
213
- }
214
-
215
-
216
- def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
217
- batch_size, num_tokens = codebook_indices.shape
218
- assert num_tokens == 16, codebook_indices.shape
219
- unused_num_embeddings, embedding_dim = embeddings.shape
220
-
221
- encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
222
- encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
223
- return encodings
224
-
225
-
226
- @functools.cache
227
- def _get_reconstruct_masks():
228
- """Reconstructs masks from codebook indices.
229
- Returns:
230
- A function that expects indices shaped `[B, 16]` of dtype int32, each
231
- ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
232
- `[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
233
- """
234
-
235
- class ResBlock(nn.Module):
236
- features: int
237
-
238
- @nn.compact
239
- def __call__(self, x):
240
- original_x = x
241
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
242
- x = nn.relu(x)
243
- x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
244
- x = nn.relu(x)
245
- x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
246
- return x + original_x
247
-
248
- class Decoder(nn.Module):
249
- """Upscales quantized vectors to mask."""
250
-
251
- @nn.compact
252
- def __call__(self, x):
253
- num_res_blocks = 2
254
- dim = 128
255
- num_upsample_layers = 4
256
-
257
- x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
258
- x = nn.relu(x)
259
-
260
- for _ in range(num_res_blocks):
261
- x = ResBlock(features=dim)(x)
262
-
263
- for _ in range(num_upsample_layers):
264
- x = nn.ConvTranspose(
265
- features=dim,
266
- kernel_size=(4, 4),
267
- strides=(2, 2),
268
- padding=2,
269
- transpose_kernel=True,
270
- )(x)
271
- x = nn.relu(x)
272
- dim //= 2
273
-
274
- x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
275
-
276
- return x
277
-
278
- def reconstruct_masks(codebook_indices):
279
- quantized = _quantized_values_from_codebook_indices(
280
- codebook_indices, params['_embeddings']
281
- )
282
- return Decoder().apply({'params': params}, quantized)
283
-
284
- with open(_MODEL_PATH, 'rb') as f:
285
- params = _get_params(dict(np.load(f)))
286
-
287
- return jax.jit(reconstruct_masks, backend='cpu')
288
-
289
-
290
- def extract_objs(text, width, height, unique_labels=False):
291
- """Returns objs for a string with "<loc>" and "<seg>" tokens."""
292
- objs = []
293
- seen = set()
294
- while text:
295
- m = _SEGMENT_DETECT_RE.match(text)
296
- if not m:
297
- break
298
- print("m", m)
299
- gs = list(m.groups())
300
- before = gs.pop(0)
301
- name = gs.pop()
302
- y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
303
-
304
- y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
305
- seg_indices = gs[4:20]
306
- if seg_indices[0] is None:
307
- mask = None
308
- else:
309
- seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
310
- m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
311
- m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
312
- m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
313
- mask = np.zeros([height, width])
314
- if y2 > y1 and x2 > x1:
315
- mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
316
-
317
- content = m.group()
318
- if before:
319
- objs.append(dict(content=before))
320
- content = content[len(before):]
321
- while unique_labels and name in seen:
322
- name = (name or '') + "'"
323
- seen.add(name)
324
- objs.append(dict(
325
- content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
326
- text = text[len(before) + len(content):]
327
-
328
- if text:
329
- objs.append(dict(content=text))
330
-
331
- return objs
332
-
333
 
334
  #########
335
 
 
49
  text: str,
50
  max_new_tokens: int
51
  ) -> str:
52
+ inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt", padding="longest", do_convert_rgb=True).to(device)
53
  with torch.inference_mode():
54
  generated_ids = model.generate(
55
  **inputs,
56
  max_new_tokens=max_new_tokens,
 
57
  )
58
  result = processor.batch_decode(generated_ids, skip_special_tokens=True)
59
  return result[0][len(text):].lstrip("\n")
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ######## Demo
62
 
63
  INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
 
112
  examples=examples,
113
  inputs=chat_inputs,
114
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  #########
117