Spaces:
Sleeping
Sleeping
demo buildout 9
Browse files
app.py
CHANGED
@@ -49,41 +49,15 @@ def infer(
|
|
49 |
text: str,
|
50 |
max_new_tokens: int
|
51 |
) -> str:
|
52 |
-
inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt").to(device)
|
53 |
with torch.inference_mode():
|
54 |
generated_ids = model.generate(
|
55 |
**inputs,
|
56 |
max_new_tokens=max_new_tokens,
|
57 |
-
do_sample=False
|
58 |
)
|
59 |
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
60 |
return result[0][len(text):].lstrip("\n")
|
61 |
|
62 |
-
|
63 |
-
##### Parse segmentation output tokens into masks
|
64 |
-
##### Also returns bounding boxes with their labels
|
65 |
-
|
66 |
-
def parse_segmentation(input_image, input_text):
|
67 |
-
out = infer(input_image, input_text, max_new_tokens=100)
|
68 |
-
objs = extract_objs(out.lstrip("\n"), input_image.size[0], input_image.size[1], unique_labels=True)
|
69 |
-
labels = set(obj.get('name') for obj in objs if obj.get('name'))
|
70 |
-
color_map = {l: COLORS[i % len(COLORS)] for i, l in enumerate(labels)}
|
71 |
-
highlighted_text = [(obj['content'], obj.get('name')) for obj in objs]
|
72 |
-
annotated_img = (
|
73 |
-
input_image,
|
74 |
-
[
|
75 |
-
(
|
76 |
-
obj['mask'] if obj.get('mask') is not None else obj['xyxy'],
|
77 |
-
obj['name'] or '',
|
78 |
-
)
|
79 |
-
for obj in objs
|
80 |
-
if 'mask' in obj or 'xyxy' in obj
|
81 |
-
],
|
82 |
-
)
|
83 |
-
has_annotations = bool(annotated_img[1])
|
84 |
-
return annotated_img
|
85 |
-
|
86 |
-
|
87 |
######## Demo
|
88 |
|
89 |
INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
|
@@ -138,198 +112,6 @@ with gr.Blocks(css="style.css") as demo:
|
|
138 |
examples=examples,
|
139 |
inputs=chat_inputs,
|
140 |
)
|
141 |
-
'''
|
142 |
-
with gr.Tab("Segment/Detect"):
|
143 |
-
image = gr.Image(type="pil")
|
144 |
-
seg_input = gr.Text(label="Entities to Segment/Detect")
|
145 |
-
seg_btn = gr.Button("Submit")
|
146 |
-
annotated_image = gr.AnnotatedImage(label="Output")
|
147 |
-
|
148 |
-
examples = [["./diagnosis-1.jpg", "Transcribe the Arabic text."],
|
149 |
-
["./examples/sign.jpg", "Transcribe the Arabic text."]]
|
150 |
-
gr.Markdown(
|
151 |
-
"")
|
152 |
-
gr.Examples(
|
153 |
-
examples=examples,
|
154 |
-
inputs=[image, seg_input],
|
155 |
-
)
|
156 |
-
|
157 |
-
seg_inputs = [
|
158 |
-
image,
|
159 |
-
seg_input
|
160 |
-
]
|
161 |
-
seg_outputs = [
|
162 |
-
annotated_image
|
163 |
-
]
|
164 |
-
seg_btn.click(
|
165 |
-
fn=parse_segmentation,
|
166 |
-
inputs=seg_inputs,
|
167 |
-
outputs=seg_outputs,
|
168 |
-
)
|
169 |
-
'''
|
170 |
-
|
171 |
-
### Postprocessing Utils for Segmentation Tokens
|
172 |
-
### Segmentation tokens are passed to another VAE which decodes them to a mask
|
173 |
-
|
174 |
-
_MODEL_PATH = 'vae-oid.npz'
|
175 |
-
|
176 |
-
_SEGMENT_DETECT_RE = re.compile(
|
177 |
-
r'(.*?)' +
|
178 |
-
r'<loc(\d{4})>' * 4 + r'\s*' +
|
179 |
-
'(?:%s)?' % (r'<seg(\d{3})>' * 16) +
|
180 |
-
r'\s*([^;<>]+)? ?(?:; )?',
|
181 |
-
)
|
182 |
-
|
183 |
-
|
184 |
-
def _get_params(checkpoint):
|
185 |
-
"""Converts PyTorch checkpoint to Flax params."""
|
186 |
-
|
187 |
-
def transp(kernel):
|
188 |
-
return np.transpose(kernel, (2, 3, 1, 0))
|
189 |
-
|
190 |
-
def conv(name):
|
191 |
-
return {
|
192 |
-
'bias': checkpoint[name + '.bias'],
|
193 |
-
'kernel': transp(checkpoint[name + '.weight']),
|
194 |
-
}
|
195 |
-
|
196 |
-
def resblock(name):
|
197 |
-
return {
|
198 |
-
'Conv_0': conv(name + '.0'),
|
199 |
-
'Conv_1': conv(name + '.2'),
|
200 |
-
'Conv_2': conv(name + '.4'),
|
201 |
-
}
|
202 |
-
|
203 |
-
return {
|
204 |
-
'_embeddings': checkpoint['_vq_vae._embedding'],
|
205 |
-
'Conv_0': conv('decoder.0'),
|
206 |
-
'ResBlock_0': resblock('decoder.2.net'),
|
207 |
-
'ResBlock_1': resblock('decoder.3.net'),
|
208 |
-
'ConvTranspose_0': conv('decoder.4'),
|
209 |
-
'ConvTranspose_1': conv('decoder.6'),
|
210 |
-
'ConvTranspose_2': conv('decoder.8'),
|
211 |
-
'ConvTranspose_3': conv('decoder.10'),
|
212 |
-
'Conv_1': conv('decoder.12'),
|
213 |
-
}
|
214 |
-
|
215 |
-
|
216 |
-
def _quantized_values_from_codebook_indices(codebook_indices, embeddings):
|
217 |
-
batch_size, num_tokens = codebook_indices.shape
|
218 |
-
assert num_tokens == 16, codebook_indices.shape
|
219 |
-
unused_num_embeddings, embedding_dim = embeddings.shape
|
220 |
-
|
221 |
-
encodings = jnp.take(embeddings, codebook_indices.reshape((-1)), axis=0)
|
222 |
-
encodings = encodings.reshape((batch_size, 4, 4, embedding_dim))
|
223 |
-
return encodings
|
224 |
-
|
225 |
-
|
226 |
-
@functools.cache
|
227 |
-
def _get_reconstruct_masks():
|
228 |
-
"""Reconstructs masks from codebook indices.
|
229 |
-
Returns:
|
230 |
-
A function that expects indices shaped `[B, 16]` of dtype int32, each
|
231 |
-
ranging from 0 to 127 (inclusive), and that returns a decoded masks sized
|
232 |
-
`[B, 64, 64, 1]`, of dtype float32, in range [-1, 1].
|
233 |
-
"""
|
234 |
-
|
235 |
-
class ResBlock(nn.Module):
|
236 |
-
features: int
|
237 |
-
|
238 |
-
@nn.compact
|
239 |
-
def __call__(self, x):
|
240 |
-
original_x = x
|
241 |
-
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
242 |
-
x = nn.relu(x)
|
243 |
-
x = nn.Conv(features=self.features, kernel_size=(3, 3), padding=1)(x)
|
244 |
-
x = nn.relu(x)
|
245 |
-
x = nn.Conv(features=self.features, kernel_size=(1, 1), padding=0)(x)
|
246 |
-
return x + original_x
|
247 |
-
|
248 |
-
class Decoder(nn.Module):
|
249 |
-
"""Upscales quantized vectors to mask."""
|
250 |
-
|
251 |
-
@nn.compact
|
252 |
-
def __call__(self, x):
|
253 |
-
num_res_blocks = 2
|
254 |
-
dim = 128
|
255 |
-
num_upsample_layers = 4
|
256 |
-
|
257 |
-
x = nn.Conv(features=dim, kernel_size=(1, 1), padding=0)(x)
|
258 |
-
x = nn.relu(x)
|
259 |
-
|
260 |
-
for _ in range(num_res_blocks):
|
261 |
-
x = ResBlock(features=dim)(x)
|
262 |
-
|
263 |
-
for _ in range(num_upsample_layers):
|
264 |
-
x = nn.ConvTranspose(
|
265 |
-
features=dim,
|
266 |
-
kernel_size=(4, 4),
|
267 |
-
strides=(2, 2),
|
268 |
-
padding=2,
|
269 |
-
transpose_kernel=True,
|
270 |
-
)(x)
|
271 |
-
x = nn.relu(x)
|
272 |
-
dim //= 2
|
273 |
-
|
274 |
-
x = nn.Conv(features=1, kernel_size=(1, 1), padding=0)(x)
|
275 |
-
|
276 |
-
return x
|
277 |
-
|
278 |
-
def reconstruct_masks(codebook_indices):
|
279 |
-
quantized = _quantized_values_from_codebook_indices(
|
280 |
-
codebook_indices, params['_embeddings']
|
281 |
-
)
|
282 |
-
return Decoder().apply({'params': params}, quantized)
|
283 |
-
|
284 |
-
with open(_MODEL_PATH, 'rb') as f:
|
285 |
-
params = _get_params(dict(np.load(f)))
|
286 |
-
|
287 |
-
return jax.jit(reconstruct_masks, backend='cpu')
|
288 |
-
|
289 |
-
|
290 |
-
def extract_objs(text, width, height, unique_labels=False):
|
291 |
-
"""Returns objs for a string with "<loc>" and "<seg>" tokens."""
|
292 |
-
objs = []
|
293 |
-
seen = set()
|
294 |
-
while text:
|
295 |
-
m = _SEGMENT_DETECT_RE.match(text)
|
296 |
-
if not m:
|
297 |
-
break
|
298 |
-
print("m", m)
|
299 |
-
gs = list(m.groups())
|
300 |
-
before = gs.pop(0)
|
301 |
-
name = gs.pop()
|
302 |
-
y1, x1, y2, x2 = [int(x) / 1024 for x in gs[:4]]
|
303 |
-
|
304 |
-
y1, x1, y2, x2 = map(round, (y1 * height, x1 * width, y2 * height, x2 * width))
|
305 |
-
seg_indices = gs[4:20]
|
306 |
-
if seg_indices[0] is None:
|
307 |
-
mask = None
|
308 |
-
else:
|
309 |
-
seg_indices = np.array([int(x) for x in seg_indices], dtype=np.int32)
|
310 |
-
m64, = _get_reconstruct_masks()(seg_indices[None])[..., 0]
|
311 |
-
m64 = np.clip(np.array(m64) * 0.5 + 0.5, 0, 1)
|
312 |
-
m64 = PIL.Image.fromarray((m64 * 255).astype('uint8'))
|
313 |
-
mask = np.zeros([height, width])
|
314 |
-
if y2 > y1 and x2 > x1:
|
315 |
-
mask[y1:y2, x1:x2] = np.array(m64.resize([x2 - x1, y2 - y1])) / 255.0
|
316 |
-
|
317 |
-
content = m.group()
|
318 |
-
if before:
|
319 |
-
objs.append(dict(content=before))
|
320 |
-
content = content[len(before):]
|
321 |
-
while unique_labels and name in seen:
|
322 |
-
name = (name or '') + "'"
|
323 |
-
seen.add(name)
|
324 |
-
objs.append(dict(
|
325 |
-
content=content, xyxy=(x1, y1, x2, y2), mask=mask, name=name))
|
326 |
-
text = text[len(before) + len(content):]
|
327 |
-
|
328 |
-
if text:
|
329 |
-
objs.append(dict(content=text))
|
330 |
-
|
331 |
-
return objs
|
332 |
-
|
333 |
|
334 |
#########
|
335 |
|
|
|
49 |
text: str,
|
50 |
max_new_tokens: int
|
51 |
) -> str:
|
52 |
+
inputs = processor(text=text, images=resize_and_pad(image, 448), return_tensors="pt", padding="longest", do_convert_rgb=True).to(device)
|
53 |
with torch.inference_mode():
|
54 |
generated_ids = model.generate(
|
55 |
**inputs,
|
56 |
max_new_tokens=max_new_tokens,
|
|
|
57 |
)
|
58 |
result = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
59 |
return result[0][len(text):].lstrip("\n")
|
60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
######## Demo
|
62 |
|
63 |
INTRO_TEXT = """## Curacel Handwritten Arabic demo\n\n
|
|
|
112 |
examples=examples,
|
113 |
inputs=chat_inputs,
|
114 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
#########
|
117 |
|