Maltokar commited on
Commit
a71fc3b
·
verified ·
1 Parent(s): 5ff73ca

Upload modeling_got.py

Browse files
Files changed (1) hide show
  1. modeling_got.py +149 -108
modeling_got.py CHANGED
@@ -408,46 +408,71 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
408
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
409
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
410
 
411
- def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None, print_prompt=False, gradio_input=False, stream_flag=False, num_workers=4):
 
 
412
 
413
- self.disable_torch_init()
414
-
415
- image_processor_high = GOTImageEvalProcessor(image_size=1024)
416
- use_im_start_end = True
417
- image_token_len = 256
418
-
419
- # Load the image either from Gradio input or directly
420
- if gradio_input:
421
- image = image_file.copy()
422
- else:
423
- image = self.load_image(image_file)
424
-
425
- w, h = image.size
426
-
427
- # Prepare OCR query
428
- if ocr_type == 'format':
429
- qs = 'OCR with format: '
430
- else:
431
- qs = 'OCR: '
432
-
433
- # Process bounding box for OCR
434
- if ocr_box:
435
- bbox = eval(ocr_box)
436
- bbox = [int(bbox[i]/w*1000) if i % 2 == 0 else int(bbox[i]/h*1000) for i in range(len(bbox))]
437
- qs = str(bbox) + ' ' + qs
438
-
439
- # Process OCR color if provided
440
- if ocr_color:
441
- qs = '[' + ocr_color + '] ' + qs
442
-
443
- # Image token embedding
444
- if use_im_start_end:
445
- qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
446
- else:
447
- qs = DEFAULT_IMAGE_TOKEN + '\n' + qs
448
 
449
- # Conversation setup
450
- conv_mpt = Conversation(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  system="""<|im_start|>system
452
  You should follow the instructions carefully and explain your answers in detail.""",
453
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
@@ -456,90 +481,106 @@ class GOTQwenForCausalLM(Qwen2ForCausalLM):
456
  offset=0,
457
  sep_style=SeparatorStyle.MPT,
458
  sep="<|im_end|>",
459
- )
460
 
461
- conv = conv_mpt.copy()
462
- conv.append_message(conv.roles[0], qs)
463
- conv.append_message(conv.roles[1], None)
464
- prompt = conv.get_prompt()
465
 
466
- if print_prompt:
467
- print(prompt)
468
-
469
- # Tokenize input
470
- inputs = tokenizer([prompt])
471
 
472
- # Process image
473
- image_tensor_1 = image_processor_high(image)
474
- input_ids = torch.as_tensor(inputs.input_ids).cpu()
475
 
476
- stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
477
- keywords = [stop_str]
478
- stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
479
- streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
480
 
481
- def generate_output(image_tensor, input_ids):
482
- with torch.autocast("cpu", dtype=torch.bfloat16):
483
- return self.generate(
484
- input_ids,
485
- images=[image_tensor.unsqueeze(0).half().cpu()],
486
- do_sample=False,
487
- num_beams=1,
488
- no_repeat_ngram_size=20,
489
- streamer=streamer if stream_flag else None,
490
- max_new_tokens=4096,
491
- stopping_criteria=[stopping_criteria]
492
- )
493
 
494
- # Multiprocessing to parallelize generation
495
- with mp.Pool(processes=num_workers) as pool:
496
- results = pool.starmap(generate_output, [(image_tensor_1, input_ids)])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
497
 
498
- # Post-processing the output
499
- outputs = tokenizer.decode(results[0][0, input_ids.shape[1]:]).strip()
500
 
501
- if outputs.endswith(stop_str):
502
- outputs = outputs[:-len(stop_str)]
503
- response_str = outputs.strip()
 
 
 
504
 
505
- # Optional rendering for output formatting
506
- if render:
507
- print('==============rendering===============')
508
- if '**kern' in outputs:
509
- import verovio
510
- tk = verovio.toolkit()
511
- tk.loadData(outputs)
512
- tk.setOptions({
513
- "pageWidth": 2100,
514
- "footer": 'none',
515
- 'barLineWidth': 0.5,
516
- 'beamMaxSlope': 15,
517
- 'staffLineWidth': 0.2,
518
- 'spacingStaff': 6
519
- })
520
- svg = tk.renderToSVG()
521
- svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
522
- svg_to_html(svg, save_render_file)
523
  else:
524
- # If 'format' OCR is being used without '**kern'
525
  html_path_2 = save_render_file
526
- outputs = outputs.replace('"', '``').replace('$', '')
527
-
528
- # Properly balance the brackets
529
- outputs = outputs.replace('\left(', '(').replace('\\right)', ')')
530
- outputs = outputs.replace('\left[', '[').replace('\\right]', ']')
531
- outputs = outputs.replace('\left{', '{').replace('\\right}', '}')
532
- outputs = outputs.replace('\left|', '|').replace('\\right|', '|')
533
-
534
- lines = content_mmd_to_html if '\\begin{tikzpicture}' not in outputs else tik_html
 
 
 
 
 
 
 
 
 
 
535
  lines = lines.split("const text =")
536
- gt = '"' + '\\n'.join(outputs.split('\n')).replace('\\', '\\\\') + '"'
537
 
538
- new_web = lines[0] + 'const text =' + gt + lines[1]
539
- with open(html_path_2, 'w') as web_f_new:
540
- web_f_new.write(new_web)
541
 
542
- return response_str
543
 
544
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
545
 
 
408
  setattr(torch.nn.Linear, "reset_parameters", lambda self: None)
409
  setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None)
410
 
411
+ # Move the generate_output function outside the chat method
412
+ def generate_output(input_ids, image_tensor, model, tokenizer, stopping_criteria, stream_flag):
413
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
414
 
415
+ if stream_flag:
416
+ with torch.autocast("cpu", dtype=torch.bfloat16):
417
+ output_ids = model.generate(
418
+ input_ids,
419
+ images=[image_tensor.unsqueeze(0).half().cpu()],
420
+ do_sample=False,
421
+ num_beams=1,
422
+ no_repeat_ngram_size=20,
423
+ streamer=streamer,
424
+ max_new_tokens=4096,
425
+ stopping_criteria=[stopping_criteria]
426
+ )
427
+ else:
428
+ with torch.autocast("cpu", dtype=torch.bfloat16):
429
+ output_ids = model.generate(
430
+ input_ids,
431
+ images=[image_tensor.unsqueeze(0).half().cpu()],
432
+ do_sample=False,
433
+ num_beams=1,
434
+ no_repeat_ngram_size=20,
435
+ max_new_tokens=4096,
436
+ stopping_criteria=[stopping_criteria]
437
+ )
438
+ return output_ids
 
 
 
 
 
 
 
 
 
 
 
439
 
440
+ # The chat method optimized for CPU performance with multiprocessing
441
+ def chat(self, tokenizer, image_file, ocr_type, ocr_box='', ocr_color='', render=False, save_render_file=None,
442
+ print_prompt=False, gradio_input=False, stream_flag=False, num_workers=1):
443
+
444
+ self.disable_torch_init()
445
+
446
+ image_processor_high = GOTImageEvalProcessor(image_size=1024)
447
+ image_token_len = 256
448
+
449
+ if gradio_input:
450
+ image = image_file.copy()
451
+ else:
452
+ image = self.load_image(image_file)
453
+
454
+ w, h = image.size
455
+ qs = 'OCR with format: ' if ocr_type == 'format' else 'OCR: '
456
+
457
+ if ocr_box:
458
+ bbox = eval(ocr_box)
459
+ if len(bbox) == 2:
460
+ bbox[0] = int(bbox[0]/w*1000)
461
+ bbox[1] = int(bbox[1]/h*1000)
462
+ if len(bbox) == 4:
463
+ bbox[0] = int(bbox[0]/w*1000)
464
+ bbox[1] = int(bbox[1]/h*1000)
465
+ bbox[2] = int(bbox[2]/w*1000)
466
+ bbox[3] = int(bbox[3]/h*1000)
467
+ qs = str(bbox) + ' ' + qs
468
+
469
+ if ocr_color:
470
+ qs = f"[{ocr_color}] " + qs
471
+
472
+ qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_PATCH_TOKEN*image_token_len + DEFAULT_IM_END_TOKEN + '\n' + qs
473
+
474
+ # Setup conversation prompt
475
+ conv_mpt = Conversation(
476
  system="""<|im_start|>system
477
  You should follow the instructions carefully and explain your answers in detail.""",
478
  roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
 
481
  offset=0,
482
  sep_style=SeparatorStyle.MPT,
483
  sep="<|im_end|>",
484
+ )
485
 
486
+ conv = conv_mpt.copy()
487
+ conv.append_message(conv.roles[0], qs)
488
+ conv.append_message(conv.roles[1], None)
489
+ prompt = conv.get_prompt()
490
 
491
+ if print_prompt:
492
+ print(prompt)
 
 
 
493
 
494
+ inputs = tokenizer([prompt])
495
+ image_tensor_1 = image_processor_high(image)
496
+ input_ids = torch.as_tensor(inputs.input_ids).cpu()
497
 
498
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
499
+ keywords = [stop_str]
500
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
 
501
 
502
+ # Multiprocessing setup
503
+ with Pool(num_workers) as pool:
504
+ results = pool.starmap(
505
+ generate_output,
506
+ [(input_ids, image_tensor_1, self, tokenizer, stopping_criteria, stream_flag)] * num_workers
507
+ )
 
 
 
 
 
 
508
 
509
+ output_ids = results[0] # Take the first result (or aggregate depending on task)
510
+ outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()
511
+
512
+ if outputs.endswith(stop_str):
513
+ outputs = outputs[:-len(stop_str)]
514
+ outputs = outputs.strip()
515
+
516
+ response_str = outputs
517
+
518
+ # Rendering logic
519
+ if render:
520
+ print('==============rendering===============')
521
+ from .render_tools import svg_to_html, content_mmd_to_html, tik_html, translation_table
522
+
523
+ if '**kern' in outputs:
524
+ import verovio
525
+ tk = verovio.toolkit()
526
+ tk.loadData(outputs)
527
+ tk.setOptions({
528
+ "pageWidth": 2100, "footer": 'none',
529
+ 'barLineWidth': 0.5, 'beamMaxSlope': 15,
530
+ 'staffLineWidth': 0.2, 'spacingStaff': 6
531
+ })
532
+ tk.getPageCount()
533
+ svg = tk.renderToSVG()
534
+ svg = svg.replace("overflow=\"inherit\"", "overflow=\"visible\"")
535
+ svg_to_html(svg, save_render_file)
536
+
537
+ if ocr_type == 'format' and '**kern' not in outputs:
538
+ if '\\begin{tikzpicture}' not in outputs:
539
+ html_path_2 = save_render_file
540
+ right_num = outputs.count('\\right')
541
+ left_num = outputs.count('\\left')
542
 
543
+ if right_num != left_num:
544
+ outputs = outputs.replace('\left(', '(').replace('\\right)', ')').replace('\left[', '[').replace('\\right]', ']').replace('\left{', '{').replace('\\right}', '}').replace('\left|', '|').replace('\\right|', '|').replace('\left.', '.').replace('\\right.', '.')
545
 
546
+ outputs = outputs.replace('"', '``').replace('$', '')
547
+ outputs_list = outputs.split('\n')
548
+ gt = ''
549
+ for out in outputs_list:
550
+ gt += '"' + out.replace('\\', '\\\\') + r'\n' + '"' + '+' + '\n'
551
+ gt = gt[:-2]
552
 
553
+ lines = content_mmd_to_html
554
+ lines = lines.split("const text =")
555
+ new_web = lines[0] + 'const text =' + gt + lines[1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
556
  else:
 
557
  html_path_2 = save_render_file
558
+ outputs = outputs.translate(translation_table)
559
+ outputs_list = outputs.split('\n')
560
+ gt = ''
561
+ for out in outputs_list:
562
+ if out:
563
+ if '\\begin{tikzpicture}' not in out and '\\end{tikzpicture}' not in out:
564
+ while out[-1] == ' ':
565
+ out = out[:-1]
566
+ if out is None:
567
+ break
568
+ if out:
569
+ if out[-1] != ';':
570
+ gt += out[:-1] + ';\n'
571
+ else:
572
+ gt += out + '\n'
573
+ else:
574
+ gt += out + '\n'
575
+
576
+ lines = tik_html
577
  lines = lines.split("const text =")
578
+ new_web = lines[0] + gt + lines[1]
579
 
580
+ with open(html_path_2, 'w') as web_f_new:
581
+ web_f_new.write(new_web)
 
582
 
583
+ return response_str
584
 
585
  def dynamic_preprocess(self, image, min_num=1, max_num=6, image_size=1024, use_thumbnail=True):
586