pszemraj commited on
Commit
435abb4
1 Parent(s): 2b8c4c9

✨ add ability to add custom options in CLI

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (1) hide show
  1. app.py +59 -12
app.py CHANGED
@@ -2,6 +2,9 @@
2
  app.py - the main module for the gradio app for summarization
3
 
4
  Usage:
 
 
 
5
  python app.py --help
6
 
7
  Environment Variables:
@@ -18,6 +21,8 @@ import logging
18
  import os
19
  import random
20
  import re
 
 
21
  import time
22
  from pathlib import Path
23
 
@@ -52,7 +57,7 @@ _here = Path(__file__).parent
52
  nltk.download("punkt", force=True, quiet=True)
53
  nltk.download("popular", force=True, quiet=True)
54
 
55
-
56
  MODEL_OPTIONS = [
57
  "pszemraj/long-t5-tglobal-base-16384-book-summary",
58
  "pszemraj/long-t5-tglobal-base-sci-simplify",
@@ -60,6 +65,14 @@ MODEL_OPTIONS = [
60
  "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
61
  "pszemraj/pegasus-x-large-book-summary",
62
  ] # models users can choose from
 
 
 
 
 
 
 
 
63
 
64
  SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
65
  AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
@@ -67,8 +80,11 @@ AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
67
  # if duplicating space: uncomment this line to adjust the max words
68
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
69
  # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
 
70
 
71
- aggregator = BatchAggregator(AGGREGATE_MODEL)
 
 
72
 
73
 
74
  def aggregate_text(
@@ -364,10 +380,11 @@ def load_uploaded_file(file_obj, max_pages: int = 20, lower: bool = False) -> st
364
  def parse_args():
365
  """arguments for the command line interface"""
366
  parser = argparse.ArgumentParser(
367
- description="Document Summarization with Long-Document Transformers Demo",
368
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
369
- epilog="Runs a local-only web app to summarize documents. use --share for a public link to share.",
370
  )
 
371
  parser.add_argument(
372
  "--share",
373
  dest="share",
@@ -379,16 +396,34 @@ def parse_args():
379
  "--model",
380
  type=str,
381
  default=None,
382
- help=f"Add a custom model to the list of models: {', '.join(MODEL_OPTIONS)}",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  )
384
  parser.add_argument(
385
  "-level",
386
- "--log-level",
387
  type=str,
388
  default="INFO",
389
  choices=["DEBUG", "INFO", "WARNING", "ERROR"],
390
  help="Set the logging level",
391
  )
 
 
 
 
392
  return parser.parse_args()
393
 
394
 
@@ -397,11 +432,19 @@ if __name__ == "__main__":
397
  logger = logging.getLogger(__name__)
398
  args = parse_args()
399
  logger.setLevel(args.log_level)
400
- logger.info(f"args: {args}")
 
 
401
  if args.model is not None:
402
  logger.info(f"Adding model {args.model} to the list of models")
403
  MODEL_OPTIONS.append(args.model)
404
- logger.info("Starting app instance")
 
 
 
 
 
 
405
  logger.info("Loading OCR model")
406
  with contextlib.redirect_stdout(None):
407
  ocr_model = ocr_predictor(
@@ -410,11 +453,14 @@ if __name__ == "__main__":
410
  pretrained=True,
411
  assume_straight_pages=True,
412
  )
 
 
413
  name_to_path = load_example_filenames(_here / "examples")
414
  logger.info(f"Loaded {len(name_to_path)} examples")
415
 
416
  demo = gr.Blocks(title="Document Summarization with Long-Document Transformers")
417
  _examples = list(name_to_path.keys())
 
418
  with demo:
419
  gr.Markdown("# Document Summarization with Long-Document Transformers")
420
  gr.Markdown(
@@ -436,9 +482,9 @@ if __name__ == "__main__":
436
  label="Model Name",
437
  )
438
  num_beams = gr.Radio(
439
- choices=[2, 3, 4],
440
  label="Beam Search: # of Beams",
441
- value=2,
442
  )
443
  load_examples_button = gr.Button(
444
  "Load Example in Dropdown",
@@ -542,9 +588,10 @@ if __name__ == "__main__":
542
  step=0.05,
543
  )
544
  token_batch_length = gr.Radio(
545
- choices=[1024, 1536, 2048, 2560, 3072],
546
  label="token batch length",
547
- value=2048,
 
548
  )
549
 
550
  with gr.Row(variant="compact"):
 
2
  app.py - the main module for the gradio app for summarization
3
 
4
  Usage:
5
+ app.py [-h] [--share] [-m MODEL] [-nb ADD_BEAM_OPTION] [-batch TOKEN_BATCH_OPTION]
6
+ [-level {DEBUG,INFO,WARNING,ERROR}]
7
+ Details:
8
  python app.py --help
9
 
10
  Environment Variables:
 
21
  import os
22
  import random
23
  import re
24
+ import pprint as pp
25
+ import sys
26
  import time
27
  from pathlib import Path
28
 
 
57
  nltk.download("punkt", force=True, quiet=True)
58
  nltk.download("popular", force=True, quiet=True)
59
 
60
+ # Constants & Globals
61
  MODEL_OPTIONS = [
62
  "pszemraj/long-t5-tglobal-base-16384-book-summary",
63
  "pszemraj/long-t5-tglobal-base-sci-simplify",
 
65
  "pszemraj/long-t5-tglobal-base-16384-booksci-summary-v1",
66
  "pszemraj/pegasus-x-large-book-summary",
67
  ] # models users can choose from
68
+ BEAM_OPTIONS = [2, 3, 4] # beam sizes users can choose from
69
+ TOKEN_BATCH_OPTIONS = [
70
+ 1024,
71
+ 1536,
72
+ 2048,
73
+ 2560,
74
+ 3072,
75
+ ] # token batch sizes users can choose from
76
 
77
  SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
78
  AGGREGATE_MODEL = "MBZUAI/LaMini-Flan-T5-783M" # model to use for aggregation
 
80
  # if duplicating space: uncomment this line to adjust the max words
81
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
82
  # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
83
+ # os.environ["APP_AGG_FORCE_CPU"] = str(1) # force cpu for aggregation
84
 
85
+ aggregator = BatchAggregator(
86
+ AGGREGATE_MODEL, force_cpu=os.environ.get("APP_AGG_FORCE_CPU", False)
87
+ )
88
 
89
 
90
  def aggregate_text(
 
380
  def parse_args():
381
  """arguments for the command line interface"""
382
  parser = argparse.ArgumentParser(
383
+ description="Document Summarization with Long-Document Transformers - Demo",
384
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
385
+ epilog="Runs a local-only web UI to summarize documents. pass --share for a public link to share.",
386
  )
387
+
388
  parser.add_argument(
389
  "--share",
390
  dest="share",
 
396
  "--model",
397
  type=str,
398
  default=None,
399
+ help=f"Add a custom model to the list of models: {pp.pformat(MODEL_OPTIONS, compact=True)}",
400
+ )
401
+ parser.add_argument(
402
+ "-nb",
403
+ "--add_beam_option",
404
+ type=int,
405
+ default=None,
406
+ help=f"Add a beam search option to the list of beam search options: {pp.pformat(BEAM_OPTIONS, compact=True)}",
407
+ )
408
+ parser.add_argument(
409
+ "-batch",
410
+ "--token_batch_option",
411
+ type=int,
412
+ default=None,
413
+ help=f"Add a token batch option to the list of token batch options: {pp.pformat(TOKEN_BATCH_OPTIONS, compact=True)}",
414
  )
415
  parser.add_argument(
416
  "-level",
417
+ "--log_level",
418
  type=str,
419
  default="INFO",
420
  choices=["DEBUG", "INFO", "WARNING", "ERROR"],
421
  help="Set the logging level",
422
  )
423
+ # if "--help" in sys.argv or "-h" in sys.argv:
424
+ # parser.print_help()
425
+ # sys.exit(0)
426
+
427
  return parser.parse_args()
428
 
429
 
 
432
  logger = logging.getLogger(__name__)
433
  args = parse_args()
434
  logger.setLevel(args.log_level)
435
+ logger.info(f"args: {pp.pformat(args.__dict__, compact=True)}")
436
+
437
+ # add any custom options
438
  if args.model is not None:
439
  logger.info(f"Adding model {args.model} to the list of models")
440
  MODEL_OPTIONS.append(args.model)
441
+ if args.add_beam_option is not None:
442
+ logger.info(f"Adding beam search option {args.add_beam_option} to the list")
443
+ BEAM_OPTIONS.append(args.add_beam_option)
444
+ if args.token_batch_option is not None:
445
+ logger.info(f"Adding token batch option {args.token_batch_option} to the list")
446
+ TOKEN_BATCH_OPTIONS.append(args.token_batch_option)
447
+
448
  logger.info("Loading OCR model")
449
  with contextlib.redirect_stdout(None):
450
  ocr_model = ocr_predictor(
 
453
  pretrained=True,
454
  assume_straight_pages=True,
455
  )
456
+
457
+ # load the examples
458
  name_to_path = load_example_filenames(_here / "examples")
459
  logger.info(f"Loaded {len(name_to_path)} examples")
460
 
461
  demo = gr.Blocks(title="Document Summarization with Long-Document Transformers")
462
  _examples = list(name_to_path.keys())
463
+ logger.info("Starting app instance")
464
  with demo:
465
  gr.Markdown("# Document Summarization with Long-Document Transformers")
466
  gr.Markdown(
 
482
  label="Model Name",
483
  )
484
  num_beams = gr.Radio(
485
+ choices=BEAM_OPTIONS,
486
  label="Beam Search: # of Beams",
487
+ value=BEAM_OPTIONS[0],
488
  )
489
  load_examples_button = gr.Button(
490
  "Load Example in Dropdown",
 
588
  step=0.05,
589
  )
590
  token_batch_length = gr.Radio(
591
+ choices=TOKEN_BATCH_OPTIONS,
592
  label="token batch length",
593
+ # select median option
594
+ value=TOKEN_BATCH_OPTIONS[len(TOKEN_BATCH_OPTIONS) // 2],
595
  )
596
 
597
  with gr.Row(variant="compact"):