File size: 5,524 Bytes
7cc3853
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import io
import platform

from openai import APIConnectionError, AsyncOpenAI, UnprocessableEntityError
import pytest
import soundfile as sf

from faster_whisper_server.routers.speech import (
    DEFAULT_MODEL,
    DEFAULT_RESPONSE_FORMAT,
    DEFAULT_VOICE,
    SUPPORTED_RESPONSE_FORMATS,
    ResponseFormat,
)

DEFAULT_INPUT = "Hello, world!"

platform_machine = platform.machine()


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("response_format", SUPPORTED_RESPONSE_FORMATS)
async def test_create_speech_formats(openai_client: AsyncOpenAI, response_format: ResponseFormat) -> None:
    await openai_client.audio.speech.create(
        model=DEFAULT_MODEL,
        voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
        input=DEFAULT_INPUT,
        response_format=response_format,
    )


GOOD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
    ("tts-1", "alloy"),  # OpenAI and OpenAI
    ("tts-1-hd", "echo"),  # OpenAI and OpenAI
    ("tts-1", DEFAULT_VOICE),  # OpenAI and Piper
    (DEFAULT_MODEL, "echo"),  # Piper and OpenAI
    (DEFAULT_MODEL, DEFAULT_VOICE),  # Piper and Piper
]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize(("model", "voice"), GOOD_MODEL_VOICE_PAIRS)
async def test_create_speech_good_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
    await openai_client.audio.speech.create(
        model=model,
        voice=voice,  # type: ignore  # noqa: PGH003
        input=DEFAULT_INPUT,
        response_format=DEFAULT_RESPONSE_FORMAT,
    )


BAD_MODEL_VOICE_PAIRS: list[tuple[str, str]] = [
    ("tts-1", "invalid"),  # OpenAI and invalid
    ("invalid", "echo"),  # Invalid and OpenAI
    (DEFAULT_MODEL, "invalid"),  # Piper and invalid
    ("invalid", DEFAULT_VOICE),  # Invalid and Piper
    ("invalid", "invalid"),  # Invalid and invalid
]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize(("model", "voice"), BAD_MODEL_VOICE_PAIRS)
async def test_create_speech_bad_model_voice_pair(openai_client: AsyncOpenAI, model: str, voice: str) -> None:
    # NOTE: not sure why `APIConnectionError` is sometimes raised
    with pytest.raises((UnprocessableEntityError, APIConnectionError)):
        await openai_client.audio.speech.create(
            model=model,
            voice=voice,  # type: ignore  # noqa: PGH003
            input=DEFAULT_INPUT,
            response_format=DEFAULT_RESPONSE_FORMAT,
        )


SUPPORTED_SPEEDS = [0.25, 0.5, 1.0, 2.0, 4.0]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
async def test_create_speech_with_varying_speed(openai_client: AsyncOpenAI) -> None:
    previous_size: int | None = None
    for speed in SUPPORTED_SPEEDS:
        res = await openai_client.audio.speech.create(
            model=DEFAULT_MODEL,
            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
            input=DEFAULT_INPUT,
            response_format="pcm",
            speed=speed,
        )
        audio_bytes = res.read()
        if previous_size is not None:
            assert len(audio_bytes) * 1.5 < previous_size  # TODO: document magic number
        previous_size = len(audio_bytes)


UNSUPPORTED_SPEEDS = [0.1, 4.1]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("speed", UNSUPPORTED_SPEEDS)
async def test_create_speech_with_unsupported_speed(openai_client: AsyncOpenAI, speed: float) -> None:
    with pytest.raises(UnprocessableEntityError):
        await openai_client.audio.speech.create(
            model=DEFAULT_MODEL,
            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
            input=DEFAULT_INPUT,
            response_format="pcm",
            speed=speed,
        )


VALID_SAMPLE_RATES = [16000, 22050, 24000, 48000]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("sample_rate", VALID_SAMPLE_RATES)
async def test_speech_valid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
    res = await openai_client.audio.speech.create(
        model=DEFAULT_MODEL,
        voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
        input=DEFAULT_INPUT,
        response_format="wav",
        extra_body={"sample_rate": sample_rate},
    )
    _, actual_sample_rate = sf.read(io.BytesIO(res.content))
    assert actual_sample_rate == sample_rate


INVALID_SAMPLE_RATES = [7999, 48001]


@pytest.mark.asyncio
@pytest.mark.skipif(platform_machine != "x86_64", reason="Only supported on x86_64")
@pytest.mark.parametrize("sample_rate", INVALID_SAMPLE_RATES)
async def test_speech_invalid_resample(openai_client: AsyncOpenAI, sample_rate: int) -> None:
    with pytest.raises(UnprocessableEntityError):
        await openai_client.audio.speech.create(
            model=DEFAULT_MODEL,
            voice=DEFAULT_VOICE,  # type: ignore  # noqa: PGH003
            input=DEFAULT_INPUT,
            response_format="wav",
            extra_body={"sample_rate": sample_rate},
        )


# TODO: implement the following test

# NUMBER_OF_MODELS = 1
# NUMBER_OF_VOICES = 124
#
#
# @pytest.mark.asyncio
# async def test_list_tts_models(openai_client: AsyncOpenAI) -> None:
#     raise NotImplementedError