Spaces:
Sleeping
Sleeping
Razzaqi3143
commited on
Commit
•
69c2a91
1
Parent(s):
1395209
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from groq import Groq, GroqError
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from parler_tts import ParlerTTSForConditionalGeneration
|
6 |
+
from transformers import AutoTokenizer
|
7 |
+
import soundfile as sf
|
8 |
+
|
9 |
+
# Initialize Groq client with API key
|
10 |
+
GROQ_API_KEY = "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"
|
11 |
+
|
12 |
+
# Device setup for Parler-TTS
|
13 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
14 |
+
parler_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
|
15 |
+
parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
|
16 |
+
|
17 |
+
# Function to transcribe audio using Whisper through Groq, with error handling
|
18 |
+
def transcribe_audio(audio):
|
19 |
+
try:
|
20 |
+
# Ensure the audio is in the correct format supported by Groq
|
21 |
+
audio_input = audio
|
22 |
+
transcription_response = client.transcriptions.create(
|
23 |
+
model="openai/whisper-large-v3",
|
24 |
+
audio=audio_input,
|
25 |
+
)
|
26 |
+
return transcription_response['text']
|
27 |
+
except GroqError as e:
|
28 |
+
print(f"Groq transcription error: {e}")
|
29 |
+
return "Error: Failed to transcribe audio."
|
30 |
+
|
31 |
+
# Function to generate a response using LLaMA through Groq, with error handling
|
32 |
+
def generate_response(text):
|
33 |
+
try:
|
34 |
+
chat_completion = client.chat.completions.create(
|
35 |
+
messages=[{"role": "user", "content": text}],
|
36 |
+
model="llama3-70b-8192", # Modify based on the model you're using
|
37 |
+
)
|
38 |
+
return chat_completion.choices[0].message['content']
|
39 |
+
except GroqError as e:
|
40 |
+
print(f"Groq response generation error: {e}")
|
41 |
+
return "Error: Failed to generate a response."
|
42 |
+
|
43 |
+
# Function to convert text to speech using Parler-TTS, unchanged
|
44 |
+
def text_to_speech(text):
|
45 |
+
try:
|
46 |
+
description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
|
47 |
+
input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
|
48 |
+
prompt_input_ids = parler_tokenizer(text, return_tensors="pt").input_ids.to(device)
|
49 |
+
generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
|
50 |
+
audio_arr = generation.cpu().numpy().squeeze()
|
51 |
+
sf.write("parler_tts_out.wav", audio_arr, parler_model.config.sampling_rate)
|
52 |
+
return "parler_tts_out.wav"
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Parler-TTS error: {e}")
|
55 |
+
return "Error: Failed to convert text to speech."
|
56 |
+
|
57 |
+
# Gradio interface combining all the components, with error handling in each step
|
58 |
+
def chatbot_pipeline(audio):
|
59 |
+
# Step 1: Convert speech to text using Whisper through Groq
|
60 |
+
transcribed_text = transcribe_audio(audio)
|
61 |
+
|
62 |
+
# If there was an error in transcription, return the error message
|
63 |
+
if "Error" in transcribed_text:
|
64 |
+
return transcribed_text, None
|
65 |
+
|
66 |
+
# Step 2: Generate a response using LLaMA through Groq
|
67 |
+
response_text = generate_response(transcribed_text)
|
68 |
+
|
69 |
+
# If there was an error in response generation, return the error message
|
70 |
+
if "Error" in response_text:
|
71 |
+
return response_text, None
|
72 |
+
|
73 |
+
# Step 3: Convert response text to speech using Parler-TTS
|
74 |
+
response_audio_path = text_to_speech(response_text)
|
75 |
+
|
76 |
+
# If there was an error in TTS conversion, return the error message
|
77 |
+
if "Error" in response_audio_path:
|
78 |
+
return response_text, None
|
79 |
+
|
80 |
+
# Return both text and audio for output
|
81 |
+
return response_text, response_audio_path
|
82 |
+
|
83 |
+
# Gradio interface setup
|
84 |
+
ui = gr.Interface(
|
85 |
+
fn=chatbot_pipeline,
|
86 |
+
inputs=gr.Audio(type="numpy"), # Removed 'source' and 'streaming'
|
87 |
+
outputs=[gr.Textbox(label="Chatbot Response"), gr.Audio(label="Chatbot Voice Response")],
|
88 |
+
live=True
|
89 |
+
)
|
90 |
+
|
91 |
+
ui.launch()
|