Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,39 +1,15 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import logging
|
4 |
-
from langchain_core.prompts import ChatPromptTemplate
|
5 |
-
from langchain_core.output_parsers import StrOutputParser
|
6 |
-
from langchain_openai import ChatOpenAI
|
7 |
-
from langchain_community.graphs import Neo4jGraph
|
8 |
-
from typing import List, Tuple
|
9 |
-
from pydantic import BaseModel, Field
|
10 |
-
from langchain_core.messages import AIMessage, HumanMessage
|
11 |
-
from langchain_core.runnables import (
|
12 |
-
RunnableBranch,
|
13 |
-
RunnableLambda,
|
14 |
-
RunnablePassthrough,
|
15 |
-
RunnableParallel,
|
16 |
-
)
|
17 |
-
from langchain_core.prompts.prompt import PromptTemplate
|
18 |
import requests
|
19 |
import tempfile
|
20 |
-
from
|
21 |
-
import
|
22 |
-
import logging
|
23 |
-
from langchain.chains import ConversationChain
|
24 |
import torch
|
25 |
-
import torchaudio
|
26 |
-
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
|
27 |
import numpy as np
|
|
|
28 |
import threading
|
29 |
|
30 |
-
# Setup conversational memory
|
31 |
-
conversational_memory = ConversationBufferWindowMemory(
|
32 |
-
memory_key='chat_history',
|
33 |
-
k=10,
|
34 |
-
return_messages=True
|
35 |
-
)
|
36 |
-
|
37 |
# Setup Neo4j connection
|
38 |
graph = Neo4jGraph(
|
39 |
url="neo4j+s://6457770f.databases.neo4j.io",
|
@@ -41,20 +17,7 @@ graph = Neo4jGraph(
|
|
41 |
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
|
42 |
)
|
43 |
|
44 |
-
#
|
45 |
-
class Entities(BaseModel):
|
46 |
-
names: List[str] = Field(
|
47 |
-
..., description="All the person, organization, or business entities that appear in the text"
|
48 |
-
)
|
49 |
-
|
50 |
-
entity_prompt = ChatPromptTemplate.from_messages([
|
51 |
-
("system", "You are extracting organization and person entities from the text."),
|
52 |
-
("human", "Use the given format to extract information from the following input: {question}"),
|
53 |
-
])
|
54 |
-
|
55 |
-
chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY'])
|
56 |
-
entity_chain = entity_prompt | chat_model.with_structured_output(Entities)
|
57 |
-
|
58 |
def remove_lucene_chars(input: str) -> str:
|
59 |
return input.translate(str.maketrans({
|
60 |
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
|
@@ -63,6 +26,7 @@ def remove_lucene_chars(input: str) -> str:
|
|
63 |
";": r"\;", " ": r"\ "
|
64 |
}))
|
65 |
|
|
|
66 |
def generate_full_text_query(input: str) -> str:
|
67 |
full_text_query = ""
|
68 |
words = [el for el in remove_lucene_chars(input).split() if el]
|
@@ -71,60 +35,29 @@ def generate_full_text_query(input: str) -> str:
|
|
71 |
full_text_query += f" {words[-1]}~2"
|
72 |
return full_text_query.strip()
|
73 |
|
74 |
-
#
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
entities = entity_chain.invoke({"question": question})
|
80 |
-
for entity in entities.names:
|
81 |
response = graph.query(
|
82 |
-
"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
|
88 |
-
UNION ALL
|
89 |
-
WITH node
|
90 |
-
MATCH (node)<-[r:!MENTIONS]-(neighbor)
|
91 |
-
RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
|
92 |
-
}
|
93 |
-
RETURN output LIMIT 50
|
94 |
""",
|
95 |
-
{"query":
|
96 |
)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
return structured_data
|
104 |
-
|
105 |
-
# Define the chain for Neo4j-based retrieval and response generation
|
106 |
-
chain_neo4j = (
|
107 |
-
RunnableParallel(
|
108 |
-
{
|
109 |
-
"context": RunnableLambda(lambda x: retriever_neo4j(x["question"])),
|
110 |
-
"question": RunnablePassthrough(),
|
111 |
-
}
|
112 |
-
)
|
113 |
-
| ChatPromptTemplate.from_template("Answer: {context} Question: {question}")
|
114 |
-
| chat_model
|
115 |
-
| StrOutputParser()
|
116 |
-
)
|
117 |
-
|
118 |
-
# Define the function to get the response
|
119 |
-
def get_response(question):
|
120 |
-
try:
|
121 |
-
return chain_neo4j.invoke({"question": question})
|
122 |
except Exception as e:
|
123 |
-
|
124 |
-
|
125 |
-
# Define the function to clear input and output
|
126 |
-
def clear_fields():
|
127 |
-
return [], "", None
|
128 |
|
129 |
# Function to generate audio with Eleven Labs TTS
|
130 |
def generate_audio_elevenlabs(text):
|
@@ -152,75 +85,56 @@ def generate_audio_elevenlabs(text):
|
|
152 |
if chunk:
|
153 |
f.write(chunk)
|
154 |
audio_path = f.name
|
155 |
-
|
156 |
-
return audio_path # Return audio path for automatic playback
|
157 |
else:
|
158 |
-
logging.error(f"Error generating audio: {response.text}")
|
159 |
return None
|
160 |
|
161 |
-
#
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
max_abs_y = np.max(np.abs(y))
|
181 |
-
if max_abs_y > 0:
|
182 |
-
y = y / max_abs_y
|
183 |
-
|
184 |
-
if stream is not None and len(stream) > 0:
|
185 |
-
stream = np.concatenate([stream, y])
|
186 |
-
else:
|
187 |
-
stream = y
|
188 |
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
|
|
|
193 |
|
194 |
-
|
|
|
|
|
195 |
|
196 |
# Define the Gradio interface
|
197 |
-
with gr.Blocks(
|
198 |
-
|
199 |
-
mode_selection = gr.Radio(
|
200 |
-
choices=["Normal Chatbot", "Voice to Voice Conversation"],
|
201 |
-
label="Mode Selection",
|
202 |
-
value="Normal Chatbot"
|
203 |
-
)
|
204 |
-
question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
|
205 |
-
audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1, label="Speak to Ask")
|
206 |
submit_voice_btn = gr.Button("Submit Voice")
|
207 |
-
audio_output = gr.Audio(label="Audio", type="filepath", autoplay=True, interactive=False)
|
208 |
|
209 |
# Interactions for Submit Voice Button
|
210 |
submit_voice_btn.click(
|
211 |
fn=handle_voice_to_voice,
|
212 |
-
inputs=
|
213 |
-
outputs=
|
214 |
-
api_name="api_voice_to_voice_translation"
|
215 |
-
)
|
216 |
-
|
217 |
-
# Speech-to-Text functionality
|
218 |
-
state = gr.State()
|
219 |
-
audio_input.stream(
|
220 |
-
transcribe_function,
|
221 |
-
inputs=[state, audio_input],
|
222 |
-
outputs=[state, question_input],
|
223 |
-
api_name="api_voice_to_text"
|
224 |
)
|
225 |
|
226 |
-
|
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import requests
|
5 |
import tempfile
|
6 |
+
from langchain_openai import ChatOpenAI
|
7 |
+
from langchain_community.graphs import Neo4jGraph
|
|
|
|
|
8 |
import torch
|
|
|
|
|
9 |
import numpy as np
|
10 |
+
from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
|
11 |
import threading
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# Setup Neo4j connection
|
14 |
graph = Neo4jGraph(
|
15 |
url="neo4j+s://6457770f.databases.neo4j.io",
|
|
|
17 |
password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
|
18 |
)
|
19 |
|
20 |
+
# Function to clean input for Neo4j full-text query
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def remove_lucene_chars(input: str) -> str:
|
22 |
return input.translate(str.maketrans({
|
23 |
"\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
|
|
|
26 |
";": r"\;", " ": r"\ "
|
27 |
}))
|
28 |
|
29 |
+
# Function to generate a full-text query
|
30 |
def generate_full_text_query(input: str) -> str:
|
31 |
full_text_query = ""
|
32 |
words = [el for el in remove_lucene_chars(input).split() if el]
|
|
|
35 |
full_text_query += f" {words[-1]}~2"
|
36 |
return full_text_query.strip()
|
37 |
|
38 |
+
# Define the function to query Neo4j and get a response
|
39 |
+
def get_response(question):
|
40 |
+
query = generate_full_text_query(question)
|
41 |
+
try:
|
42 |
+
# Query the Neo4j database using a full-text search
|
|
|
|
|
43 |
response = graph.query(
|
44 |
+
"""
|
45 |
+
CALL db.index.fulltext.queryNodes('entity', $query)
|
46 |
+
YIELD node, score
|
47 |
+
RETURN node.content AS content, score
|
48 |
+
ORDER BY score DESC LIMIT 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
""",
|
50 |
+
{"query": query}
|
51 |
)
|
52 |
+
# Extract the content from the top response
|
53 |
+
if response:
|
54 |
+
result = response[0]['content']
|
55 |
+
return result
|
56 |
+
else:
|
57 |
+
return "Sorry, I couldn't find any relevant information in the database."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
except Exception as e:
|
59 |
+
logging.error(f"Error querying Neo4j: {e}")
|
60 |
+
return "An error occurred while fetching data from the database."
|
|
|
|
|
|
|
61 |
|
62 |
# Function to generate audio with Eleven Labs TTS
|
63 |
def generate_audio_elevenlabs(text):
|
|
|
85 |
if chunk:
|
86 |
f.write(chunk)
|
87 |
audio_path = f.name
|
88 |
+
return audio_path
|
|
|
89 |
else:
|
|
|
90 |
return None
|
91 |
|
92 |
+
# Define ASR model for speech-to-text
|
93 |
+
model_id = 'openai/whisper-large-v3'
|
94 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
95 |
+
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
96 |
+
model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
|
97 |
+
processor = AutoProcessor.from_pretrained(model_id)
|
98 |
+
|
99 |
+
pipe_asr = pipeline(
|
100 |
+
"automatic-speech-recognition",
|
101 |
+
model=model,
|
102 |
+
tokenizer=processor.tokenizer,
|
103 |
+
feature_extractor=processor.feature_extractor,
|
104 |
+
max_new_tokens=128,
|
105 |
+
chunk_length_s=15,
|
106 |
+
batch_size=16,
|
107 |
+
torch_dtype=torch_dtype,
|
108 |
+
device=device,
|
109 |
+
return_timestamps=True
|
110 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
+
# Function to handle voice input, generate response from Neo4j, and return audio output
|
113 |
+
def handle_voice_to_voice(audio):
|
114 |
+
# Transcribe audio input to text
|
115 |
+
sr, y = audio
|
116 |
+
result = pipe_asr({"array": y, "sampling_rate": sr}, return_timestamps=False)
|
117 |
+
question = result.get("text", "")
|
118 |
|
119 |
+
# Get response using the transcribed question
|
120 |
+
response = get_response(question)
|
121 |
|
122 |
+
# Generate audio from the response
|
123 |
+
audio_path = generate_audio_elevenlabs(response)
|
124 |
+
return audio_path
|
125 |
|
126 |
# Define the Gradio interface
|
127 |
+
with gr.Blocks() as demo:
|
128 |
+
audio_input = gr.Audio(sources=["microphone"], type='numpy', streaming=False, label="Speak to Ask")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
submit_voice_btn = gr.Button("Submit Voice")
|
130 |
+
audio_output = gr.Audio(label="Response Audio", type="filepath", autoplay=True, interactive=False)
|
131 |
|
132 |
# Interactions for Submit Voice Button
|
133 |
submit_voice_btn.click(
|
134 |
fn=handle_voice_to_voice,
|
135 |
+
inputs=audio_input,
|
136 |
+
outputs=audio_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
)
|
138 |
|
139 |
+
# Launch the Gradio interface
|
140 |
+
demo.launch(show_error=True, share=True)
|