Ahsen Khaliq commited on
Commit
6dab69a
·
1 Parent(s): 719bab7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -4
app.py CHANGED
@@ -12,6 +12,7 @@ import torchtext
12
  from stat import ST_CTIME
13
  from datetime import datetime, timedelta
14
  import shutil
 
15
  # Images
16
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
17
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
@@ -357,7 +358,6 @@ def crop(img, h, w):
357
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
358
  if not os.path.exists(output_dir):
359
  os.mkdir(output_dir)
360
-
361
  for entry in os.listdir(output_dir):
362
  path = os.path.join(output_dir, entry)
363
  stats = os.stat(path)
@@ -485,12 +485,25 @@ def gradio_inference(image):
485
  resize_h=400, # resize original input to this size. None means do not resize.
486
  resize_w=400, # resize original input to this size. None means do not resize.
487
  serial=True) # if need animation, serial must be True.
488
-
 
 
 
 
 
 
 
 
 
 
 
 
 
489
  title = "Paint Transformer"
490
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
491
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
492
  gr.Interface(
493
- gradio_inference,
494
  gr.inputs.Image(type="file", label="Input"),
495
  [gr.outputs.Image(type="file", label="Output GIF"),
496
  gr.outputs.Image(type="pil", label="Output Image")],
@@ -500,4 +513,5 @@ gr.Interface(
500
  examples=[
501
  ['city.jpg'],
502
  ['tower.jpg']
503
- ],enable_queue=True).launch(debug=True)
 
 
12
  from stat import ST_CTIME
13
  from datetime import datetime, timedelta
14
  import shutil
15
+ torch.hub.download_url_to_file('https://i.imgur.com/tXrot31.jpg', 'gpu.jpg')
16
  # Images
17
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
18
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
 
358
  def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
359
  if not os.path.exists(output_dir):
360
  os.mkdir(output_dir)
 
361
  for entry in os.listdir(output_dir):
362
  path = os.path.join(output_dir, entry)
363
  stats = os.stat(path)
 
485
  resize_h=400, # resize original input to this size. None means do not resize.
486
  resize_w=400, # resize original input to this size. None means do not resize.
487
  serial=True) # if need animation, serial must be True.
488
+ inferences_running = 0
489
+ def throttled_inference(image):
490
+ global inferences_running
491
+ current = inferences_running
492
+ if current >= 5:
493
+ print(f"Rejected inference when we already had {current} running")
494
+ return "gpu.jpg",Image.open("gpu.jpg")
495
+ print(f"Inference starting when we already had {current} running")
496
+ inferences_running += 1
497
+ try:
498
+ return gradio_inference(image)
499
+ finally:
500
+ print("Inference finished")
501
+ inferences_running -= 1
502
  title = "Paint Transformer"
503
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
504
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
505
  gr.Interface(
506
+ throttled_inference,
507
  gr.inputs.Image(type="file", label="Input"),
508
  [gr.outputs.Image(type="file", label="Output GIF"),
509
  gr.outputs.Image(type="pil", label="Output Image")],
 
513
  examples=[
514
  ['city.jpg'],
515
  ['tower.jpg']
516
+ ]
517
+ ).launch(debug=True)