sanghan commited on
Commit
efac2d0
·
1 Parent(s): c1cd135

check for gpu if none exists

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,24 +1,27 @@
1
  import torch
2
  import gradio as gr
3
 
4
- model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3").cuda() # or "resnet50"
 
 
 
5
 
6
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
7
 
8
 
9
  def inference(video):
10
  convert_video(
11
- model, # The loaded model, can be on any device (cpu or cuda).
12
- input_source=video, # A video file or an image sequence directory.
13
- downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px.
14
- output_type='video', # Choose "video" or "png_sequence"
15
- output_composition='com.mp4', # File path if video; directory path if png sequence.
16
- output_alpha=None, # [Optional] Output the raw alpha prediction.
17
- output_foreground=None, # [Optional] Output the raw foreground prediction.
18
- output_video_mbps=4, # Output video mbps. Not needed for png sequence.
19
- seq_chunk=12, # Process n frames at once for better parallelism.
20
- num_workers=1, # Only for image sequence input. Reader threads.
21
- progress=True # Print conversion progress.
22
  )
23
  return "com.mp4"
24
 
 
1
  import torch
2
  import gradio as gr
3
 
4
+ model = torch.hub.load("PeterL1n/RobustVideoMatting", "mobilenetv3")
5
+
6
+ if torch.cuda.is_available():
7
+ model = model.cuda()
8
 
9
  convert_video = torch.hub.load("PeterL1n/RobustVideoMatting", "converter")
10
 
11
 
12
  def inference(video):
13
  convert_video(
14
+ model, # The loaded model, can be on any device (cpu or cuda).
15
+ input_source=video, # A video file or an image sequence directory.
16
+ downsample_ratio=0.25, # [Optional] If None, make downsampled max size be 512px.
17
+ output_type="video", # Choose "video" or "png_sequence"
18
+ output_composition="com.mp4", # File path if video; directory path if png sequence.
19
+ output_alpha=None, # [Optional] Output the raw alpha prediction.
20
+ output_foreground=None, # [Optional] Output the raw foreground prediction.
21
+ output_video_mbps=4, # Output video mbps. Not needed for png sequence.
22
+ seq_chunk=12, # Process n frames at once for better parallelism.
23
+ num_workers=1, # Only for image sequence input. Reader threads.
24
+ progress=True, # Print conversion progress.
25
  )
26
  return "com.mp4"
27