Spaces:
Runtime error
Runtime error
import os | |
import json | |
import random | |
import string | |
import numpy as np | |
import gradio as gr | |
import requests | |
import soundfile as sf | |
from transformers import pipeline, set_seed | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import logging | |
import sys | |
import gradio as gr | |
from transformers import pipeline, AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2ProcessorWithLM | |
DEBUG = os.environ.get("DEBUG", "false")[0] in "ty1" | |
MAX_LENGTH = int(os.environ.get("MAX_LENGTH", 1024)) | |
DEFAULT_LANG = os.environ.get("DEFAULT_LANG", "English") | |
HEADER = """ | |
# Poor Man's Duplex | |
Talk to a language model like you talk on a Walkie-Talkie! Well, with larger latencies. | |
The models are [EleutherAI's GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) for English, and [BERTIN GPT-J-6B](https://huggingface.co/bertin-project/bertin-gpt-j-6B) for Spanish. | |
""".strip() | |
FOOTER = """ | |
<div align=center> | |
<img src="https://visitor-badge.glitch.me/badge?page_id=versae/poor-mans-duplex"/> | |
<div align=center> | |
""".strip() | |
asr_model_name_es = "jonatasgrosman/wav2vec2-large-xlsr-53-spanish" | |
model_instance_es = AutoModelForCTC.from_pretrained(asr_model_name_es) | |
processor_es = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_es) | |
asr_es = pipeline( | |
"automatic-speech-recognition", | |
model=model_instance_es, | |
tokenizer=processor_es.tokenizer, | |
feature_extractor=processor_es.feature_extractor, | |
decoder=processor_es.decoder | |
) | |
tts_model_name = "facebook/tts_transformer-es-css10" | |
speak_es = gr.Interface.load(f"huggingface/{tts_model_name}") | |
transcribe_es = lambda input_file: asr_es(input_file, chunk_length_s=5, stride_length_s=1)["text"] | |
def generate_es(text, **kwargs): | |
# max_length=100, top_k=100, top_p=50, temperature=0.95, do_sample=True, do_clean=True | |
api_uri = "https://hf.space/embed/bertin-project/bertin-gpt-j-6B/+/api/predict/" | |
response = requests.post(api_uri, data=json.dumps({"data": [text, 100, 100, 50, 0.95, True, True]})) | |
if response.ok: | |
if DEBUG: | |
print(response.json()) | |
return response.json()["data"][0] | |
else: | |
return "" | |
asr_model_name_en = "jonatasgrosman/wav2vec2-large-xlsr-53-english" | |
model_instance_en = AutoModelForCTC.from_pretrained(asr_model_name_en) | |
processor_en = Wav2Vec2ProcessorWithLM.from_pretrained(asr_model_name_en) | |
asr_en = pipeline( | |
"automatic-speech-recognition", | |
model=model_instance_en, | |
tokenizer=processor_en.tokenizer, | |
feature_extractor=processor_en.feature_extractor, | |
decoder=processor_en.decoder | |
) | |
tts_model_name = "facebook/fastspeech2-en-ljspeech" | |
speak_en = gr.Interface.load(f"huggingface/{tts_model_name}") | |
transcribe_en = lambda input_file: asr_en(input_file, chunk_length_s=5, stride_length_s=1)["text"] | |
generate_iface = gr.Interface.load("huggingface/EleutherAI/gpt-j-6B") | |
empty_audio = 'empty.flac' | |
sf.write(empty_audio, [], 16000) | |
deuncase = gr.Interface.load("huggingface/pere/DeUnCaser") | |
def generate_en(text, **kwargs): | |
response = generate_iface(text) | |
if DEBUG: | |
print(response) | |
return response or "" | |
def select_lang(lang): | |
if lang.lower() == "spanish": | |
return generate_es, transcribe_es, speak_es | |
else: | |
return generate_en, transcribe_en, speak_en | |
def select_lang_vars(lang): | |
if lang.lower() == "spanish": | |
AGENT = "BERTIN" | |
USER = "ENTREVISTADOR" | |
CONTEXT = """La siguiente conversación es un extracto de una entrevista a {AGENT} celebrada en Madrid para Radio Televisión Española: | |
{USER}: Bienvenido, {AGENT}. Un placer tenerlo hoy con nosotros. | |
{AGENT}: Gracias. El placer es mío.""" | |
else: | |
AGENT = "ELEUTHER" | |
USER = "INTERVIEWER" | |
CONTEXT = """The next conversation is an excerpt from an interview to {AGENT} that appeared in the New York Times: | |
{USER}: Welcome, {AGENT}. It is a pleasure to have you here today. | |
{AGENT}: Thanks. The pleasure is mine.""" | |
return AGENT, USER, CONTEXT | |
def format_chat(history): | |
interventions = [] | |
for user, bot in history: | |
interventions.append(f""" | |
<div data-testid="user" style="background-color:#16a34a" class="px-3 py-2 rounded-[22px] rounded-bl-none place-self-start text-white ml-7 text-sm">{user}</div> | |
<div data-testid="bot" style="background-color:gray" class="px-3 py-2 rounded-[22px] rounded-br-none text-white ml-7 text-sm">{bot}</div> | |
""") | |
return f"""<details><summary>Conversation log</summary> | |
<div class="overflow-y-auto h-[40vh]"> | |
<div class="flex flex-col items-end space-y-4 p-3"> | |
{"".join(interventions)} | |
</div> | |
</div> | |
</summary>""" | |
def chat_with_gpt(lang, agent, user, context, audio_in, history): | |
if not audio_in: | |
return history, history, empty_audio, format_chat(history) | |
generate, transcribe, speak = select_lang(lang) | |
AGENT, USER, _ = select_lang_vars(lang) | |
user_message = deuncase(transcribe(audio_in)) | |
# agent = AGENT | |
# user = USER | |
generation_kwargs = { | |
"max_length": 25, | |
# "top_k": top_k, | |
# "top_p": top_p, | |
# "temperature": temperature, | |
# "do_sample": do_sample, | |
# "do_clean": do_clean, | |
# "num_return_sequences": 1, | |
# "return_full_text": False, | |
} | |
message = user_message.split(" ", 1)[0].capitalize() + " " + user_message.split(" ", 1)[-1] | |
history = history or [] #[(f"{user}: Bienvenido. Encantado de tenerle con nosotros.", f"{agent}: Un placer, muchas gracias por la invitación.")] | |
context = context.format(USER=user or USER, AGENT=agent or AGENT).strip() | |
if context[-1] not in ".:": | |
context += "." | |
context_length = len(context.split()) | |
history_take = 0 | |
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
while len(history_context.split()) > MAX_LENGTH - (generation_kwargs["max_length"] + context_length): | |
history_take += 1 | |
history_context = "\n".join(f"{user}: {history_message.capitalize()}.\n{agent}: {history_response}." for history_message, history_response in history[-len(history) + history_take:]) | |
if history_take >= MAX_LENGTH: | |
break | |
context += history_context | |
for _ in range(5): | |
response = generate(f"{context}\n\n{user}: {message}.\n", **generation_kwargs) | |
if DEBUG: | |
print("\n-----" + response + "-----\n") | |
response = response.split("\n")[-1] | |
if agent in response and response.split(agent)[-1]: | |
response = response.split(agent)[-1] | |
if user in response and response.split(user)[-1]: | |
response = response.split(user)[-1] | |
if response and response[0] in string.punctuation: | |
response = response[1:].strip() | |
if response.strip().startswith(f"{user}: {message}"): | |
response = response.strip().split(f"{user}: {message}")[-1] | |
if response.replace(".", "").strip() and message.replace(".", "").strip() != response.replace(".", "").strip(): | |
break | |
if DEBUG: | |
print() | |
print("CONTEXT:") | |
print(context) | |
print() | |
print("MESSAGE") | |
print(message) | |
print() | |
print("RESPONSE:") | |
print(response) | |
if not response.strip(): | |
response = "Lo siento, no puedo hablar ahora" if lang.lower() == "Spanish" else "Sorry, can't talk right now" | |
history.append((user_message, response)) | |
return history, history, speak(response), format_chat(history) | |
with gr.Blocks() as demo: | |
gr.Markdown(HEADER) | |
lang = gr.Radio(label="Language", choices=["English", "Spanish"], value=DEFAULT_LANG, type="value") | |
AGENT, USER, CONTEXT = select_lang_vars(DEFAULT_LANG) | |
context = gr.Textbox(label="Context", lines=5, value=CONTEXT) | |
with gr.Row(): | |
audio_in = gr.Audio(label="User", source="microphone", type="filepath") | |
audio_out = gr.Audio(label="Agent", interactive=False, value=empty_audio) | |
# chat_btn = gr.Button("Submit") | |
with gr.Row(): | |
user = gr.Textbox(label="User", value=USER) | |
agent = gr.Textbox(label="Agent", value=AGENT) | |
lang.change(select_lang_vars, inputs=[lang], outputs=[agent, user, context]) | |
history = gr.Variable(value=[]) | |
chatbot = gr.Variable() # gr.Chatbot(color_map=("green", "gray"), visible=False) | |
# chat_btn.click(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out]) | |
log = gr.HTML() | |
audio_in.change(chat_with_gpt, inputs=[lang, agent, user, context, audio_in, history], outputs=[chatbot, history, audio_out, log]) | |
gr.Markdown(FOOTER) | |
demo.launch() | |