jhj0517 commited on
Commit
88d2794
·
1 Parent(s): bd79b3c

Use conditional `@skipif`

Browse files
tests/test_bgm_separation.py CHANGED
@@ -6,9 +6,14 @@ from test_transcription import download_file, test_transcribe
6
 
7
  import gradio as gr
8
  import pytest
 
9
  import os
10
 
11
 
 
 
 
 
12
  @pytest.mark.parametrize(
13
  "whisper_type,vad_filter,bgm_separation,diarization",
14
  [
@@ -26,7 +31,10 @@ def test_bgm_separation_pipeline(
26
  test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
27
 
28
 
29
- @pytest.mark.skip(reason="Too heavy to run in actions with all of other tests")
 
 
 
30
  @pytest.mark.parametrize(
31
  "whisper_type,vad_filter,bgm_separation,diarization",
32
  [
 
6
 
7
  import gradio as gr
8
  import pytest
9
+ import torch
10
  import os
11
 
12
 
13
+ @pytest.mark.skipif(
14
+ not is_cuda_available(),
15
+ reason="Skipping because the test only works on GPU"
16
+ )
17
  @pytest.mark.parametrize(
18
  "whisper_type,vad_filter,bgm_separation,diarization",
19
  [
 
31
  test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
32
 
33
 
34
+ @pytest.mark.skipif(
35
+ not is_cuda_available(),
36
+ reason="Skipping because the test only works on GPU"
37
+ )
38
  @pytest.mark.parametrize(
39
  "whisper_type,vad_filter,bgm_separation,diarization",
40
  [
tests/test_config.py CHANGED
@@ -1,6 +1,7 @@
1
  from modules.utils.paths import *
2
 
3
  import os
 
4
 
5
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
6
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
@@ -11,3 +12,6 @@ TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
11
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
12
  TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
13
 
 
 
 
 
1
  from modules.utils.paths import *
2
 
3
  import os
4
+ import torch
5
 
6
  TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
7
  TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
 
12
  TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
13
  TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
14
 
15
+
16
+ def is_cuda_available():
17
+ return torch.cuda.is_available()