Razzaqi3143 commited on
Commit
69c2a91
1 Parent(s): 1395209

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -0
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from groq import Groq, GroqError
3
+ import gradio as gr
4
+ import torch
5
+ from parler_tts import ParlerTTSForConditionalGeneration
6
+ from transformers import AutoTokenizer
7
+ import soundfile as sf
8
+
9
+ # Initialize Groq client with API key
10
+ GROQ_API_KEY = "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"
11
+
12
+ # Device setup for Parler-TTS
13
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
14
+ parler_model = ParlerTTSForConditionalGeneration.from_pretrained("parler-tts/parler-tts-mini-v1").to(device)
15
+ parler_tokenizer = AutoTokenizer.from_pretrained("parler-tts/parler-tts-mini-v1")
16
+
17
+ # Function to transcribe audio using Whisper through Groq, with error handling
18
+ def transcribe_audio(audio):
19
+ try:
20
+ # Ensure the audio is in the correct format supported by Groq
21
+ audio_input = audio
22
+ transcription_response = client.transcriptions.create(
23
+ model="openai/whisper-large-v3",
24
+ audio=audio_input,
25
+ )
26
+ return transcription_response['text']
27
+ except GroqError as e:
28
+ print(f"Groq transcription error: {e}")
29
+ return "Error: Failed to transcribe audio."
30
+
31
+ # Function to generate a response using LLaMA through Groq, with error handling
32
+ def generate_response(text):
33
+ try:
34
+ chat_completion = client.chat.completions.create(
35
+ messages=[{"role": "user", "content": text}],
36
+ model="llama3-70b-8192", # Modify based on the model you're using
37
+ )
38
+ return chat_completion.choices[0].message['content']
39
+ except GroqError as e:
40
+ print(f"Groq response generation error: {e}")
41
+ return "Error: Failed to generate a response."
42
+
43
+ # Function to convert text to speech using Parler-TTS, unchanged
44
+ def text_to_speech(text):
45
+ try:
46
+ description = "A female speaker delivers a slightly expressive and animated speech with a moderate speed and pitch."
47
+ input_ids = parler_tokenizer(description, return_tensors="pt").input_ids.to(device)
48
+ prompt_input_ids = parler_tokenizer(text, return_tensors="pt").input_ids.to(device)
49
+ generation = parler_model.generate(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
50
+ audio_arr = generation.cpu().numpy().squeeze()
51
+ sf.write("parler_tts_out.wav", audio_arr, parler_model.config.sampling_rate)
52
+ return "parler_tts_out.wav"
53
+ except Exception as e:
54
+ print(f"Parler-TTS error: {e}")
55
+ return "Error: Failed to convert text to speech."
56
+
57
+ # Gradio interface combining all the components, with error handling in each step
58
+ def chatbot_pipeline(audio):
59
+ # Step 1: Convert speech to text using Whisper through Groq
60
+ transcribed_text = transcribe_audio(audio)
61
+
62
+ # If there was an error in transcription, return the error message
63
+ if "Error" in transcribed_text:
64
+ return transcribed_text, None
65
+
66
+ # Step 2: Generate a response using LLaMA through Groq
67
+ response_text = generate_response(transcribed_text)
68
+
69
+ # If there was an error in response generation, return the error message
70
+ if "Error" in response_text:
71
+ return response_text, None
72
+
73
+ # Step 3: Convert response text to speech using Parler-TTS
74
+ response_audio_path = text_to_speech(response_text)
75
+
76
+ # If there was an error in TTS conversion, return the error message
77
+ if "Error" in response_audio_path:
78
+ return response_text, None
79
+
80
+ # Return both text and audio for output
81
+ return response_text, response_audio_path
82
+
83
+ # Gradio interface setup
84
+ ui = gr.Interface(
85
+ fn=chatbot_pipeline,
86
+ inputs=gr.Audio(type="numpy"), # Removed 'source' and 'streaming'
87
+ outputs=[gr.Textbox(label="Chatbot Response"), gr.Audio(label="Chatbot Voice Response")],
88
+ live=True
89
+ )
90
+
91
+ ui.launch()