Ahsen Khaliq commited on
Commit
b76ff72
·
1 Parent(s): 1c68c58

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -2
app.py CHANGED
@@ -10,6 +10,7 @@ import gradio as gr
10
  from torchvision import transforms
11
  import torchtext
12
 
 
13
 
14
  # Images
15
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
@@ -512,7 +513,6 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
512
  save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
513
  return os.path.join(frame_dir, "animation.gif"), tensor_to_pil
514
 
515
-
516
  def gradio_inference(image):
517
  return main(input_path=image.name,
518
  model_path='model.pth',
@@ -522,12 +522,26 @@ def gradio_inference(image):
522
  resize_w=300, # resize original input to this size. None means do not resize.
523
  serial=True) # if need animation, serial must be True.
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  title = "Paint Transformer"
526
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
527
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
528
 
529
  gr.Interface(
530
- gradio_inference,
531
  gr.inputs.Image(type="file", label="Input"),
532
  [gr.outputs.Image(type="file", label="Output GIF"),
533
  gr.outputs.Image(type="pil", label="Output Image")],
 
10
  from torchvision import transforms
11
  import torchtext
12
 
13
+ torch.hub.download_url_to_file('https://i.imgur.com/WEHmKef.jpg', 'gpu.jpg')
14
 
15
  # Images
16
  torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
 
513
  save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
514
  return os.path.join(frame_dir, "animation.gif"), tensor_to_pil
515
 
 
516
  def gradio_inference(image):
517
  return main(input_path=image.name,
518
  model_path='model.pth',
 
522
  resize_w=300, # resize original input to this size. None means do not resize.
523
  serial=True) # if need animation, serial must be True.
524
 
525
+ def throttled_inference(image):
526
+ global inferences_running
527
+ current = inferences_running
528
+ if current >= 4:
529
+ print(f"Rejected inference when we already had {current} running")
530
+ return load_image("./gpu.jpg")
531
+ print(f"Inference starting when we already had {current} running")
532
+ inferences_running += 1
533
+ try:
534
+ return inference(image)
535
+ finally:
536
+ print("Inference finished")
537
+ inferences_running -= 1
538
+
539
  title = "Paint Transformer"
540
  description = "Gradio demo for Paint Transformer: Feed Forward Neural Painting with Stroke Prediction. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
541
  article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2108.03798'>Paint Transformer: Feed Forward Neural Painting with Stroke Prediction</a> | <a href='https://github.com/Huage001/PaintTransformer'>Github Repo</a></p>"
542
 
543
  gr.Interface(
544
+ throttled_inference,
545
  gr.inputs.Image(type="file", label="Input"),
546
  [gr.outputs.Image(type="file", label="Output GIF"),
547
  gr.outputs.Image(type="pil", label="Output Image")],