Spaces:
Sleeping
Sleeping
import gradio as gr | |
import whisper | |
import asyncio | |
import httpx | |
import tempfile | |
import os | |
import requests | |
import time | |
import threading | |
from datetime import datetime, timedelta | |
session = requests.Session() | |
from interview_protocol import protocols as interview_protocols | |
model = whisper.load_model("base") | |
base_url = "https://llm4socialisolation-fd4082d0a518.herokuapp.com" | |
# base_url = "http://localhost:8080" | |
timeout = 60 | |
concurrency_count=10 | |
# mapping between display names and internal chatbot_type values | |
display_to_value = { | |
'Echo': 'enhanced', | |
'Breeze': 'baseline' | |
} | |
value_to_display = { | |
'enhanced': 'Echo', | |
'baseline': 'Breeze' | |
} | |
def get_method_index(chapter, method): | |
all_methods = [] | |
for chap in interview_protocols.values(): | |
all_methods.extend(chap) | |
index = all_methods.index(method) | |
return index | |
async def initialization(api_key, chapter_name, topic_name, username, prompts, chatbot_type): | |
url = f"{base_url}/api/initialization" | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'api_key': api_key, | |
'chapter_name': chapter_name, | |
'topic_name': topic_name, | |
'username': username, | |
'chatbot_type': chatbot_type, | |
**prompts | |
} | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
return "Initialization successful." | |
else: | |
return f"Initialization failed: {response.text}" | |
except asyncio.TimeoutError: | |
print("The request timed out") | |
return "Request timed out during initialization." | |
except Exception as e: | |
return f"Error in initialization: {str(e)}" | |
def fetch_default_prompts(chatbot_type): | |
url = f"{base_url}?chatbot_type={chatbot_type}" | |
try: | |
response = httpx.get(url, timeout=timeout) | |
if response.status_code == 200: | |
prompts = response.json() | |
print(prompts) | |
return prompts | |
else: | |
print(f"Failed to fetch prompts: {response.status_code} - {response.text}") | |
return {} | |
except Exception as e: | |
print(f"Error fetching prompts: {str(e)}") | |
return {} | |
async def get_backend_response(api_key, patient_prompt, username, chatbot_type): | |
url = f"{base_url}/responses/doctor" | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'username': username, | |
'patient_prompt': patient_prompt, | |
'chatbot_type': chatbot_type | |
} | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
response_data = response.json() | |
return response_data | |
else: | |
return f"Failed to fetch response from backend: {response.text}" | |
except Exception as e: | |
return f"Error contacting backend service: {str(e)}" | |
async def save_conversation_and_memory(username, chatbot_type): | |
url = f"{base_url}/save/end_and_save" | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'username': username, | |
'chatbot_type': chatbot_type | |
} | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
response_data = response.json() | |
return response_data.get('message', 'Saving Error!') | |
else: | |
return f"Failed to save conversations and memory graph: {response.text}" | |
except Exception as e: | |
return f"Error contacting backend service: {str(e)}" | |
async def get_conversation_histories(username, chatbot_type): | |
url = f"{base_url}/save/download_conversations" | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'username': username, | |
'chatbot_type': chatbot_type | |
} | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
conversation_data = response.json() | |
return conversation_data | |
else: | |
return [] | |
except Exception as e: | |
return [] | |
def download_conversations(username, chatbot_type): | |
conversation_histories = asyncio.run(get_conversation_histories(username, chatbot_type)) | |
files = [] | |
temp_dir = tempfile.mkdtemp() | |
for conversation_entry in conversation_histories: | |
file_name = conversation_entry.get('file_name', f"Conversation_{len(files)+1}.txt") | |
conversation = conversation_entry.get('conversation', []) | |
conversation_text = "" | |
for message_pair in conversation: | |
if isinstance(message_pair, list) and len(message_pair) == 2: | |
speaker, message = message_pair | |
conversation_text += f"{speaker.capitalize()}: {message}\n\n" | |
else: | |
conversation_text += f"Unknown format: {message_pair}\n\n" | |
temp_file_path = os.path.join(temp_dir, file_name) | |
with open(temp_file_path, 'w') as temp_file: | |
temp_file.write(conversation_text) | |
files.append(temp_file_path) | |
return files | |
async def get_biography(username, chatbot_type): | |
url = f"{base_url}/save/generate_autobiography" | |
headers = {'Content-Type': 'application/json'} | |
data = { | |
'username': username, | |
'chatbot_type': chatbot_type | |
} | |
async with httpx.AsyncClient(timeout=timeout) as client: | |
try: | |
response = await client.post(url, json=data, headers=headers) | |
if response.status_code == 200: | |
biography_data = response.json() | |
biography_text = biography_data.get('biography', '') | |
return biography_text | |
else: | |
return "Failed to generate biography." | |
except Exception as e: | |
return f"Error contacting backend service: {str(e)}" | |
def download_biography(username, chatbot_type): | |
biography_text = asyncio.run(get_biography(username, chatbot_type)) | |
if not biography_text or "Failed" in biography_text or "Error" in biography_text: | |
return gr.update(value=None, visible=False), gr.update(value=biography_text, visible=True) | |
temp_dir = tempfile.mkdtemp() | |
temp_file_path = os.path.join(temp_dir, "biography.txt") | |
with open(temp_file_path, 'w') as temp_file: | |
temp_file.write(biography_text) | |
return temp_file_path, gr.update(value=biography_text, visible=True) | |
def transcribe_audio(audio_file): | |
transcription = model.transcribe(audio_file)["text"] | |
return transcription | |
def submit_text_and_respond(edited_text, api_key, username, history, chatbot_type): | |
response = asyncio.run(get_backend_response(api_key, edited_text, username, chatbot_type)) | |
print('------') | |
print(response) | |
if isinstance(response, str): | |
history.append((edited_text, response)) | |
return history, "", [] | |
doctor_response = response['doctor_response']['response'] | |
memory_event = response.get('memory_events', []) | |
history.append((edited_text, doctor_response)) | |
memory_graph = update_memory_graph(memory_event) | |
return history, "", memory_graph # Return memory_graph as output | |
def set_initialize_button(api_key_input, chapter_name, topic_name, username_input, | |
system_prompt_text, conv_instruction_prompt_text, therapy_prompt_text, autobio_prompt_text, chatbot_display_name): | |
chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
prompts = { | |
'system_prompt': system_prompt_text, | |
'conv_instruction_prompt': conv_instruction_prompt_text, | |
'therapy_prompt': therapy_prompt_text, | |
'autobio_prompt': autobio_prompt_text | |
} | |
message = asyncio.run(initialization(api_key_input, chapter_name, topic_name, username_input, prompts, chatbot_type)) | |
print(message) | |
return message, api_key_input, chatbot_type | |
def save_conversation(username, chatbot_type): | |
response = asyncio.run(save_conversation_and_memory(username, chatbot_type)) | |
return response | |
def start_recording(audio_file): | |
if not audio_file: | |
return "" | |
try: | |
transcription = transcribe_audio(audio_file) | |
return transcription | |
except Exception as e: | |
return f"Failed to transcribe: {str(e)}" | |
def update_methods(chapter): | |
return gr.update(choices=interview_protocols[chapter], value=interview_protocols[chapter][0]) | |
def update_memory_graph(memory_data): | |
table_data = [] | |
for node in memory_data: | |
table_data.append([ | |
node.get('date', ''), | |
node.get('topic', ''), | |
node.get('event_description', ''), | |
node.get('people_involved', '') | |
]) | |
return table_data | |
def update_prompts(chatbot_display_name): | |
chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
prompts = fetch_default_prompts(chatbot_type) | |
return ( | |
gr.update(value=prompts.get('system_prompt', '')), | |
gr.update(value=prompts.get('conv_instruction_prompt', '')), | |
gr.update(value=prompts.get('therapy_prompt', '')), | |
gr.update(value=prompts.get('autobio_generation_prompt', '')), | |
) | |
def update_chatbot_type(chatbot_display_name): | |
chatbot_type = display_to_value.get(chatbot_display_name, 'enhanced') | |
return chatbot_type | |
# Function to start the periodic toggle | |
def start_timer(): | |
target_timestamp = datetime.now() + timedelta(seconds=8 * 60) | |
return True, target_timestamp | |
def reset_timer(): | |
is_running = False | |
return is_running, "Timer remaining: 8:00" | |
# Async function to manage periodic updates, running every second | |
def periodic_call(is_running, target_timestamp): | |
if is_running: | |
prefix = 'Time remaining:' | |
time_difference = target_timestamp - datetime.now() | |
second_left = int(round(time_difference.total_seconds())) | |
if second_left <= 0: | |
second_left = 0 | |
minutes, seconds = divmod(second_left, 60) | |
new_remain_min = f'{minutes:02}' | |
new_remain_second = f'{seconds:02}' | |
new_info = f'{prefix} {new_remain_min}:{new_remain_second}' | |
return new_info | |
else: | |
return 'Time remaining: 8:00' | |
# initialize prompts with empty strings | |
initial_prompts = {'system_prompt': '', 'conv_instruction_prompt': '', 'therapy_prompt': '', 'autobio_generation_prompt': ''} | |
# CSS to keep the buttons small | |
css = """ | |
#start_button, #reset_button { | |
padding: 4px 10px !important; | |
font-size: 12px !important; | |
width: auto !important; | |
} | |
""" | |
with gr.Blocks(css=css) as app: | |
chatbot_type_state = gr.State('enhanced') | |
api_key_state = gr.State() | |
prompt_visibility_state = gr.State(False) | |
is_running = gr.State() | |
target_timestamp = gr.State() | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=250): | |
gr.Markdown("## Settings") | |
# chatbot Type Selection | |
with gr.Box(): | |
gr.Markdown("### Chatbot Selection") | |
chatbot_type_dropdown = gr.Dropdown( | |
label="Select Chatbot Type", | |
choices=['Echo', 'Breeze'], | |
value='Echo', | |
) | |
chatbot_type_dropdown.change( | |
fn=update_chatbot_type, | |
inputs=[chatbot_type_dropdown], | |
outputs=[chatbot_type_state] | |
) | |
# fetch initial prompts based on the default chatbot type | |
system_prompt_value, conv_instruction_prompt_value, therapy_prompt_value, autobio_prompt_value = update_prompts('Echo') | |
# interview protocol selection | |
with gr.Box(): | |
gr.Markdown("### Interview Protocol") | |
chapter_dropdown = gr.Dropdown( | |
label="Select Chapter", | |
choices=list(interview_protocols.keys()), | |
value=list(interview_protocols.keys())[1], | |
) | |
method_dropdown = gr.Dropdown( | |
label="Select Topic", | |
choices=interview_protocols[chapter_dropdown.value], | |
value=interview_protocols[chapter_dropdown.value][3], | |
) | |
chapter_dropdown.change( | |
fn=update_methods, | |
inputs=[chapter_dropdown], | |
outputs=[method_dropdown] | |
) | |
# Update states when selections change | |
def update_chapter(chapter): | |
return chapter | |
def update_method(method): | |
return method | |
chapter_state = gr.State() | |
method_state = gr.State() | |
chapter_dropdown.change( | |
fn=update_chapter, | |
inputs=[chapter_dropdown], | |
outputs=[chapter_state] | |
) | |
method_dropdown.change( | |
fn=update_method, | |
inputs=[method_dropdown], | |
outputs=[method_state] | |
) | |
# customize Prompts | |
with gr.Box(): | |
toggle_prompts_button = gr.Button("Show Prompts") | |
# wrap the prompts in a component with initial visibility set to False | |
with gr.Column(visible=False) as prompt_section: | |
gr.Markdown("### Customize Prompts") | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
placeholder="Enter the system prompt here.", | |
value=system_prompt_value['value'] | |
) | |
conv_instruction_prompt = gr.Textbox( | |
label="Conversation Instruction Prompt", | |
placeholder="Enter the instruction for each conversation here.", | |
value=conv_instruction_prompt_value['value'] | |
) | |
therapy_prompt = gr.Textbox( | |
label="Therapy Prompt", | |
placeholder="Enter the instruction for reminiscence therapy.", | |
value=therapy_prompt_value['value'] | |
) | |
autobio_prompt = gr.Textbox( | |
label="Autobiography Generation Prompt", | |
placeholder="Enter the instruction for autobiography generation.", | |
value=autobio_prompt_value['value'] | |
) | |
# update prompts when chatbot_type changes | |
chatbot_type_dropdown.change( | |
fn=update_prompts, | |
inputs=[chatbot_type_dropdown], | |
outputs=[system_prompt, conv_instruction_prompt, therapy_prompt, autobio_prompt] | |
) | |
with gr.Box(): | |
gr.Markdown("### User Information") | |
username_input = gr.Textbox( | |
label="Username", placeholder="Enter your username" | |
) | |
api_key_input = gr.Textbox( | |
label="OpenAI API Key", | |
placeholder="Enter your openai api key", | |
type="password" | |
) | |
initialize_button = gr.Button("Initialize", variant="primary", size="large") | |
initialization_status = gr.Textbox( | |
label="Status", interactive=False, placeholder="Initialization status will appear here." | |
) | |
initialize_button.click( | |
fn=set_initialize_button, | |
inputs=[api_key_input, chapter_dropdown, method_dropdown, username_input, | |
system_prompt, conv_instruction_prompt, therapy_prompt, autobio_prompt, chatbot_type_dropdown], | |
outputs=[initialization_status, api_key_state, chatbot_type_state], | |
) | |
# define the function to toggle prompts visibility | |
def toggle_prompts(visibility): | |
new_visibility = not visibility | |
button_text = "Hide Prompts" if new_visibility else "Show Prompts" | |
return gr.update(value=button_text), gr.update(visible=new_visibility), new_visibility | |
toggle_prompts_button.click( | |
fn=toggle_prompts, | |
inputs=[prompt_visibility_state], | |
outputs=[toggle_prompts_button, prompt_section, prompt_visibility_state] | |
) | |
with gr.Column(scale=3): | |
with gr.Row(): | |
timer_display = gr.Textbox(value="Time remaining: 08:00", label="") | |
start_button = gr.Button("Start Timer", elem_id="start_button") | |
start_button.click(start_timer, outputs=[is_running, target_timestamp]).then( | |
periodic_call, inputs=[is_running, target_timestamp], outputs=timer_display, every=1) | |
chatbot = gr.Chatbot(label="Chat here for autobiography generation", height=500) | |
with gr.Row(): | |
transcription_box = gr.Textbox( | |
label="Transcription (You can edit this)", lines=3 | |
) | |
audio_input = gr.Audio( | |
source="microphone", type="filepath", label="🎤 Record Audio" | |
) | |
with gr.Row(): | |
submit_button = gr.Button("Submit", variant="primary", size="large") | |
save_conversation_button = gr.Button("End and Save Conversation", variant="secondary") | |
download_button = gr.Button("Download Conversations", variant="secondary") | |
download_biography_button = gr.Button("Download Biography", variant="secondary") | |
memory_graph_table = gr.Dataframe( | |
headers=["Date", "Topic", "Description", "People Involved"], | |
datatype=["str", "str", "str", "str"], | |
interactive=False, | |
label="Memory Events", | |
max_rows=5 | |
) | |
biography_textbox = gr.Textbox(label="Autobiography", visible=False) | |
audio_input.change( | |
fn=start_recording, | |
inputs=[audio_input], | |
outputs=[transcription_box] | |
) | |
state = gr.State([]) | |
submit_button.click( | |
submit_text_and_respond, | |
inputs=[transcription_box, api_key_state, username_input, state, chatbot_type_state], | |
outputs=[chatbot, transcription_box, memory_graph_table] | |
) | |
download_button.click( | |
fn=download_conversations, | |
inputs=[username_input, chatbot_type_state], | |
outputs=gr.Files() | |
) | |
download_biography_button.click( | |
fn=download_biography, | |
inputs=[username_input, chatbot_type_state], | |
outputs=[gr.File(label="Biography.txt"), biography_textbox] | |
) | |
save_conversation_button.click( | |
fn=save_conversation, | |
inputs=[username_input, chatbot_type_state], | |
outputs=None | |
) | |
app.queue() | |
app.launch(share=True, max_threads=10) | |