RobertaSpeak / app.py
ariankhalfani's picture
Update app.py
03fe636 verified
import requests
import gradio as gr
import os
from pydub import AudioSegment
from io import BytesIO
import time
# Hugging Face API URLs
API_URL_ROBERTA = "https://api-inference.huggingface.co/models/deepset/roberta-base-squad2"
API_URL_TTS = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_tacotron"
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
# Retry settings
MAX_RETRIES = 5
RETRY_DELAY = 1 # seconds
# Function to query the Whisper model for audio transcription
def query_whisper(api_token, audio_path):
headers = {"Authorization": f"Bearer {api_token}"}
for attempt in range(MAX_RETRIES):
try:
if not audio_path:
raise ValueError("Audio file path is None")
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
with open(audio_path, "rb") as f:
data = f.read()
response = requests.post(API_URL_WHISPER, headers=headers, files={"file": data})
response.raise_for_status()
return response.json()
except Exception as e:
print(f"Whisper model query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying Whisper model query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to query the RoBERTa model
def query_roberta(api_token, prompt, context):
headers = {"Authorization": f"Bearer {api_token}"}
payload = {"inputs": {"question": prompt, "context": context}}
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_ROBERTA, headers=headers, json=payload)
response.raise_for_status()
return response.json()
except Exception as e:
print(f"RoBERTa model query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying RoBERTa model query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to generate speech from text using ESPnet TTS
def generate_speech(api_token, answer):
headers = {"Authorization": f"Bearer {api_token}"}
payload = {"inputs": answer}
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_TTS, headers=headers, json=payload)
response.raise_for_status()
audio = response.content
audio_segment = AudioSegment.from_file(BytesIO(audio), format="flac")
audio_file_path = "/tmp/answer.wav"
audio_segment.export(audio_file_path, format="wav")
return audio_file_path
except Exception as e:
print(f"ESPnet TTS query failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying ESPnet TTS query ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to handle the entire process
def handle_all(api_token, context, audio):
for attempt in range(MAX_RETRIES):
try:
# Step 1: Transcribe audio
transcription = query_whisper(api_token, audio)
if 'error' in transcription:
raise Exception(transcription['error'])
question = transcription.get("text", "No transcription found")
# Step 2: Get answer from RoBERTa
answer = query_roberta(api_token, question, context)
if 'error' in answer:
raise Exception(answer['error'])
answer_text = answer.get('answer', 'No answer found')
# Step 3: Generate speech from answer
audio_file_path = generate_speech(api_token, answer_text)
if 'error' in audio_file_path:
raise Exception(audio_file_path['error'])
return answer_text, audio_file_path
except Exception as e:
print(f"Process failed: {e}")
if attempt < MAX_RETRIES - 1:
print(f"Retrying entire process ({attempt + 1}/{MAX_RETRIES})...")
time.sleep(RETRY_DELAY)
else:
return str(e), None
# Define the Gradio interface
iface = gr.Interface(
fn=handle_all,
inputs=[
gr.Textbox(lines=1, label="Hugging Face API Token", type="password", placeholder="Enter your Hugging Face API token..."),
gr.Textbox(lines=2, label="Context", placeholder="Enter the context here..."),
gr.Audio(type="filepath", label="Record your voice")
],
outputs=[
gr.Textbox(label="Answer"),
gr.Audio(label="Answer as Speech", type="filepath")
],
title="Chat with Roberta with Voice",
description="Record your voice, get the transcription, use it as a question for the Roberta model, and hear the response via text-to-speech."
)
# Launch the Gradio app
iface.launch(share=True)