Spaces:
Running
Running
File size: 4,647 Bytes
b128c76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import os
import queue
from dataclasses import dataclass
from typing import Annotated, Literal
import torch
from pydantic import BaseModel, Field, conint, conlist
from pydantic.functional_validators import SkipValidation
from fish_speech.conversation import Message, TextPart, VQPart
class ServeVQPart(BaseModel):
type: Literal["vq"] = "vq"
codes: SkipValidation[list[list[int]]]
class ServeTextPart(BaseModel):
type: Literal["text"] = "text"
text: str
class ServeAudioPart(BaseModel):
type: Literal["audio"] = "audio"
audio: bytes
@dataclass
class ASRPackRequest:
audio: torch.Tensor
result_queue: queue.Queue
language: str
class ServeASRRequest(BaseModel):
# The audio should be an uncompressed PCM float16 audio
audios: list[bytes]
sample_rate: int = 44100
language: Literal["zh", "en", "ja", "auto"] = "auto"
class ServeASRTranscription(BaseModel):
text: str
duration: float
huge_gap: bool
class ServeASRSegment(BaseModel):
text: str
start: float
end: float
class ServeTimedASRResponse(BaseModel):
text: str
segments: list[ServeASRSegment]
duration: float
class ServeASRResponse(BaseModel):
transcriptions: list[ServeASRTranscription]
class ServeMessage(BaseModel):
role: Literal["system", "assistant", "user"]
parts: list[ServeVQPart | ServeTextPart]
def to_conversation_message(self):
new_message = Message(role=self.role, parts=[])
if self.role == "assistant":
new_message.modality = "voice"
for part in self.parts:
if isinstance(part, ServeTextPart):
new_message.parts.append(TextPart(text=part.text))
elif isinstance(part, ServeVQPart):
new_message.parts.append(
VQPart(codes=torch.tensor(part.codes, dtype=torch.int))
)
else:
raise ValueError(f"Unsupported part type: {part}")
return new_message
class ServeChatRequest(BaseModel):
messages: Annotated[list[ServeMessage], conlist(ServeMessage, min_length=1)]
max_new_tokens: int = 1024
top_p: float = 0.7
repetition_penalty: float = 1.2
temperature: float = 0.7
streaming: bool = False
num_samples: int = 1
early_stop_threshold: float = 1.0
class ServeVQGANEncodeRequest(BaseModel):
# The audio here should be in wav, mp3, etc
audios: list[bytes]
class ServeVQGANEncodeResponse(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeRequest(BaseModel):
tokens: SkipValidation[list[list[list[int]]]]
class ServeVQGANDecodeResponse(BaseModel):
# The audio here should be in PCM float16 format
audios: list[bytes]
class ServeForwardMessage(BaseModel):
role: str
content: str
class ServeResponse(BaseModel):
messages: list[ServeMessage]
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] = {}
class ServeStreamDelta(BaseModel):
role: Literal["system", "assistant", "user"] | None = None
part: ServeVQPart | ServeTextPart | None = None
class ServeStreamResponse(BaseModel):
sample_id: int = 0
delta: ServeStreamDelta | None = None
finish_reason: Literal["stop", "error"] | None = None
stats: dict[str, int | float | str] | None = None
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str
def __repr__(self) -> str:
return f"ServeReferenceAudio(text={self.text!r}, audio_size={len(self.audio)})"
class ServeTTSRequest(BaseModel):
text: str
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
seed: int | None = None
use_memory_cache: Literal["on", "off"] = "off"
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
# not usually used below
streaming: bool = False
max_new_tokens: int = 1024
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
class Config:
# Allow arbitrary types for pytorch related types
arbitrary_types_allowed = True
|