shaoyent commited on
Commit
59eb726
1 Parent(s): fbe445b
Files changed (1) hide show
  1. app.py +5 -2
app.py CHANGED
@@ -18,7 +18,10 @@ from pytube import YouTube
18
  from youtube_transcript_api import YouTubeTranscriptApi
19
  from youtube_transcript_api.formatters import WebVTTFormatter
20
 
21
- device = 'cpu'
 
 
 
22
  model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc'
23
  model = BridgeTowerForITC.from_pretrained(model_name).to(device)
24
  text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device)
@@ -282,7 +285,7 @@ def process(video_url, text_query):
282
  expanded=False,
283
  batch_size=8,
284
  )
285
- frame_paths, transcripts = run_query(video_id, text_query, path=output_dir)
286
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
287
 
288
 
 
18
  from youtube_transcript_api import YouTubeTranscriptApi
19
  from youtube_transcript_api.formatters import WebVTTFormatter
20
 
21
+ if torch.cuda.is_available():
22
+ device = 'cuda'
23
+ else:
24
+ device = 'cpu'
25
  model_name = 'BridgeTower/bridgetower-large-itm-mlm-itc'
26
  model = BridgeTowerForITC.from_pretrained(model_name).to(device)
27
  text_model = BridgeTowerTextFeatureExtractor.from_pretrained(model_name).to(device)
 
285
  expanded=False,
286
  batch_size=8,
287
  )
288
+ frame_paths, transcripts = run_query(video_file, text_query, path=output_dir)
289
  return video_file, [(image, caption) for image, caption in zip(frame_paths, transcripts)]
290
 
291