Spaces:
Paused
Paused
from io import BytesIO | |
import os | |
from pprint import pprint | |
import uuid | |
import chainlit as cl | |
from chainlit.element import ElementBased | |
from dotenv import load_dotenv | |
# modules for audio processing | |
from langchain.schema.runnable.config import RunnableConfig | |
from langchain_openai.chat_models import ChatOpenAI | |
from openai import AsyncOpenAI | |
from utils.graph import generate_workflow | |
client = AsyncOpenAI() | |
# ---- ENV VARIABLES ---- # | |
load_dotenv() | |
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
# -- AUGMENTED -- # | |
# Define the LLM | |
base_llm = ChatOpenAI(model="gpt-4o-mini", openai_api_key=OPENAI_API_KEY, tags=["base_llm"], temperature=0) | |
power_llm = ChatOpenAI(model="gpt-4o", openai_api_key=OPENAI_API_KEY, tags=["base_llm"]) | |
# Conversation starters for the 1st screen | |
async def set_starters(): | |
return [ | |
cl.Starter( | |
label="Plan your quick daily meals", | |
message="Give me ideas for making an easy weeknight dinner that takes less than 25 minutes to prepare", | |
icon="/public/meals4.svg", | |
), | |
cl.Starter( | |
label="Ideas for special occasions", | |
message="What are good Middle Eastern dishes to make for Thanksgiving?", | |
icon="/public/occasion4.svg", | |
), | |
cl.Starter( | |
label="Use ingredients you have", | |
message="Suggest Spanish recipes that are good for the summer that makes use of tomatoes", | |
icon="/public/ingredients4.svg", | |
), | |
] | |
# Chat Start Function: Initialize a RAG (Retrieval-Augmented Generation) chain at the start of each chat session. | |
async def start_chat(): | |
""" | |
This function will be called at the start of every user session. | |
We will build our LCEL RAG chain here, and store it in the user session. | |
The user session is a dictionary that is unique to each user session, and is stored in the memory of the server. | |
""" | |
langgraph_chain = generate_workflow(base_llm, power_llm) | |
cl.user_session.set("langgraph_chain", langgraph_chain) | |
cl.user_session.set("thread_id", str(uuid.uuid4())) | |
# Message Handling Function: Process and respond to user messages using the RAG chain. | |
async def main(message: cl.Message): | |
""" | |
This function will be called every time a message is recieved from a session. | |
We will use the LCEL RAG chain to generate a response to the user question. | |
The LCEL RAG chain is stored in the user session, and is unique to each user session - this is why we can access it here. | |
""" | |
langgraph_chain = cl.user_session.get("langgraph_chain") | |
thread_id = cl.user_session.get("thread_id") | |
msg = cl.Message(content="") | |
langgraph_config = {"configurable": {"thread_id": thread_id, "cl_msg": msg}} | |
async for output in langgraph_chain.astream({"question": message.content}, langgraph_config): | |
for key, value in output.items(): | |
pprint(f"================== Node: '{key}':") | |
await msg.send() | |
# Speech-to-Text Function: Convert audio file to text | |
async def speech_to_text(audio_file): | |
response = await client.audio.transcriptions.create(model="whisper-1", file=audio_file) | |
return response.text | |
# Generate Text Answer Function: Take the output of Speech-to-Text and generate a text answer | |
async def generate_text_answer(transcription): | |
base_rag_chain = cl.user_session.get("base_rag_chain") | |
msg = cl.Message(content="") | |
async for chunk in base_rag_chain.astream( | |
{"question": transcription}, | |
config=RunnableConfig(callbacks=[cl.LangchainCallbackHandler()]), | |
): | |
if isinstance(chunk, dict) and "response" in chunk and isinstance(chunk["response"], str): | |
await msg.stream_token(chunk["response"]) | |
return msg.content | |
# Audio Chunk Function: Process audio chunks as they arrive from the user's microphone | |
async def on_audio_chunk(chunk: cl.AudioChunk): | |
if chunk.isStart: | |
buffer = BytesIO() | |
# This is required for whisper to recognize the file type | |
buffer.name = f"input_audio.{chunk.mimeType.split('/')[1]}" | |
# Initialize the session for a new audio stream | |
cl.user_session.set("audio_buffer", buffer) | |
cl.user_session.set("audio_mime_type", chunk.mimeType) | |
# For now, write the chunks to a buffer and transcribe the whole audio at the end | |
cl.user_session.get("audio_buffer").write(chunk.data) | |
# Audio End Function: Process the audio file and generate a response | |
async def on_audio_end(elements: list[ElementBased]): | |
# Get the audio buffer from the session | |
audio_buffer: BytesIO = cl.user_session.get("audio_buffer") | |
audio_buffer.seek(0) # Move the file pointer to the beginning | |
audio_file = audio_buffer.read() | |
audio_mime_type: str = cl.user_session.get("audio_mime_type") | |
input_audio_el = cl.Audio(mime=audio_mime_type, content=audio_file, name=audio_buffer.name) | |
await cl.Message(author="You", type="user_message", content="", elements=[input_audio_el, *elements]).send() | |
whisper_input = (audio_buffer.name, audio_file, audio_mime_type) | |
transcription = await speech_to_text(whisper_input) | |
print("Transcription: ", transcription) | |
langgraph_chain = cl.user_session.get("langgraph_chain") | |
thread_id = cl.user_session.get("thread_id") | |
msg = cl.Message(content="") | |
langgraph_config = {"configurable": {"thread_id": thread_id, "cl_msg": msg}} | |
async for output in langgraph_chain.astream({"question": transcription}, langgraph_config): | |
for key, value in output.items(): | |
pprint(f"================== Node: '{key}':") | |
await msg.send() | |