Fedir Zadniprovskyi commited on
Commit
a5d2e48
1 Parent(s): f5d1866

test: capture openai's param handling

Browse files
src/faster_whisper_server/routers/stt.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  import asyncio
4
  from io import BytesIO
5
  import logging
6
- from typing import TYPE_CHECKING, Annotated, Literal
7
 
8
  from fastapi import (
9
  APIRouter,
@@ -30,6 +30,7 @@ from faster_whisper_server.config import (
30
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
31
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
32
  from faster_whisper_server.server_models import (
 
33
  TranscriptionJsonResponse,
34
  TranscriptionVerboseJsonResponse,
35
  )
@@ -165,7 +166,7 @@ def transcribe_file(
165
  response_format: Annotated[ResponseFormat | None, Form()] = None,
166
  temperature: Annotated[float, Form()] = 0.0,
167
  timestamp_granularities: Annotated[
168
- list[Literal["segment", "word"]],
169
  Form(alias="timestamp_granularities[]"),
170
  ] = ["segment"],
171
  stream: Annotated[bool, Form()] = False,
 
3
  import asyncio
4
  from io import BytesIO
5
  import logging
6
+ from typing import TYPE_CHECKING, Annotated
7
 
8
  from fastapi import (
9
  APIRouter,
 
30
  from faster_whisper_server.core import Segment, segments_to_srt, segments_to_text, segments_to_vtt
31
  from faster_whisper_server.dependencies import ConfigDependency, ModelManagerDependency, get_config
32
  from faster_whisper_server.server_models import (
33
+ TimestampGranularities,
34
  TranscriptionJsonResponse,
35
  TranscriptionVerboseJsonResponse,
36
  )
 
166
  response_format: Annotated[ResponseFormat | None, Form()] = None,
167
  temperature: Annotated[float, Form()] = 0.0,
168
  timestamp_granularities: Annotated[
169
+ TimestampGranularities,
170
  Form(alias="timestamp_granularities[]"),
171
  ] = ["segment"],
172
  stream: Annotated[bool, Form()] = False,
src/faster_whisper_server/server_models.py CHANGED
@@ -107,3 +107,15 @@ class ModelObject(BaseModel):
107
  ]
108
  },
109
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  ]
108
  },
109
  )
110
+
111
+
112
+ TimestampGranularities = list[Literal["segment", "word"]]
113
+
114
+
115
+ TIMESTAMP_GRANULARITIES_COMBINATIONS: list[TimestampGranularities] = [
116
+ [], # should be treated as ["segment"]. https://platform.openai.com/docs/api-reference/audio/createTranscription#audio-createtranscription-timestamp_granularities
117
+ ["segment"],
118
+ ["word"],
119
+ ["word", "segment"],
120
+ ["segment", "word"], # same as ["word", "segment"] but order is different
121
+ ]
tests/conftest.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  from fastapi.testclient import TestClient
6
  from faster_whisper_server.main import create_app
7
  from httpx import ASGITransport, AsyncClient
8
- from openai import OpenAI
9
  import pytest
10
  import pytest_asyncio
11
 
@@ -35,3 +35,10 @@ async def aclient() -> AsyncGenerator[AsyncClient, None]:
35
  @pytest.fixture()
36
  def openai_client(client: TestClient) -> OpenAI:
37
  return OpenAI(api_key="cant-be-empty", http_client=client)
 
 
 
 
 
 
 
 
5
  from fastapi.testclient import TestClient
6
  from faster_whisper_server.main import create_app
7
  from httpx import ASGITransport, AsyncClient
8
+ from openai import AsyncOpenAI, OpenAI
9
  import pytest
10
  import pytest_asyncio
11
 
 
35
  @pytest.fixture()
36
  def openai_client(client: TestClient) -> OpenAI:
37
  return OpenAI(api_key="cant-be-empty", http_client=client)
38
+
39
+
40
+ @pytest.fixture()
41
+ def actual_openai_client() -> AsyncOpenAI:
42
+ return AsyncOpenAI(
43
+ base_url="https://api.openai.com/v1"
44
+ ) # `base_url` is provided in case `OPENAI_API_BASE_URL` is set to a different value
tests/openai_timestamp_granularities_test.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """OpenAI's handling of `response_format` and `timestamp_granularities` is a bit confusing and inconsistent. This test module exists to capture the OpenAI API's behavior with respect to these parameters.""" # noqa: E501
2
+
3
+ from faster_whisper_server.server_models import TIMESTAMP_GRANULARITIES_COMBINATIONS, TimestampGranularities
4
+ from openai import AsyncOpenAI, BadRequestError
5
+ import pytest
6
+
7
+
8
+ @pytest.mark.asyncio()
9
+ @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
10
+ async def test_openai_json_response_format_and_timestamp_granularities_combinations(
11
+ actual_openai_client: AsyncOpenAI,
12
+ timestamp_granularities: TimestampGranularities,
13
+ ) -> None:
14
+ audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
15
+
16
+ if "word" in timestamp_granularities:
17
+ with pytest.raises(BadRequestError):
18
+ await actual_openai_client.audio.transcriptions.create(
19
+ file=audio_file,
20
+ model="whisper-1",
21
+ response_format="json",
22
+ timestamp_granularities=timestamp_granularities,
23
+ )
24
+ else:
25
+ await actual_openai_client.audio.transcriptions.create(
26
+ file=audio_file, model="whisper-1", response_format="json", timestamp_granularities=timestamp_granularities
27
+ )
28
+
29
+
30
+ @pytest.mark.asyncio()
31
+ @pytest.mark.parametrize("timestamp_granularities", TIMESTAMP_GRANULARITIES_COMBINATIONS)
32
+ async def test_openai_verbose_json_response_format_and_timestamp_granularities_combinations(
33
+ actual_openai_client: AsyncOpenAI,
34
+ timestamp_granularities: TimestampGranularities,
35
+ ) -> None:
36
+ audio_file = open("audio.wav", "rb") # noqa: SIM115, ASYNC230
37
+
38
+ transcription = await actual_openai_client.audio.transcriptions.create(
39
+ file=audio_file,
40
+ model="whisper-1",
41
+ response_format="verbose_json",
42
+ timestamp_granularities=timestamp_granularities,
43
+ )
44
+
45
+ assert transcription.__pydantic_extra__
46
+ if timestamp_granularities == ["word"]:
47
+ # This is an exception where segments are not present
48
+ assert transcription.__pydantic_extra__.get("segments") is None
49
+ assert transcription.__pydantic_extra__.get("words") is not None
50
+ elif "word" in timestamp_granularities:
51
+ assert transcription.__pydantic_extra__.get("segments") is not None
52
+ assert transcription.__pydantic_extra__.get("words") is not None
53
+ else:
54
+ # Unless explicitly requested, words are not present
55
+ assert transcription.__pydantic_extra__.get("segments") is not None
56
+ assert transcription.__pydantic_extra__.get("words") is None