Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,6 +23,9 @@ import torch
|
|
23 |
from phonemizer import phonemize
|
24 |
from faster_whisper import WhisperModel
|
25 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
|
|
|
|
|
|
26 |
|
27 |
app = FastAPI()
|
28 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
@@ -32,8 +35,14 @@ phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-3
|
|
32 |
phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
|
33 |
whisper_model = WhisperModel("small", compute_type="float32")
|
34 |
|
35 |
-
# Cache for phoneme lookups
|
36 |
phoneme_cache = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
def log(msg):
|
39 |
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
@@ -55,6 +64,75 @@ def words_sim(a: str, b: str) -> float:
|
|
55 |
"""Cached similarity calculation"""
|
56 |
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
def get_reference_phonemes(words: List[str]) -> List[str]:
|
59 |
"""Get reference phonemes for all words at once with caching"""
|
60 |
cache_key = tuple(words)
|
@@ -316,6 +394,10 @@ async def transcribe(audio: UploadFile = File(...), similarity: float = Form(0.4
|
|
316 |
# Find best phoneme matches considering all variants
|
317 |
scores = find_best_phoneme_matches(segment_variants, reference_phonemes, sample_rate)
|
318 |
|
|
|
|
|
|
|
|
|
319 |
# Format output
|
320 |
full_text = " ".join(word_texts)
|
321 |
resolved_output = []
|
@@ -336,18 +418,58 @@ async def transcribe(audio: UploadFile = File(...), similarity: float = Form(0.4
|
|
336 |
return {
|
337 |
"transcript": full_text,
|
338 |
"resolved": " ".join(resolved_output),
|
339 |
-
"resolved_colored": " ".join(resolved_colored)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
}
|
341 |
|
342 |
@app.get("/")
|
343 |
def root():
|
344 |
return "fonetik running (optimized)"
|
345 |
|
346 |
-
# Optional: Add
|
347 |
@app.post("/api/clear-cache")
|
348 |
def clear_cache():
|
349 |
-
global phoneme_cache
|
350 |
phoneme_cache.clear()
|
|
|
351 |
normalize_phoneme_string.cache_clear()
|
352 |
words_sim.cache_clear()
|
353 |
-
return {"message": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
from phonemizer import phonemize
|
24 |
from faster_whisper import WhisperModel
|
25 |
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
26 |
+
import edge_tts
|
27 |
+
import asyncio
|
28 |
+
import base64
|
29 |
|
30 |
app = FastAPI()
|
31 |
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
|
|
|
35 |
phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
|
36 |
whisper_model = WhisperModel("small", compute_type="float32")
|
37 |
|
38 |
+
# Cache for phoneme lookups and TTS audio
|
39 |
phoneme_cache = {}
|
40 |
+
tts_cache = {}
|
41 |
+
|
42 |
+
# TTS configuration
|
43 |
+
TTS_VOICE = "en-US-AriaNeural" # High quality English voice
|
44 |
+
TTS_RATE = "+0%" # Normal speed
|
45 |
+
TTS_PITCH = "+0Hz" # Normal pitch
|
46 |
|
47 |
def log(msg):
|
48 |
print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
|
|
|
64 |
"""Cached similarity calculation"""
|
65 |
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
66 |
|
67 |
+
async def generate_expected_audio(word: str) -> str:
|
68 |
+
"""Generate TTS audio for expected pronunciation and return as base64"""
|
69 |
+
# Check cache first
|
70 |
+
cache_key = f"{word.lower()}_{TTS_VOICE}_{TTS_RATE}_{TTS_PITCH}"
|
71 |
+
if cache_key in tts_cache:
|
72 |
+
log(f"TTS cache hit for: {word}")
|
73 |
+
return tts_cache[cache_key]
|
74 |
+
|
75 |
+
log(f"Generating TTS for: {word}")
|
76 |
+
|
77 |
+
try:
|
78 |
+
# Generate TTS audio
|
79 |
+
communicate = edge_tts.Communicate(word, TTS_VOICE, rate=TTS_RATE, pitch=TTS_PITCH)
|
80 |
+
|
81 |
+
# Collect audio data
|
82 |
+
audio_data = b""
|
83 |
+
async for chunk in communicate.stream():
|
84 |
+
if chunk["type"] == "audio":
|
85 |
+
audio_data += chunk["data"]
|
86 |
+
|
87 |
+
if not audio_data:
|
88 |
+
log(f"No audio data generated for: {word}")
|
89 |
+
return ""
|
90 |
+
|
91 |
+
# Convert to base64
|
92 |
+
audio_b64 = base64.b64encode(audio_data).decode('utf-8')
|
93 |
+
|
94 |
+
# Cache the result
|
95 |
+
tts_cache[cache_key] = audio_b64
|
96 |
+
log(f"TTS generated and cached for: {word} ({len(audio_data)} bytes)")
|
97 |
+
|
98 |
+
return audio_b64
|
99 |
+
|
100 |
+
except Exception as e:
|
101 |
+
log(f"TTS generation failed for '{word}': {str(e)}")
|
102 |
+
return ""
|
103 |
+
|
104 |
+
def extract_user_audio_segment(waveform: torch.Tensor, sample_rate: int,
|
105 |
+
start_time: float, end_time: float) -> str:
|
106 |
+
"""Extract user's pronunciation segment and return as base64 WAV"""
|
107 |
+
try:
|
108 |
+
# Add small buffer around the word
|
109 |
+
buffer_time = 0.05 # 50ms buffer
|
110 |
+
adj_start = max(0, start_time - buffer_time)
|
111 |
+
adj_end = end_time + buffer_time
|
112 |
+
|
113 |
+
# Convert to sample indices
|
114 |
+
start_sample = int(adj_start * sample_rate)
|
115 |
+
end_sample = int(adj_end * sample_rate)
|
116 |
+
end_sample = min(waveform.shape[-1], end_sample)
|
117 |
+
|
118 |
+
# Extract segment
|
119 |
+
segment = waveform[:, start_sample:end_sample]
|
120 |
+
|
121 |
+
# Convert to bytes
|
122 |
+
buffer = io.BytesIO()
|
123 |
+
torchaudio.save(buffer, segment, sample_rate, format="wav")
|
124 |
+
buffer.seek(0)
|
125 |
+
|
126 |
+
# Convert to base64
|
127 |
+
audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
|
128 |
+
|
129 |
+
log(f"Extracted user audio segment: {start_time:.2f}s-{end_time:.2f}s ({len(audio_b64)} chars)")
|
130 |
+
return audio_b64
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
log(f"Failed to extract audio segment {start_time}-{end_time}: {str(e)}")
|
134 |
+
return ""
|
135 |
+
|
136 |
def get_reference_phonemes(words: List[str]) -> List[str]:
|
137 |
"""Get reference phonemes for all words at once with caching"""
|
138 |
cache_key = tuple(words)
|
|
|
394 |
# Find best phoneme matches considering all variants
|
395 |
scores = find_best_phoneme_matches(segment_variants, reference_phonemes, sample_rate)
|
396 |
|
397 |
+
# Generate audio data for playback
|
398 |
+
log("Generating audio data for playback...")
|
399 |
+
audio_data = await generate_audio_data(words, word_texts, waveform, sample_rate)
|
400 |
+
|
401 |
# Format output
|
402 |
full_text = " ".join(word_texts)
|
403 |
resolved_output = []
|
|
|
418 |
return {
|
419 |
"transcript": full_text,
|
420 |
"resolved": " ".join(resolved_output),
|
421 |
+
"resolved_colored": " ".join(resolved_colored),
|
422 |
+
"audio_data": audio_data,
|
423 |
+
"debug_info": {
|
424 |
+
"total_words": len(words),
|
425 |
+
"audio_segments_generated": len([a for a in audio_data if a["user_audio"]]),
|
426 |
+
"tts_segments_generated": len([a for a in audio_data if a["expected_audio"]]),
|
427 |
+
"cache_stats": {
|
428 |
+
"phoneme_cache_size": len(phoneme_cache),
|
429 |
+
"tts_cache_size": len(tts_cache)
|
430 |
+
}
|
431 |
+
}
|
432 |
}
|
433 |
|
434 |
@app.get("/")
|
435 |
def root():
|
436 |
return "fonetik running (optimized)"
|
437 |
|
438 |
+
# Optional: Add endpoints for debugging and cache management
|
439 |
@app.post("/api/clear-cache")
|
440 |
def clear_cache():
|
441 |
+
global phoneme_cache, tts_cache
|
442 |
phoneme_cache.clear()
|
443 |
+
tts_cache.clear()
|
444 |
normalize_phoneme_string.cache_clear()
|
445 |
words_sim.cache_clear()
|
446 |
+
return {"message": "All caches cleared"}
|
447 |
+
|
448 |
+
@app.get("/api/debug/cache-stats")
|
449 |
+
def get_cache_stats():
|
450 |
+
return {
|
451 |
+
"phoneme_cache_size": len(phoneme_cache),
|
452 |
+
"tts_cache_size": len(tts_cache),
|
453 |
+
"lru_cache_info": {
|
454 |
+
"normalize_phoneme_string": normalize_phoneme_string.cache_info()._asdict(),
|
455 |
+
"words_sim": words_sim.cache_info()._asdict()
|
456 |
+
}
|
457 |
+
}
|
458 |
+
|
459 |
+
@app.post("/api/debug/test-tts")
|
460 |
+
async def test_tts(word: str = Form(...)):
|
461 |
+
"""Test TTS generation for a single word"""
|
462 |
+
try:
|
463 |
+
audio_b64 = await generate_expected_audio(word)
|
464 |
+
return {
|
465 |
+
"word": word,
|
466 |
+
"success": len(audio_b64) > 0,
|
467 |
+
"audio_length": len(audio_b64),
|
468 |
+
"audio_data": audio_b64 if len(audio_b64) > 0 else None
|
469 |
+
}
|
470 |
+
except Exception as e:
|
471 |
+
return {
|
472 |
+
"word": word,
|
473 |
+
"success": False,
|
474 |
+
"error": str(e)
|
475 |
+
}
|