Pijush2023 commited on
Commit
7e66356
·
verified ·
1 Parent(s): 8527f42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -148
app.py CHANGED
@@ -1,39 +1,15 @@
1
  import gradio as gr
2
  import os
3
  import logging
4
- from langchain_core.prompts import ChatPromptTemplate
5
- from langchain_core.output_parsers import StrOutputParser
6
- from langchain_openai import ChatOpenAI
7
- from langchain_community.graphs import Neo4jGraph
8
- from typing import List, Tuple
9
- from pydantic import BaseModel, Field
10
- from langchain_core.messages import AIMessage, HumanMessage
11
- from langchain_core.runnables import (
12
- RunnableBranch,
13
- RunnableLambda,
14
- RunnablePassthrough,
15
- RunnableParallel,
16
- )
17
- from langchain_core.prompts.prompt import PromptTemplate
18
  import requests
19
  import tempfile
20
- from langchain.memory import ConversationBufferWindowMemory
21
- import time
22
- import logging
23
- from langchain.chains import ConversationChain
24
  import torch
25
- import torchaudio
26
- from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
27
  import numpy as np
 
28
  import threading
29
 
30
- # Setup conversational memory
31
- conversational_memory = ConversationBufferWindowMemory(
32
- memory_key='chat_history',
33
- k=10,
34
- return_messages=True
35
- )
36
-
37
  # Setup Neo4j connection
38
  graph = Neo4jGraph(
39
  url="neo4j+s://6457770f.databases.neo4j.io",
@@ -41,20 +17,7 @@ graph = Neo4jGraph(
41
  password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
42
  )
43
 
44
- # Define entity extraction and retrieval functions
45
- class Entities(BaseModel):
46
- names: List[str] = Field(
47
- ..., description="All the person, organization, or business entities that appear in the text"
48
- )
49
-
50
- entity_prompt = ChatPromptTemplate.from_messages([
51
- ("system", "You are extracting organization and person entities from the text."),
52
- ("human", "Use the given format to extract information from the following input: {question}"),
53
- ])
54
-
55
- chat_model = ChatOpenAI(temperature=0, model_name="gpt-4o", api_key=os.environ['OPENAI_API_KEY'])
56
- entity_chain = entity_prompt | chat_model.with_structured_output(Entities)
57
-
58
  def remove_lucene_chars(input: str) -> str:
59
  return input.translate(str.maketrans({
60
  "\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
@@ -63,6 +26,7 @@ def remove_lucene_chars(input: str) -> str:
63
  ";": r"\;", " ": r"\ "
64
  }))
65
 
 
66
  def generate_full_text_query(input: str) -> str:
67
  full_text_query = ""
68
  words = [el for el in remove_lucene_chars(input).split() if el]
@@ -71,60 +35,29 @@ def generate_full_text_query(input: str) -> str:
71
  full_text_query += f" {words[-1]}~2"
72
  return full_text_query.strip()
73
 
74
- # Setup logging to a file to capture debug information
75
- logging.basicConfig(filename='neo4j_retrieval.log', level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
76
-
77
- def structured_retriever(question: str) -> str:
78
- result = ""
79
- entities = entity_chain.invoke({"question": question})
80
- for entity in entities.names:
81
  response = graph.query(
82
- """CALL db.index.fulltext.queryNodes('entity', $query, {limit:2})
83
- YIELD node,score
84
- CALL {
85
- WITH node
86
- MATCH (node)-[r:!MENTIONS]->(neighbor)
87
- RETURN node.id + ' - ' + type(r) + ' -> ' + neighbor.id AS output
88
- UNION ALL
89
- WITH node
90
- MATCH (node)<-[r:!MENTIONS]-(neighbor)
91
- RETURN neighbor.id + ' - ' + type(r) + ' -> ' + node.id AS output
92
- }
93
- RETURN output LIMIT 50
94
  """,
95
- {"query": generate_full_text_query(entity)},
96
  )
97
- result += "\n".join([el['output'] for el in response])
98
- return result
99
-
100
- def retriever_neo4j(question: str):
101
- structured_data = structured_retriever(question)
102
- logging.debug(f"Structured data: {structured_data}")
103
- return structured_data
104
-
105
- # Define the chain for Neo4j-based retrieval and response generation
106
- chain_neo4j = (
107
- RunnableParallel(
108
- {
109
- "context": RunnableLambda(lambda x: retriever_neo4j(x["question"])),
110
- "question": RunnablePassthrough(),
111
- }
112
- )
113
- | ChatPromptTemplate.from_template("Answer: {context} Question: {question}")
114
- | chat_model
115
- | StrOutputParser()
116
- )
117
-
118
- # Define the function to get the response
119
- def get_response(question):
120
- try:
121
- return chain_neo4j.invoke({"question": question})
122
  except Exception as e:
123
- return f"Error: {str(e)}"
124
-
125
- # Define the function to clear input and output
126
- def clear_fields():
127
- return [], "", None
128
 
129
  # Function to generate audio with Eleven Labs TTS
130
  def generate_audio_elevenlabs(text):
@@ -152,75 +85,56 @@ def generate_audio_elevenlabs(text):
152
  if chunk:
153
  f.write(chunk)
154
  audio_path = f.name
155
- logging.debug(f"Audio saved to {audio_path}")
156
- return audio_path # Return audio path for automatic playback
157
  else:
158
- logging.error(f"Error generating audio: {response.text}")
159
  return None
160
 
161
- # Function to handle voice to voice conversation
162
- def handle_voice_to_voice(chat_history, question):
163
- response = get_response(question)
164
- audio_path = generate_audio_elevenlabs(response)
165
- chat_history.append(("[Voice Input]", "[Voice Response]"))
166
- return chat_history, "", audio_path
167
-
168
- # Function to transcribe audio input
169
- def transcribe_function(stream, new_chunk):
170
- try:
171
- sr, y = new_chunk[0], new_chunk[1]
172
- except TypeError:
173
- print(f"Error chunk structure: {type(new_chunk)}, content: {new_chunk}")
174
- return stream, "", None
175
-
176
- if y is None or len(y) == 0:
177
- return stream, "", None
178
-
179
- y = y.astype(np.float32)
180
- max_abs_y = np.max(np.abs(y))
181
- if max_abs_y > 0:
182
- y = y / max_abs_y
183
-
184
- if stream is not None and len(stream) > 0:
185
- stream = np.concatenate([stream, y])
186
- else:
187
- stream = y
188
 
189
- result = pipe_asr({"array": stream, "sampling_rate": sr}, return_timestamps=False)
190
- full_text = result.get("text", "")
 
 
 
 
191
 
192
- threading.Thread(target=auto_reset_state).start()
 
193
 
194
- return stream, full_text, full_text
 
 
195
 
196
  # Define the Gradio interface
197
- with gr.Blocks(theme="rawrsor1/Everforest") as demo:
198
- chatbot = gr.Chatbot([], elem_id="RADAR", bubble_full_width=False)
199
- mode_selection = gr.Radio(
200
- choices=["Normal Chatbot", "Voice to Voice Conversation"],
201
- label="Mode Selection",
202
- value="Normal Chatbot"
203
- )
204
- question_input = gr.Textbox(label="Ask a Question", placeholder="Type your question here...")
205
- audio_input = gr.Audio(sources=["microphone"], streaming=True, type='numpy', every=0.1, label="Speak to Ask")
206
  submit_voice_btn = gr.Button("Submit Voice")
207
- audio_output = gr.Audio(label="Audio", type="filepath", autoplay=True, interactive=False)
208
 
209
  # Interactions for Submit Voice Button
210
  submit_voice_btn.click(
211
  fn=handle_voice_to_voice,
212
- inputs=[chatbot, question_input],
213
- outputs=[chatbot, question_input, audio_output],
214
- api_name="api_voice_to_voice_translation"
215
- )
216
-
217
- # Speech-to-Text functionality
218
- state = gr.State()
219
- audio_input.stream(
220
- transcribe_function,
221
- inputs=[state, audio_input],
222
- outputs=[state, question_input],
223
- api_name="api_voice_to_text"
224
  )
225
 
226
- demo.launch(show_error=True, share=True)
 
 
1
  import gradio as gr
2
  import os
3
  import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import requests
5
  import tempfile
6
+ from langchain_openai import ChatOpenAI
7
+ from langchain_community.graphs import Neo4jGraph
 
 
8
  import torch
 
 
9
  import numpy as np
10
+ from transformers import pipeline, AutoModelForSpeechSeq2Seq, AutoProcessor
11
  import threading
12
 
 
 
 
 
 
 
 
13
  # Setup Neo4j connection
14
  graph = Neo4jGraph(
15
  url="neo4j+s://6457770f.databases.neo4j.io",
 
17
  password="Z10duoPkKCtENuOukw3eIlvl0xJWKtrVSr-_hGX1LQ4"
18
  )
19
 
20
+ # Function to clean input for Neo4j full-text query
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def remove_lucene_chars(input: str) -> str:
22
  return input.translate(str.maketrans({
23
  "\\": r"\\", "+": r"\+", "-": r"\-", "&": r"\&", "|": r"\|", "!": r"\!",
 
26
  ";": r"\;", " ": r"\ "
27
  }))
28
 
29
+ # Function to generate a full-text query
30
  def generate_full_text_query(input: str) -> str:
31
  full_text_query = ""
32
  words = [el for el in remove_lucene_chars(input).split() if el]
 
35
  full_text_query += f" {words[-1]}~2"
36
  return full_text_query.strip()
37
 
38
+ # Define the function to query Neo4j and get a response
39
+ def get_response(question):
40
+ query = generate_full_text_query(question)
41
+ try:
42
+ # Query the Neo4j database using a full-text search
 
 
43
  response = graph.query(
44
+ """
45
+ CALL db.index.fulltext.queryNodes('entity', $query)
46
+ YIELD node, score
47
+ RETURN node.content AS content, score
48
+ ORDER BY score DESC LIMIT 1
 
 
 
 
 
 
 
49
  """,
50
+ {"query": query}
51
  )
52
+ # Extract the content from the top response
53
+ if response:
54
+ result = response[0]['content']
55
+ return result
56
+ else:
57
+ return "Sorry, I couldn't find any relevant information in the database."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  except Exception as e:
59
+ logging.error(f"Error querying Neo4j: {e}")
60
+ return "An error occurred while fetching data from the database."
 
 
 
61
 
62
  # Function to generate audio with Eleven Labs TTS
63
  def generate_audio_elevenlabs(text):
 
85
  if chunk:
86
  f.write(chunk)
87
  audio_path = f.name
88
+ return audio_path
 
89
  else:
 
90
  return None
91
 
92
+ # Define ASR model for speech-to-text
93
+ model_id = 'openai/whisper-large-v3'
94
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
95
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
96
+ model = AutoModelForSpeechSeq2Seq.from_pretrained(model_id, torch_dtype=torch_dtype).to(device)
97
+ processor = AutoProcessor.from_pretrained(model_id)
98
+
99
+ pipe_asr = pipeline(
100
+ "automatic-speech-recognition",
101
+ model=model,
102
+ tokenizer=processor.tokenizer,
103
+ feature_extractor=processor.feature_extractor,
104
+ max_new_tokens=128,
105
+ chunk_length_s=15,
106
+ batch_size=16,
107
+ torch_dtype=torch_dtype,
108
+ device=device,
109
+ return_timestamps=True
110
+ )
 
 
 
 
 
 
 
 
111
 
112
+ # Function to handle voice input, generate response from Neo4j, and return audio output
113
+ def handle_voice_to_voice(audio):
114
+ # Transcribe audio input to text
115
+ sr, y = audio
116
+ result = pipe_asr({"array": y, "sampling_rate": sr}, return_timestamps=False)
117
+ question = result.get("text", "")
118
 
119
+ # Get response using the transcribed question
120
+ response = get_response(question)
121
 
122
+ # Generate audio from the response
123
+ audio_path = generate_audio_elevenlabs(response)
124
+ return audio_path
125
 
126
  # Define the Gradio interface
127
+ with gr.Blocks() as demo:
128
+ audio_input = gr.Audio(sources=["microphone"], type='numpy', streaming=False, label="Speak to Ask")
 
 
 
 
 
 
 
129
  submit_voice_btn = gr.Button("Submit Voice")
130
+ audio_output = gr.Audio(label="Response Audio", type="filepath", autoplay=True, interactive=False)
131
 
132
  # Interactions for Submit Voice Button
133
  submit_voice_btn.click(
134
  fn=handle_voice_to_voice,
135
+ inputs=audio_input,
136
+ outputs=audio_output
 
 
 
 
 
 
 
 
 
 
137
  )
138
 
139
+ # Launch the Gradio interface
140
+ demo.launch(show_error=True, share=True)