Spaces:
Running
Running
jhj0517
commited on
Commit
·
88d2794
1
Parent(s):
bd79b3c
Use conditional `@skipif`
Browse files- tests/test_bgm_separation.py +9 -1
- tests/test_config.py +4 -0
tests/test_bgm_separation.py
CHANGED
@@ -6,9 +6,14 @@ from test_transcription import download_file, test_transcribe
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import pytest
|
|
|
9 |
import os
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
12 |
@pytest.mark.parametrize(
|
13 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
14 |
[
|
@@ -26,7 +31,10 @@ def test_bgm_separation_pipeline(
|
|
26 |
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
27 |
|
28 |
|
29 |
-
@pytest.mark.
|
|
|
|
|
|
|
30 |
@pytest.mark.parametrize(
|
31 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
32 |
[
|
|
|
6 |
|
7 |
import gradio as gr
|
8 |
import pytest
|
9 |
+
import torch
|
10 |
import os
|
11 |
|
12 |
|
13 |
+
@pytest.mark.skipif(
|
14 |
+
not is_cuda_available(),
|
15 |
+
reason="Skipping because the test only works on GPU"
|
16 |
+
)
|
17 |
@pytest.mark.parametrize(
|
18 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
19 |
[
|
|
|
31 |
test_transcribe(whisper_type, vad_filter, bgm_separation, diarization)
|
32 |
|
33 |
|
34 |
+
@pytest.mark.skipif(
|
35 |
+
not is_cuda_available(),
|
36 |
+
reason="Skipping because the test only works on GPU"
|
37 |
+
)
|
38 |
@pytest.mark.parametrize(
|
39 |
"whisper_type,vad_filter,bgm_separation,diarization",
|
40 |
[
|
tests/test_config.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from modules.utils.paths import *
|
2 |
|
3 |
import os
|
|
|
4 |
|
5 |
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
6 |
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
@@ -11,3 +12,6 @@ TEST_NLLB_MODEL = "facebook/nllb-200-distilled-600M"
|
|
11 |
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
12 |
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
13 |
|
|
|
|
|
|
|
|
1 |
from modules.utils.paths import *
|
2 |
|
3 |
import os
|
4 |
+
import torch
|
5 |
|
6 |
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
7 |
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
|
|
12 |
TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
13 |
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
14 |
|
15 |
+
|
16 |
+
def is_cuda_available():
|
17 |
+
return torch.cuda.is_available()
|