m-adil-ali commited on
Commit
59adbe1
·
verified ·
1 Parent(s): cb9c026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -22
app.py CHANGED
@@ -1,29 +1,11 @@
1
  import streamlit as st
2
  import torchaudio
3
  import torchaudio.transforms as T
4
- from transformers import pipeline, AutoProcessor, AutoModelForCTC, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToSpectrogram
5
  import torch
6
  import numpy as np
7
  import io
8
 
9
- def transcribe_audio(audio_bytes):
10
- # Load audio
11
- waveform, sample_rate = torchaudio.load(io.BytesIO(audio_bytes), normalize=True)
12
-
13
- # Resample to 16kHz if necessary
14
- if sample_rate != 16000:
15
- resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
16
- waveform = resampler(waveform)
17
- sample_rate = 16000
18
-
19
- # Transcription
20
- inputs = asr_processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt", padding=True)
21
- with torch.no_grad():
22
- logits = asr_model(input_values=inputs.input_values).logits
23
- predicted_ids = torch.argmax(logits, dim=-1)
24
- transcription = asr_processor.decode(predicted_ids[0])
25
- return transcription
26
-
27
  # Load models
28
  asr_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
29
  asr_processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
@@ -68,9 +50,20 @@ def generate_reply(text):
68
  return reply
69
 
70
  def text_to_speech(text):
 
 
 
 
 
 
 
 
71
  inputs = tts_processor(text=text, return_tensors="pt")
72
  with torch.no_grad():
73
- spectrogram = tts_model.generate(**inputs)
 
 
 
74
  return spectrogram
75
 
76
  # Streamlit app
@@ -98,9 +91,11 @@ if audio_input:
98
 
99
  # Convert text to speech
100
  spectrogram = text_to_speech(reply_text)
101
- # Save spectrogram to file
 
 
102
  audio_file = io.BytesIO()
103
- torchaudio.save(audio_file, spectrogram, 22050) # assuming 22050 Hz sample rate
104
  audio_file.seek(0)
105
 
106
  st.audio(audio_file, format="audio/wav")
 
1
  import streamlit as st
2
  import torchaudio
3
  import torchaudio.transforms as T
4
+ from transformers import AutoProcessor, AutoModelForCTC, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForTextToSpectrogram
5
  import torch
6
  import numpy as np
7
  import io
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Load models
10
  asr_model = AutoModelForCTC.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
11
  asr_processor = AutoProcessor.from_pretrained("facebook/wav2vec2-large-robust-ft-swbd-300h")
 
50
  return reply
51
 
52
  def text_to_speech(text):
53
+ # Load speaker embeddings
54
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
55
+ from datasets import load_dataset
56
+
57
+ # Load pre-trained speaker embeddings (assuming you have downloaded them)
58
+ dataset = load_dataset("Matthijs/cmu-arctic-xvectors")
59
+ speaker_embeddings = dataset['train'][0]['xvector']
60
+
61
  inputs = tts_processor(text=text, return_tensors="pt")
62
  with torch.no_grad():
63
+ spectrogram = tts_model.generate(
64
+ **inputs,
65
+ speaker_embeddings=speaker_embeddings
66
+ )
67
  return spectrogram
68
 
69
  # Streamlit app
 
91
 
92
  # Convert text to speech
93
  spectrogram = text_to_speech(reply_text)
94
+
95
+ # Convert spectrogram to waveform for saving
96
+ waveform = tts_processor.convert_spectrogram_to_waveform(spectrogram)
97
  audio_file = io.BytesIO()
98
+ torchaudio.save(audio_file, waveform, 22050) # assuming 22050 Hz sample rate
99
  audio_file.seek(0)
100
 
101
  st.audio(audio_file, format="audio/wav")