Spaces:
Runtime error
Runtime error
import requests | |
import gradio as gr | |
import os | |
from pydub import AudioSegment | |
from io import BytesIO | |
import time | |
# Hugging Face API URLs | |
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2" | |
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron" | |
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2" | |
# Retry settings | |
MAX_RETRIES = 5 | |
RETRY_DELAY = 1 # seconds | |
# Function to query the Whisper model for audio transcription | |
def query_whisper(api_token, audio_path): | |
headers = {"Authorization": f"Bearer {api_token}"} | |
for attempt in range(MAX_RETRIES): | |
try: | |
if not audio_path: | |
raise ValueError("Audio file path is None") | |
if not os.path.exists(audio_path): | |
raise FileNotFoundError(f"Audio file does not exist: {audio_path}") | |
with open(audio_path, "rb") as f: | |
data = f.read() | |
response = requests.post(API_URL_WHISPER, headers=headers, files={"file": data}) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
print(f"Whisper model query failed: {e}") | |
if attempt < MAX_RETRIES - 1: | |
print(f"Retrying Whisper model query ({attempt + 1}/{MAX_RETRIES})...") | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to query the RoBERTa model | |
def query_roberta(api_token, prompt, context): | |
headers = {"Authorization": f"Bearer {api_token}"} | |
payload = {"inputs": {"question": prompt, "context": context}} | |
for attempt in range(MAX_RETRIES): | |
try: | |
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
print(f"RoBERTa model query failed: {e}") | |
if attempt < MAX_RETRIES - 1: | |
print(f"Retrying RoBERTa model query ({attempt + 1}/{MAX_RETRIES})...") | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to generate speech from text using ESPnet TTS | |
def generate_speech(api_token, answer): | |
headers = {"Authorization": f"Bearer {api_token}"} | |
payload = {"inputs": answer} | |
for attempt in range(MAX_RETRIES): | |
try: | |
response = requests.post(API_URL_TTS, headers=headers, json=payload) | |
response.raise_for_status() | |
audio = response.content | |
audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac") | |
audio_file_path = "/tmp/answer.wav" | |
audio_segment.export(audio_file_path, format="wav") | |
return audio_file_path | |
except Exception as e: | |
print(f"ESPnet TTS query failed: {e}") | |
if attempt < MAX_RETRIES - 1: | |
print(f"Retrying ESPnet TTS query ({attempt + 1}/{MAX_RETRIES})...") | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to handle the entire process | |
def handle_all(api_token, context, audio): | |
for attempt in range(MAX_RETRIES): | |
try: | |
# Step 1: Transcribe audio | |
transcription = query_whisper(api_token, audio) | |
if 'error' in transcription: | |
raise Exception(transcription['error']) | |
question = transcription.get("text", "No transcription found") | |
# Step 2: Get answer from RoBERTa | |
answer = query_roberta(api_token, question, context) | |
if 'error' in answer: | |
raise Exception(answer['error']) | |
answer_text = answer.get('answer', 'No answer found') | |
# Step 3: Generate speech from answer | |
audio_file_path = generate_speech(api_token, answer_text) | |
if 'error' in audio_file_path: | |
raise Exception(audio_file_path['error']) | |
return answer_text, audio_file_path | |
except Exception as e: | |
print(f"Process failed: {e}") | |
if attempt < MAX_RETRIES - 1: | |
print(f"Retrying entire process ({attempt + 1}/{MAX_RETRIES})...") | |
time.sleep(RETRY_DELAY) | |
else: | |
return str(e), None | |
# Define the Gradio interface | |
iface = gr.Interface( | |
fn=handle_all, | |
inputs=[ | |
gr.Textbox(lines=1, label="Hugging Face API Token", type="password", placeholder="Enter your Hugging Face API token..."), | |
gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."), | |
gr.Audio(type="filepath", label="Record your voice") | |
], | |
outputs=[ | |
gr.Textbox(label="Answer"), | |
gr.Audio(label="Answer as Speech", type="filepath") | |
], | |
title="Chat with Roberta with Voice", | |
description="Record your voice, get the transcription, use it as a question for the Roberta model, and hear the response via text-to-speech." | |
) | |
# Launch the Gradio app | |
iface.launch(share=True) |