greg0rs commited on
Commit
2c6f1eb
·
verified ·
1 Parent(s): f1ddfba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -5
app.py CHANGED
@@ -23,6 +23,9 @@ import torch
23
  from phonemizer import phonemize
24
  from faster_whisper import WhisperModel
25
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
 
 
 
26
 
27
  app = FastAPI()
28
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
@@ -32,8 +35,14 @@ phoneme_processor = Wav2Vec2Processor.from_pretrained("vitouphy/wav2vec2-xls-r-3
32
  phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
33
  whisper_model = WhisperModel("small", compute_type="float32")
34
 
35
- # Cache for phoneme lookups
36
  phoneme_cache = {}
 
 
 
 
 
 
37
 
38
  def log(msg):
39
  print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
@@ -55,6 +64,75 @@ def words_sim(a: str, b: str) -> float:
55
  """Cached similarity calculation"""
56
  return SequenceMatcher(None, a.lower(), b.lower()).ratio()
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def get_reference_phonemes(words: List[str]) -> List[str]:
59
  """Get reference phonemes for all words at once with caching"""
60
  cache_key = tuple(words)
@@ -316,6 +394,10 @@ async def transcribe(audio: UploadFile = File(...), similarity: float = Form(0.4
316
  # Find best phoneme matches considering all variants
317
  scores = find_best_phoneme_matches(segment_variants, reference_phonemes, sample_rate)
318
 
 
 
 
 
319
  # Format output
320
  full_text = " ".join(word_texts)
321
  resolved_output = []
@@ -336,18 +418,58 @@ async def transcribe(audio: UploadFile = File(...), similarity: float = Form(0.4
336
  return {
337
  "transcript": full_text,
338
  "resolved": " ".join(resolved_output),
339
- "resolved_colored": " ".join(resolved_colored)
 
 
 
 
 
 
 
 
 
 
340
  }
341
 
342
  @app.get("/")
343
  def root():
344
  return "fonetik running (optimized)"
345
 
346
- # Optional: Add an endpoint to clear caches if needed
347
  @app.post("/api/clear-cache")
348
  def clear_cache():
349
- global phoneme_cache
350
  phoneme_cache.clear()
 
351
  normalize_phoneme_string.cache_clear()
352
  words_sim.cache_clear()
353
- return {"message": "Cache cleared"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  from phonemizer import phonemize
24
  from faster_whisper import WhisperModel
25
  from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
26
+ import edge_tts
27
+ import asyncio
28
+ import base64
29
 
30
  app = FastAPI()
31
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"])
 
35
  phoneme_model = Wav2Vec2ForCTC.from_pretrained("vitouphy/wav2vec2-xls-r-300m-timit-phoneme")
36
  whisper_model = WhisperModel("small", compute_type="float32")
37
 
38
+ # Cache for phoneme lookups and TTS audio
39
  phoneme_cache = {}
40
+ tts_cache = {}
41
+
42
+ # TTS configuration
43
+ TTS_VOICE = "en-US-AriaNeural" # High quality English voice
44
+ TTS_RATE = "+0%" # Normal speed
45
+ TTS_PITCH = "+0Hz" # Normal pitch
46
 
47
  def log(msg):
48
  print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}")
 
64
  """Cached similarity calculation"""
65
  return SequenceMatcher(None, a.lower(), b.lower()).ratio()
66
 
67
+ async def generate_expected_audio(word: str) -> str:
68
+ """Generate TTS audio for expected pronunciation and return as base64"""
69
+ # Check cache first
70
+ cache_key = f"{word.lower()}_{TTS_VOICE}_{TTS_RATE}_{TTS_PITCH}"
71
+ if cache_key in tts_cache:
72
+ log(f"TTS cache hit for: {word}")
73
+ return tts_cache[cache_key]
74
+
75
+ log(f"Generating TTS for: {word}")
76
+
77
+ try:
78
+ # Generate TTS audio
79
+ communicate = edge_tts.Communicate(word, TTS_VOICE, rate=TTS_RATE, pitch=TTS_PITCH)
80
+
81
+ # Collect audio data
82
+ audio_data = b""
83
+ async for chunk in communicate.stream():
84
+ if chunk["type"] == "audio":
85
+ audio_data += chunk["data"]
86
+
87
+ if not audio_data:
88
+ log(f"No audio data generated for: {word}")
89
+ return ""
90
+
91
+ # Convert to base64
92
+ audio_b64 = base64.b64encode(audio_data).decode('utf-8')
93
+
94
+ # Cache the result
95
+ tts_cache[cache_key] = audio_b64
96
+ log(f"TTS generated and cached for: {word} ({len(audio_data)} bytes)")
97
+
98
+ return audio_b64
99
+
100
+ except Exception as e:
101
+ log(f"TTS generation failed for '{word}': {str(e)}")
102
+ return ""
103
+
104
+ def extract_user_audio_segment(waveform: torch.Tensor, sample_rate: int,
105
+ start_time: float, end_time: float) -> str:
106
+ """Extract user's pronunciation segment and return as base64 WAV"""
107
+ try:
108
+ # Add small buffer around the word
109
+ buffer_time = 0.05 # 50ms buffer
110
+ adj_start = max(0, start_time - buffer_time)
111
+ adj_end = end_time + buffer_time
112
+
113
+ # Convert to sample indices
114
+ start_sample = int(adj_start * sample_rate)
115
+ end_sample = int(adj_end * sample_rate)
116
+ end_sample = min(waveform.shape[-1], end_sample)
117
+
118
+ # Extract segment
119
+ segment = waveform[:, start_sample:end_sample]
120
+
121
+ # Convert to bytes
122
+ buffer = io.BytesIO()
123
+ torchaudio.save(buffer, segment, sample_rate, format="wav")
124
+ buffer.seek(0)
125
+
126
+ # Convert to base64
127
+ audio_b64 = base64.b64encode(buffer.read()).decode('utf-8')
128
+
129
+ log(f"Extracted user audio segment: {start_time:.2f}s-{end_time:.2f}s ({len(audio_b64)} chars)")
130
+ return audio_b64
131
+
132
+ except Exception as e:
133
+ log(f"Failed to extract audio segment {start_time}-{end_time}: {str(e)}")
134
+ return ""
135
+
136
  def get_reference_phonemes(words: List[str]) -> List[str]:
137
  """Get reference phonemes for all words at once with caching"""
138
  cache_key = tuple(words)
 
394
  # Find best phoneme matches considering all variants
395
  scores = find_best_phoneme_matches(segment_variants, reference_phonemes, sample_rate)
396
 
397
+ # Generate audio data for playback
398
+ log("Generating audio data for playback...")
399
+ audio_data = await generate_audio_data(words, word_texts, waveform, sample_rate)
400
+
401
  # Format output
402
  full_text = " ".join(word_texts)
403
  resolved_output = []
 
418
  return {
419
  "transcript": full_text,
420
  "resolved": " ".join(resolved_output),
421
+ "resolved_colored": " ".join(resolved_colored),
422
+ "audio_data": audio_data,
423
+ "debug_info": {
424
+ "total_words": len(words),
425
+ "audio_segments_generated": len([a for a in audio_data if a["user_audio"]]),
426
+ "tts_segments_generated": len([a for a in audio_data if a["expected_audio"]]),
427
+ "cache_stats": {
428
+ "phoneme_cache_size": len(phoneme_cache),
429
+ "tts_cache_size": len(tts_cache)
430
+ }
431
+ }
432
  }
433
 
434
  @app.get("/")
435
  def root():
436
  return "fonetik running (optimized)"
437
 
438
+ # Optional: Add endpoints for debugging and cache management
439
  @app.post("/api/clear-cache")
440
  def clear_cache():
441
+ global phoneme_cache, tts_cache
442
  phoneme_cache.clear()
443
+ tts_cache.clear()
444
  normalize_phoneme_string.cache_clear()
445
  words_sim.cache_clear()
446
+ return {"message": "All caches cleared"}
447
+
448
+ @app.get("/api/debug/cache-stats")
449
+ def get_cache_stats():
450
+ return {
451
+ "phoneme_cache_size": len(phoneme_cache),
452
+ "tts_cache_size": len(tts_cache),
453
+ "lru_cache_info": {
454
+ "normalize_phoneme_string": normalize_phoneme_string.cache_info()._asdict(),
455
+ "words_sim": words_sim.cache_info()._asdict()
456
+ }
457
+ }
458
+
459
+ @app.post("/api/debug/test-tts")
460
+ async def test_tts(word: str = Form(...)):
461
+ """Test TTS generation for a single word"""
462
+ try:
463
+ audio_b64 = await generate_expected_audio(word)
464
+ return {
465
+ "word": word,
466
+ "success": len(audio_b64) > 0,
467
+ "audio_length": len(audio_b64),
468
+ "audio_data": audio_b64 if len(audio_b64) > 0 else None
469
+ }
470
+ except Exception as e:
471
+ return {
472
+ "word": word,
473
+ "success": False,
474
+ "error": str(e)
475
+ }