|
import gradio as gr |
|
import os |
|
import logging |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_openai import ChatOpenAI |
|
from langchain_community.graphs import Neo4jGraph |
|
from typing import List, Tuple |
|
from pydantic import BaseModel, Field |
|
from langchain_core.messages import AIMessage, HumanMessage |
|
from langchain_core.runnables import ( |
|
RunnableBranch, |
|
RunnableLambda, |
|
RunnablePassthrough, |
|
RunnableParallel, |
|
) |
|
from langchain_core.prompts.prompt import PromptTemplate |
|
import requests |
|
import tempfile |
|
from langchain.memory import ConversationBufferWindowMemory |
|
import time |
|
import logging |
|
from langchain.chains import ConversationChain |
|
import torch |
|
import torchaudio |
|
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor |
|
import numpy as np |
|
|
|
|
|
|
|
conversational_memory = ConversationBufferWindowMemory( |
|
memory_key='chat_history', |
|
k=10, |
|
return_messages=True |
|
) |
|
|
|
|
|
graph = Neo4jGraph( |
|
url="neo4j+s://6457770f.databases.neo4j.io", |
|
username="neo4j", |
|
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4" |
|
) |
|
|
|
|
|
class Entities(BaseModel): |
|
names: List[str] = Field( |
|
..., description="All the person, organization, or business entities that appear in the text" |
|
) |
|
|
|
entity_prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are extracting organization and person entities from the text."), |
|
("human", "Use the given format to extract information from the following input: {question}"), |
|
]) |
|
|
|
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY']) |
|
entity_chain = entity_prompt | chat_model.with_structured_output(Entities) |
|
|
|
def remove_lucene_chars(input: str) -> str: |
|
return input.translate(str.maketrans({ |
|
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!", |
|
"(": r"\(", ")": r"\)", "{": r"\{", "}": r"\}", "[": r"\[", "]": r"\]", |
|
"^": r"\^", "~": r"\~", "*": r"\*", "?": r"\?", ":": r"\:", '"': r'\"', |
|
";": r"\;", " ": r"\ " |
|
})) |
|
|
|
def generate_full_text_query(input: str) -> str: |
|
full_text_query = "" |
|
words = [el for el in remove_lucene_chars(input).split() if el] |
|
for word in words[:-1]: |
|
full_text_query += f" {word}~2 AND" |
|
full_text_query += f" {words[-1]}~2" |
|
return full_text_query.strip() |
|
|
|
|
|
logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
def structured_retriever(question: str) -> str: |
|
result = "" |
|
entities = entity_chain.invoke({"question": question}) |
|
for entity in entities.names: |
|
response = graph.query( |
|
"""CALL db.index.fulltext.queryNodes('entity', $query, {limit:2}) |
|
YIELD node,score |
|
CALL { |
|
WITH node |
|
MATCH (node)-[r:!MENTIONS]->(neighbor) |
|
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output |
|
UNION ALL |
|
WITH node |
|
MATCH (node)<-[r:!MENTIONS]-(neighbor) |
|
RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output |
|
} |
|
RETURN output LIMIT 50 |
|
""", |
|
{"query": generate_full_text_query(entity)}, |
|
) |
|
result += "\n".join([el['output'] for el in response]) |
|
return result |
|
|
|
def retriever_neo4j(question: str): |
|
structured_data = structured_retriever(question) |
|
logging.debug(f"Structured data: {structured_data}") |
|
return structured_data |
|
|
|
|
|
_template = """Given the following conversation and a follow-up question, rephrase the follow-up question to be a standalone question, |
|
in its original language. |
|
Chat History: |
|
{chat_history} |
|
Follow Up Input: {question} |
|
Standalone question:""" |
|
|
|
CONDENSE_QUESTION_PROMPT = PromptTemplate.from_template(_template) |
|
|
|
def _format_chat_history(chat_history: list[tuple[str, str]]) -> list: |
|
buffer = [] |
|
for human, ai in chat_history: |
|
buffer.append(HumanMessage(content=human)) |
|
buffer.append(AIMessage(content=ai)) |
|
return buffer |
|
|
|
_search_query = RunnableBranch( |
|
( |
|
RunnableLambda(lambda x: bool(x.get("chat_history"))).with_config( |
|
run_name="HasChatHistoryCheck" |
|
), |
|
RunnablePassthrough.assign( |
|
chat_history=lambda x: _format_chat_history(x["chat_history"]) |
|
) |
|
| CONDENSE_QUESTION_PROMPT |
|
| ChatOpenAI(temperature=0, api_key=os.environ['OPENAI_API_KEY']) |
|
| StrOutputParser(), |
|
), |
|
RunnableLambda(lambda x: x["question"]), |
|
) |
|
|
|
|
|
template = """I am a guide for Birmingham, Alabama. I can provide recommendations and insights about the city, including events and activities. |
|
Ask your question directly, and I'll provide a precise and quick,short and crisp response in a conversational way without any Greet. |
|
{context} |
|
Question: {question} |
|
Answer:""" |
|
|
|
|
|
qa_prompt = ChatPromptTemplate.from_template(template) |
|
|
|
|
|
chain_neo4j = ( |
|
RunnableParallel( |
|
{ |
|
"context": _search_query | retriever_neo4j, |
|
"question": RunnablePassthrough(), |
|
} |
|
) |
|
| qa_prompt |
|
| chat_model |
|
| StrOutputParser() |
|
) |
|
|
|
|
|
def get_response(question): |
|
try: |
|
return chain_neo4j.invoke({"question": question}) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
def clear_fields(): |
|
return [],"",None |
|
|
|
|
|
def generate_audio_elevenlabs(text): |
|
XI_API_KEY = os.environ['ELEVENLABS_API'] |
|
VOICE_ID = 'ehbJzYLQFpwbJmGkqbnW' |
|
tts_url = f"https://api.elevenlabs.io/v1/text-to-speech/{VOICE_ID}/stream" |
|
headers = { |
|
"Accept": "application/json", |
|
"xi-api-key": XI_API_KEY |
|
} |
|
data = { |
|
"text": str(text), |
|
"model_id": "eleven_multilingual_v2", |
|
"voice_settings": { |
|
"stability": 1.0, |
|
"similarity_boost": 0.0, |
|
"style": 0.60, |
|
"use_speaker_boost": False |
|
} |
|
} |
|
response = requests.post(tts_url, headers=headers, json=data, stream=True) |
|
if response.ok: |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as f: |
|
for chunk in response.iter_content(chunk_size=1024): |
|
if chunk: |
|
f.write(chunk) |
|
audio_path = f.name |
|
logging.debug(f"Audio saved to {audio_path}") |
|
return audio_path |
|
else: |
|
logging.error(f"Error generating audio: {response.text}") |
|
return None |
|
|
|
|
|
def add_message(history, message): |
|
if message.strip(): |
|
history.append((message, None)) |
|
return history, "" |
|
|
|
|
|
def chat_with_bot(messages): |
|
user_message = messages[-1][0] |
|
messages[-1] = (user_message, "") |
|
|
|
response = get_response(user_message) |
|
|
|
|
|
|
|
|
|
for character in response: |
|
messages[-1] = (user_message, messages[-1][1] + character) |
|
yield messages |
|
time.sleep(0.05) |
|
|
|
yield messages |
|
|
|
|
|
|
|
def generate_audio_from_last_response(history): |
|
|
|
if history and len(history) > 0: |
|
recent_response = history[-1][1] |
|
if recent_response: |
|
return generate_audio_elevenlabs(recent_response) |
|
return None |
|
|
|
|
|
examples = [ |
|
["What are some popular events in Birmingham?"], |
|
["Who are the top players of the Crimson Tide?"], |
|
["Where can I find a hamburger?"], |
|
["What are some popular tourist attractions in Birmingham?"], |
|
["What are some good clubs in Birmingham?"] |
|
] |
|
|
|
|
|
def insert_prompt(current_text, prompt): |
|
return prompt[0] if prompt else current_text |
|
|
|
|
|
|
|
model_id = 'openai/whisper-large-v3' |
|
device = "cuda:0" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device) |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
|
|
pipe_asr = pipeline( |
|
"automatic-speech-recognition", |
|
model=model, |
|
tokenizer=processor.tokenizer, |
|
feature_extractor=processor.feature_extractor, |
|
max_new_tokens=128, |
|
chunk_length_s=15, |
|
batch_size=16, |
|
torch_dtype=torch_dtype, |
|
device=device, |
|
return_timestamps=True |
|
) |
|
|
|
|
|
def transcribe_function(stream, new_chunk): |
|
try: |
|
sr, y = new_chunk[0], new_chunk[1] |
|
except TypeError: |
|
print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}") |
|
return stream, "", None |
|
|
|
y = y.astype(np.float32) |
|
max_abs_y = np.max(np.abs(y)) |
|
if max_abs_y > 0: |
|
y = y / max_abs_y |
|
|
|
if stream is not None: |
|
stream = np.concatenate([stream, y]) |
|
else: |
|
stream = y |
|
|
|
result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False) |
|
full_text = result.get("text", "") |
|
|
|
return stream, full_text, full_text |
|
|
|
|
|
|
|
def clear_transcription_state(): |
|
return None, "" |
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme="rawrsor1/Everforest") as demo: |
|
chatbot = gr.Chatbot([], elem_id="RADAR", bubble_full_width=False) |
|
with gr.Row(): |
|
with gr.Column(): |
|
question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...") |
|
audio_input = gr.Audio(sources=["microphone"],streaming=True,type='numpy',every=0.1,label="Speak to Ask") |
|
|
|
|
|
|
|
with gr.Column(): |
|
audio_output = gr.Audio(label="Audio", type="filepath", interactive=False) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
get_response_btn = gr.Button("Get Response") |
|
with gr.Column(): |
|
clear_state_btn = gr.Button("Clear State") |
|
with gr.Column(): |
|
generate_audio_btn = gr.Button("Generate Audio") |
|
with gr.Column(): |
|
clean_btn = gr.Button("Clean") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("<h1 style='color: red;'>Example Prompts</h1>", elem_id="Example-Prompts") |
|
gr.Examples(examples=examples, fn=insert_prompt, inputs=question_input, outputs=question_input) |
|
|
|
|
|
|
|
get_response_btn.click(fn=add_message, inputs=[chatbot, question_input], outputs=[chatbot, question_input])\ |
|
.then(fn=chat_with_bot, inputs=[chatbot], outputs=chatbot) |
|
|
|
question_input.submit(fn=add_message, inputs=[chatbot, question_input], outputs=[chatbot, question_input])\ |
|
.then(fn=chat_with_bot, inputs=[chatbot], outputs=chatbot) |
|
|
|
|
|
state = gr.State() |
|
audio_input.stream(transcribe_function, inputs=[state, audio_input], outputs=[state, question_input]) |
|
|
|
|
|
generate_audio_btn.click(fn=generate_audio_from_last_response, inputs=chatbot, outputs=audio_output) |
|
clean_btn.click(fn=clear_fields, inputs=[], outputs=[chatbot, question_input, audio_output]) |
|
|
|
|
|
|
|
clear_state_btn.click(fn=clear_transcription_state, outputs=[question_input, state]) |
|
|
|
|
|
|
|
demo.launch(show_error=True) |
|
|
|
|