AMead10 commited on
Commit
6af554a
·
1 Parent(s): b851225

phase 2 initial

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +132 -7
  3. requirements.txt +4 -1
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ examples
app.py CHANGED
@@ -1,20 +1,31 @@
1
  import gradio as gr
2
  import mysql.connector
3
  import os
 
 
 
 
 
 
4
 
5
  # Use a pipeline as a high-level helper
6
  from transformers import pipeline
7
 
 
 
8
  classifier_model = pipeline(
9
  "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
10
  )
11
 
 
 
12
  # get db info from env vars
13
  db_host = os.environ.get("DB_HOST")
14
  db_user = os.environ.get("DB_USER")
15
  db_pass = os.environ.get("DB_PASS")
16
  db_name = os.environ.get("DB_NAME")
17
 
 
18
 
19
  db_connection = mysql.connector.connect(
20
  host=db_host,
@@ -29,6 +40,14 @@ ORG_ID = 731
29
 
30
  potential_labels = []
31
 
 
 
 
 
 
 
 
 
32
 
33
  def get_potential_labels():
34
  # get potential labels from db
@@ -48,7 +67,7 @@ potential_labels = get_potential_labels()
48
 
49
 
50
  # Function to handle the classification
51
- def classify_email(constituent_email):
52
  potential_labels = get_potential_labels()
53
  print("classifying email")
54
  model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
@@ -62,9 +81,43 @@ def classify_email(constituent_email):
62
  # Find the index of the highest score
63
  max_score_index = model_out["scores"].index(max(model_out["scores"]))
64
  # Return the label with the highest score
65
- return model_out["labels"][max_score_index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
- return ", ".join(top_labels)
 
 
 
 
 
 
 
68
 
69
 
70
  def remove_spaces_after_comma(s):
@@ -73,6 +126,78 @@ def remove_spaces_after_comma(s):
73
  return ",".join(parts)
74
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # Function to handle saving data
77
  def save_data(orig_user_email, constituent_email, labels, user_response, current_user):
78
  # save the data to the database
@@ -224,9 +349,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
224
 
225
  # Define button actions
226
  classify_button.click(
227
- fn=classify_email,
228
- inputs=constituent_response_input,
229
- outputs=classification_output,
230
  )
231
 
232
  save_button.click(
@@ -242,4 +367,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
242
  )
243
 
244
  # Launch the app
245
- app.launch(auth=auth, debug=True)
 
1
  import gradio as gr
2
  import mysql.connector
3
  import os
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.prompts import (
6
+ ChatPromptTemplate,
7
+ PromptTemplate,
8
+ FewShotPromptTemplate,
9
+ )
10
 
11
  # Use a pipeline as a high-level helper
12
  from transformers import pipeline
13
 
14
+ from sentence_transformers import SentenceTransformer, util
15
+
16
  classifier_model = pipeline(
17
  "zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
18
  )
19
 
20
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
21
+
22
  # get db info from env vars
23
  db_host = os.environ.get("DB_HOST")
24
  db_user = os.environ.get("DB_USER")
25
  db_pass = os.environ.get("DB_PASS")
26
  db_name = os.environ.get("DB_NAME")
27
 
28
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
29
 
30
  db_connection = mysql.connector.connect(
31
  host=db_host,
 
40
 
41
  potential_labels = []
42
 
43
+ llm = ChatOpenAI(openai_api_key=openai_api_key, model="gpt-4")
44
+
45
+ system_prompt = "You are a representative for a local government. A constituent has reached out to you with a question about a local policy. Base your response using the examples below. Be sure to address all points and concerns raised by the constituent. If you do not have enough information to be able to answer the question (you do not see an example that answers the question from the constituent), please make note and another representative will fill in the missing information.\n\n"
46
+
47
+ examples_prompt = PromptTemplate(
48
+ input_variables=["example"], template="Example:\n\n {example}"
49
+ )
50
+
51
 
52
  def get_potential_labels():
53
  # get potential labels from db
 
67
 
68
 
69
  # Function to handle the classification
70
+ def classify_email_and_generate_response(representative_email, constituent_email):
71
  potential_labels = get_potential_labels()
72
  print("classifying email")
73
  model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
 
81
  # Find the index of the highest score
82
  max_score_index = model_out["scores"].index(max(model_out["scores"]))
83
  # Return the label with the highest score
84
+ top_labels = [model_out["labels"][max_score_index]]
85
+
86
+ labels_with_enough_examples = ["Enforcement", "Financial", "Rules"]
87
+ # see if any of the labels are in labels_with_enough_examples, if so get the messages for that category, else return
88
+
89
+ examples = get_similar_messages(constituent_email)
90
+
91
+ if representative_email != "":
92
+ current_thread = (
93
+ "Representative message: \n\n"
94
+ + representative_email
95
+ + "\n\nConstituent message: \n\n"
96
+ + constituent_email
97
+ )
98
+ else:
99
+ current_thread = "Constituent message: \n\n" + constituent_email
100
+
101
+ prompt = FewShotPromptTemplate(
102
+ examples=examples,
103
+ example_prompt=examples_prompt,
104
+ prefix=system_prompt,
105
+ suffix="Current thread:\n\n {current_thread}\n\nYour response:\n\n",
106
+ input_variables=["current_thread"],
107
+ )
108
+
109
+ formatted_prompt = prompt.format(current_thread=current_thread)
110
+
111
+ print(formatted_prompt)
112
 
113
+ print("Generating GPT4 response")
114
+ import time
115
+
116
+ start = time.time()
117
+ out = "GPT4:\n\n" + llm.invoke(formatted_prompt).content
118
+ print("GPT4 response generated in", time.time() - start, "seconds")
119
+
120
+ return ", ".join(top_labels), out
121
 
122
 
123
  def remove_spaces_after_comma(s):
 
126
  return ",".join(parts)
127
 
128
 
129
+ def get_similar_messages(constituent_email):
130
+ db_connection = mysql.connector.connect(
131
+ host=db_host,
132
+ user=db_user,
133
+ password=db_pass,
134
+ database=db_name,
135
+ )
136
+ db_cursor = db_connection.cursor()
137
+
138
+ messages_for_category = db_cursor.execute(
139
+ "SELECT id, person_id, body FROM radmap_frog12.messages WHERE id IN (SELECT message_id FROM radmap_frog12.message_category_associations)"
140
+ )
141
+
142
+ messages_for_category = db_cursor.fetchall()
143
+
144
+ all_message_chains = []
145
+
146
+ for message in messages_for_category:
147
+ # TODO: refactor for when integrated with RADMAP
148
+ # if person_id is set
149
+ if message[1] != 0:
150
+ message_chain = "Representative message: \n\n" + message[2] + "\n\n"
151
+ is_representative_turn = False
152
+ else:
153
+ message_chain = "Constituent message: \n\n" + message[2] + "\n\n"
154
+ is_representative_turn = True
155
+ embedding = embedding_model.encode([message[2]])[0]
156
+
157
+ next_message_id = message[0]
158
+
159
+ while next_message_id:
160
+ next_message = db_cursor.execute(
161
+ "SELECT id, body FROM radmap_frog12.messages WHERE previous_message_id = %s",
162
+ (next_message_id,),
163
+ )
164
+ next_message = db_cursor.fetchall()
165
+ if not next_message:
166
+ break
167
+ if is_representative_turn:
168
+ message_chain += (
169
+ "Representative message: \n\n" + next_message[0][1] + "\n\n"
170
+ )
171
+ is_representative_turn = False
172
+ else:
173
+ message_chain += (
174
+ "Constituent message: \n\n" + next_message[0][1] + "\n\n"
175
+ )
176
+ is_representative_turn = True
177
+
178
+ embedding = embedding_model.encode([next_message[0][1]])[0]
179
+
180
+ next_message_id = next_message[0][0]
181
+
182
+ all_message_chains.append((message_chain, embedding))
183
+
184
+ target_embedding = embedding_model.encode([constituent_email])[0]
185
+
186
+ # Compute cosine-similarities and keep the top 3 most similar sentences
187
+
188
+ top_messages = []
189
+
190
+ for message, embedding in all_message_chains:
191
+ cosine_score = util.pytorch_cos_sim(embedding, target_embedding)
192
+ if cosine_score > 0.98:
193
+ continue
194
+ top_messages.append((message, cosine_score))
195
+
196
+ top_messages = sorted(top_messages, key=lambda x: x[1], reverse=True)
197
+
198
+ return [{"example": message} for message, score in top_messages[0:3]]
199
+
200
+
201
  # Function to handle saving data
202
  def save_data(orig_user_email, constituent_email, labels, user_response, current_user):
203
  # save the data to the database
 
349
 
350
  # Define button actions
351
  classify_button.click(
352
+ fn=classify_email_and_generate_response,
353
+ inputs=[original_email_input, constituent_response_input],
354
+ outputs=[classification_output, user_response_input],
355
  )
356
 
357
  save_button.click(
 
367
  )
368
 
369
  # Launch the app
370
+ app.launch(debug=True)
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
  mysql-connector-python
2
  torch
3
- transformers
 
 
 
 
1
  mysql-connector-python
2
  torch
3
+ transformers
4
+ langchain-openai
5
+ langchain-core
6
+ sentence_transformers