Spaces:
Runtime error
Runtime error
wendru18
commited on
Commit
·
15a19f2
1
Parent(s):
8c3c6c0
added gradio app
Browse files- .gitignore +2 -1
- app.py +172 -0
- main.ipynb +107 -143
- semantic_search.py +0 -39
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
__pycache__/
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.chroma/
|
app.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.embeddings import TensorflowHubEmbeddings
|
2 |
+
from langchain.chains import ConversationalRetrievalChain
|
3 |
+
from langchain.vectorstores import Chroma
|
4 |
+
from langchain.llms import OpenAI
|
5 |
+
from tqdm import tqdm
|
6 |
+
import pandas as pd
|
7 |
+
import gradio as gr
|
8 |
+
import openai
|
9 |
+
import praw
|
10 |
+
import os
|
11 |
+
import re
|
12 |
+
|
13 |
+
reddit = None
|
14 |
+
bot = None
|
15 |
+
chat_history = []
|
16 |
+
|
17 |
+
def set_openai_key(key):
|
18 |
+
|
19 |
+
if key == "":
|
20 |
+
key = os.environ.get("OPENAI_API_KEY")
|
21 |
+
|
22 |
+
openai.api_key = key
|
23 |
+
|
24 |
+
def set_reddit_keys(client_id, client_secret, user_agent):
|
25 |
+
|
26 |
+
global reddit
|
27 |
+
|
28 |
+
# If any of the keys are empty, use the environment variables
|
29 |
+
if [client_id, client_secret, user_agent] == ["", "", ""]:
|
30 |
+
client_id = os.environ.get("REDDIT_CLIENT_ID")
|
31 |
+
client_secret = os.environ.get("REDDIT_CLIENT_SECRET")
|
32 |
+
user_agent = os.environ.get("REDDIT_USER_AGENT")
|
33 |
+
|
34 |
+
reddit = praw.Reddit(client_id=client_id, client_secret=client_secret, user_agent=user_agent)
|
35 |
+
|
36 |
+
def generate_topics(query, model="gpt-3.5-turbo"):
|
37 |
+
|
38 |
+
messages = [
|
39 |
+
{"role": "user", "content": f"Take this query '{query}' and return a list of 10 simple to understand topics (4 words or less) to input in Search so it returns good results."},
|
40 |
+
]
|
41 |
+
|
42 |
+
response = openai.ChatCompletion.create(
|
43 |
+
model=model,
|
44 |
+
messages=messages,
|
45 |
+
temperature=0
|
46 |
+
)
|
47 |
+
|
48 |
+
response_message = response["choices"][0]["message"]["content"]
|
49 |
+
|
50 |
+
topics = re.sub(r'^\d+\.\s*', '', response_message, flags=re.MULTILINE).split("\n")
|
51 |
+
|
52 |
+
# Post-processing GPT output
|
53 |
+
|
54 |
+
topics = [topic.strip() for topic in topics]
|
55 |
+
|
56 |
+
topics = [topic[1:-1] if (topic.startswith('"') and topic.endswith('"')) or (topic.startswith("'") and topic.endswith("'")) else topic for topic in topics]
|
57 |
+
|
58 |
+
topics = [re.sub(r'[^a-zA-Z0-9\s]', ' ', topic) for topic in topics]
|
59 |
+
|
60 |
+
return topics
|
61 |
+
|
62 |
+
def get_relevant_comments(topics):
|
63 |
+
|
64 |
+
global reddit
|
65 |
+
|
66 |
+
comments = []
|
67 |
+
|
68 |
+
for topic in tqdm(topics):
|
69 |
+
for post in reddit.subreddit("all").search(
|
70 |
+
topic, limit=10):
|
71 |
+
|
72 |
+
post.comment_limit = 20
|
73 |
+
post.comment_sort = "top"
|
74 |
+
|
75 |
+
# Top level comments only
|
76 |
+
post.comments.replace_more(limit=0)
|
77 |
+
|
78 |
+
for comment in post.comments:
|
79 |
+
author = comment.author.name if comment.author else '[deleted]'
|
80 |
+
comments.append([post.id, comment.id, post.subreddit.display_name, post.title, author, comment.body])
|
81 |
+
|
82 |
+
comments = pd.DataFrame(comments,columns=['source', 'comment_id', 'subreddit', 'title', 'author', 'text'])
|
83 |
+
|
84 |
+
# Drop empty texts or ["deleted"] texts
|
85 |
+
comments = comments[comments['text'].str.len() > 0]
|
86 |
+
comments = comments[comments['text'] != "[deleted]"]
|
87 |
+
|
88 |
+
# Drop comments with None authors
|
89 |
+
comments = comments[comments['author'] != "AutoModerator"]
|
90 |
+
|
91 |
+
# Drop duplicate ids
|
92 |
+
comments = comments.drop_duplicates(subset=['source'])
|
93 |
+
|
94 |
+
return comments
|
95 |
+
|
96 |
+
def construct_retriever(comments, k=20):
|
97 |
+
|
98 |
+
# Convert comments dataframe to a dictionary
|
99 |
+
comments = comments.to_dict('records')
|
100 |
+
|
101 |
+
# Convert comments["text"] to a list of strings
|
102 |
+
texts = [comment["title"] + " " + comment["text"] + " " + comment["subreddit"] for comment in comments]
|
103 |
+
|
104 |
+
db = Chroma.from_texts(texts, TensorflowHubEmbeddings(model_url="https://tfhub.dev/google/universal-sentence-encoder/4"), metadatas=[{"source": comment["source"], "comment_id": comment["comment_id"], "author": comment["author"]} for comment in comments])
|
105 |
+
|
106 |
+
retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": k})
|
107 |
+
|
108 |
+
return retriever
|
109 |
+
|
110 |
+
def construct_bot(retriever):
|
111 |
+
bot = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), retriever, return_source_documents=True)
|
112 |
+
return bot
|
113 |
+
|
114 |
+
def get_response(query, chat_history):
|
115 |
+
response = bot({"question": query, "chat_history": chat_history})
|
116 |
+
return response
|
117 |
+
|
118 |
+
def restart():
|
119 |
+
|
120 |
+
global chat_history
|
121 |
+
global bot
|
122 |
+
|
123 |
+
chat_history = []
|
124 |
+
bot = None
|
125 |
+
|
126 |
+
print("Chat history and bot knowledge has been cleared!")
|
127 |
+
|
128 |
+
return None
|
129 |
+
|
130 |
+
def main(query):
|
131 |
+
|
132 |
+
global chat_history
|
133 |
+
global bot
|
134 |
+
|
135 |
+
if chat_history == []:
|
136 |
+
print("Bot knowledge has not been initialised yet! Generating topics...")
|
137 |
+
topics = generate_topics(query)
|
138 |
+
|
139 |
+
print("Fetching relevant comments...")
|
140 |
+
comments = get_relevant_comments(topics)
|
141 |
+
|
142 |
+
print("Embedding relevant comments...")
|
143 |
+
retriever = construct_retriever(comments)
|
144 |
+
|
145 |
+
print("Educating bot...")
|
146 |
+
bot = construct_bot(retriever)
|
147 |
+
|
148 |
+
print("Bot has been constructed and is ready to use!")
|
149 |
+
|
150 |
+
response = get_response(query, chat_history)
|
151 |
+
|
152 |
+
answer, source_documents = response["answer"], response["source_documents"]
|
153 |
+
|
154 |
+
print(source_documents)
|
155 |
+
|
156 |
+
chat_history.append((query, answer))
|
157 |
+
|
158 |
+
return "", chat_history
|
159 |
+
|
160 |
+
# Testing only!
|
161 |
+
set_openai_key("")
|
162 |
+
set_reddit_keys("", "", "")
|
163 |
+
|
164 |
+
with gr.Blocks() as demo:
|
165 |
+
chat_bot = gr.Chatbot()
|
166 |
+
query = gr.Textbox()
|
167 |
+
clear = gr.Button("Clear")
|
168 |
+
|
169 |
+
query.submit(main, [query], [query, chat_bot])
|
170 |
+
clear.click(restart, None, chat_bot, queue=False)
|
171 |
+
|
172 |
+
demo.launch()
|
main.ipynb
CHANGED
@@ -2,10 +2,11 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
|
|
9 |
"import pandas as pd\n",
|
10 |
"import openai\n",
|
11 |
"import praw\n",
|
@@ -18,24 +19,19 @@
|
|
18 |
},
|
19 |
{
|
20 |
"cell_type": "code",
|
21 |
-
"execution_count":
|
22 |
"metadata": {},
|
23 |
"outputs": [],
|
24 |
"source": [
|
25 |
-
"from langchain.embeddings.openai import OpenAIEmbeddings\n",
|
26 |
-
"from langchain.text_splitter import CharacterTextSplitter\n",
|
27 |
"from langchain.vectorstores import Chroma\n",
|
28 |
-
"from langchain.docstore.document import Document\n",
|
29 |
-
"from langchain.prompts import PromptTemplate\n",
|
30 |
-
"from langchain.indexes.vectorstore import VectorstoreIndexCreator\n",
|
31 |
-
"from langchain.chains.qa_with_sources import load_qa_with_sources_chain\n",
|
32 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
33 |
-
"from langchain.llms import OpenAI"
|
|
|
34 |
]
|
35 |
},
|
36 |
{
|
37 |
"cell_type": "code",
|
38 |
-
"execution_count":
|
39 |
"metadata": {},
|
40 |
"outputs": [],
|
41 |
"source": [
|
@@ -56,12 +52,12 @@
|
|
56 |
},
|
57 |
{
|
58 |
"cell_type": "code",
|
59 |
-
"execution_count":
|
60 |
"metadata": {},
|
61 |
"outputs": [],
|
62 |
"source": [
|
63 |
"query = '''\n",
|
64 |
-
"
|
65 |
"'''"
|
66 |
]
|
67 |
},
|
@@ -75,7 +71,7 @@
|
|
75 |
},
|
76 |
{
|
77 |
"cell_type": "code",
|
78 |
-
"execution_count":
|
79 |
"metadata": {},
|
80 |
"outputs": [],
|
81 |
"source": [
|
@@ -95,40 +91,44 @@
|
|
95 |
"\n",
|
96 |
" topics = re.sub(r'^\\d+\\.\\s*', '', response_message, flags=re.MULTILINE).split(\"\\n\")\n",
|
97 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
" return topics"
|
99 |
]
|
100 |
},
|
101 |
{
|
102 |
"cell_type": "code",
|
103 |
-
"execution_count":
|
104 |
"metadata": {},
|
105 |
"outputs": [
|
106 |
{
|
107 |
"data": {
|
108 |
"text/plain": [
|
109 |
-
"['
|
110 |
-
" '
|
111 |
-
" '
|
112 |
-
" '
|
113 |
-
" '
|
114 |
-
" '
|
115 |
-
" '
|
116 |
-
" '
|
117 |
-
" '
|
118 |
-
" '
|
119 |
]
|
120 |
},
|
121 |
-
"execution_count":
|
122 |
"metadata": {},
|
123 |
"output_type": "execute_result"
|
124 |
}
|
125 |
],
|
126 |
"source": [
|
127 |
"topics = generate_topics(query)\n",
|
128 |
-
"topics = [topic.strip() for topic in topics]\n",
|
129 |
-
"topics = [topic[1:-1] if (topic.startswith('\"') and topic.endswith('\"')) or (topic.startswith(\"'\") and topic.endswith(\"'\")) else topic for topic in topics]\n",
|
130 |
-
"\n",
|
131 |
-
"topics = [re.sub(r'[^a-zA-Z0-9\\s]', ' ', topic) for topic in topics]\n",
|
132 |
"\n",
|
133 |
"topics"
|
134 |
]
|
@@ -138,42 +138,62 @@
|
|
138 |
"cell_type": "markdown",
|
139 |
"metadata": {},
|
140 |
"source": [
|
141 |
-
"## Relevant
|
142 |
]
|
143 |
},
|
144 |
{
|
145 |
"cell_type": "code",
|
146 |
-
"execution_count":
|
147 |
"metadata": {},
|
148 |
"outputs": [],
|
149 |
"source": [
|
150 |
-
"
|
|
|
151 |
"\n",
|
152 |
-
"for topic in topics:\n",
|
153 |
-
"
|
154 |
-
"
|
|
|
|
|
|
|
155 |
"\n",
|
156 |
-
"
|
|
|
157 |
"\n",
|
158 |
-
"
|
159 |
-
"
|
160 |
-
"
|
161 |
"\n",
|
162 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
]
|
164 |
},
|
165 |
{
|
166 |
"cell_type": "code",
|
167 |
-
"execution_count":
|
168 |
"metadata": {},
|
169 |
-
"outputs": [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
170 |
"source": [
|
171 |
-
"
|
172 |
-
"posts = posts[posts['text'].str.len() > 0]\n",
|
173 |
-
"posts = posts[posts['text'] != \"[deleted]\"]\n",
|
174 |
-
"\n",
|
175 |
-
"# Drop duplicate ids\n",
|
176 |
-
"posts = posts.drop_duplicates(subset=['source'])"
|
177 |
]
|
178 |
},
|
179 |
{
|
@@ -186,26 +206,28 @@
|
|
186 |
},
|
187 |
{
|
188 |
"cell_type": "code",
|
189 |
-
"execution_count":
|
190 |
"metadata": {},
|
191 |
"outputs": [],
|
192 |
"source": [
|
193 |
-
"
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
"
|
203 |
-
"
|
|
|
|
|
204 |
]
|
205 |
},
|
206 |
{
|
207 |
"cell_type": "code",
|
208 |
-
"execution_count":
|
209 |
"metadata": {},
|
210 |
"outputs": [
|
211 |
{
|
@@ -217,49 +239,51 @@
|
|
217 |
}
|
218 |
],
|
219 |
"source": [
|
220 |
-
"
|
221 |
]
|
222 |
},
|
223 |
{
|
224 |
"cell_type": "code",
|
225 |
-
"execution_count":
|
226 |
"metadata": {},
|
227 |
"outputs": [],
|
228 |
"source": [
|
229 |
-
"
|
|
|
|
|
230 |
]
|
231 |
},
|
232 |
{
|
233 |
"cell_type": "code",
|
234 |
-
"execution_count":
|
235 |
"metadata": {},
|
236 |
"outputs": [],
|
237 |
"source": [
|
238 |
-
"
|
|
|
239 |
]
|
240 |
},
|
241 |
{
|
242 |
"cell_type": "code",
|
243 |
-
"execution_count":
|
244 |
"metadata": {},
|
245 |
"outputs": [],
|
246 |
"source": [
|
247 |
-
"chat_history = []\n",
|
248 |
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
249 |
]
|
250 |
},
|
251 |
{
|
252 |
"cell_type": "code",
|
253 |
-
"execution_count":
|
254 |
"metadata": {},
|
255 |
"outputs": [
|
256 |
{
|
257 |
"data": {
|
258 |
"text/plain": [
|
259 |
-
"
|
260 |
]
|
261 |
},
|
262 |
-
"execution_count":
|
263 |
"metadata": {},
|
264 |
"output_type": "execute_result"
|
265 |
}
|
@@ -270,25 +294,25 @@
|
|
270 |
},
|
271 |
{
|
272 |
"cell_type": "code",
|
273 |
-
"execution_count":
|
274 |
"metadata": {},
|
275 |
"outputs": [
|
276 |
{
|
277 |
"data": {
|
278 |
"text/plain": [
|
279 |
-
"[Document(page_content='
|
280 |
-
" Document(page_content='
|
281 |
-
" Document(page_content='
|
282 |
-
" Document(page_content='
|
283 |
-
" Document(page_content='
|
284 |
-
" Document(page_content='
|
285 |
-
" Document(page_content
|
286 |
-
" Document(page_content='
|
287 |
-
" Document(page_content='
|
288 |
-
" Document(page_content
|
289 |
]
|
290 |
},
|
291 |
-
"execution_count":
|
292 |
"metadata": {},
|
293 |
"output_type": "execute_result"
|
294 |
}
|
@@ -296,66 +320,6 @@
|
|
296 |
"source": [
|
297 |
"result[\"source_documents\"]"
|
298 |
]
|
299 |
-
},
|
300 |
-
{
|
301 |
-
"cell_type": "code",
|
302 |
-
"execution_count": 297,
|
303 |
-
"metadata": {},
|
304 |
-
"outputs": [],
|
305 |
-
"source": [
|
306 |
-
"new_query = \"Are most pizzas in Stockholm gluten-free?\"\n",
|
307 |
-
"\n",
|
308 |
-
"new_result = qa({\"question\": query + new_query, \"chat_history\": result[\"chat_history\"]})"
|
309 |
-
]
|
310 |
-
},
|
311 |
-
{
|
312 |
-
"cell_type": "code",
|
313 |
-
"execution_count": 298,
|
314 |
-
"metadata": {},
|
315 |
-
"outputs": [
|
316 |
-
{
|
317 |
-
"data": {
|
318 |
-
"text/plain": [
|
319 |
-
"' Most places have gluten free pizza. Only Giro and Meno Male have good gluten free pizza. Pretty much all pizzerias serve GF pizza, but most use store-bought pizza bases. Stockholm is like the capital of gluten intolerance so I think you will be fine in most (central) places. Meno Male is a great option for Napoletana style Italian pizza, they have a few restaurants throughout Stockholm.'"
|
320 |
-
]
|
321 |
-
},
|
322 |
-
"execution_count": 298,
|
323 |
-
"metadata": {},
|
324 |
-
"output_type": "execute_result"
|
325 |
-
}
|
326 |
-
],
|
327 |
-
"source": [
|
328 |
-
"new_result[\"answer\"]"
|
329 |
-
]
|
330 |
-
},
|
331 |
-
{
|
332 |
-
"cell_type": "code",
|
333 |
-
"execution_count": 299,
|
334 |
-
"metadata": {},
|
335 |
-
"outputs": [
|
336 |
-
{
|
337 |
-
"data": {
|
338 |
-
"text/plain": [
|
339 |
-
"[Document(page_content='Gluten free pizza in Stockholm? Any pizzeria', metadata={'source': 'irirsw2', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
340 |
-
" Document(page_content='Gluten free pizza in Stockholm? Thanks!', metadata={'source': 'irimidh', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
341 |
-
" Document(page_content='Gluten free pizza in Stockholm? Crispy pizza', metadata={'source': 'iritand', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
342 |
-
" Document(page_content='Gluten free pizza in Stockholm? Most places have gluten free pizza. Only Giro and Meno Male have good gluten free pizza.', metadata={'source': 'irm0583', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
343 |
-
" Document(page_content='Var kan jag köpa glutenfri kebab pizza?? (Stockholm) Quick search:\\n\\n[https://www.findmeglutenfree.com/se/stockholm/pizza](https://www.findmeglutenfree.com/se/stockholm/pizza)', metadata={'source': 'jef789l', 'post_id': '127mh90', 'author': 'TheRoyalStork'}),\n",
|
344 |
-
" Document(page_content='Var kan jag köpa glutenfri kebab pizza?? (Stockholm) Tack!', metadata={'source': 'jeg12d4', 'post_id': '127mh90', 'author': 'TheRoyalStork'}),\n",
|
345 |
-
" Document(page_content='Gluten free pizza in Stockholm? Pretty much all pizzerias serve GF pizza, but most use store-bought pizza bases.', metadata={'source': 'irjap32', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
346 |
-
" Document(page_content='Gluten free pizza in Stockholm? Stockholm is like the capital of gluten intolerance so I think you will be fine in most (central) places', metadata={'source': 'irioh1a', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
347 |
-
" Document(page_content='Gluten free pizza in Stockholm? Meno Male. Super nice Napoletana style Italian pizza, they have a few restaurants throughout Stockholm', metadata={'source': 'irilvan', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'}),\n",
|
348 |
-
" Document(page_content='Gluten free pizza in Stockholm? Has anyone here with a gluten intolerance diagnosis eaten a gf pizza from a pizzeria? Was it without issues? I imagine that all the flour handling, the oven and the spade would contaminate any gf pizza with gluten.', metadata={'source': 'irjmdif', 'post_id': 'xyti59', 'author': 'Head-Commission-8222'})]"
|
349 |
-
]
|
350 |
-
},
|
351 |
-
"execution_count": 299,
|
352 |
-
"metadata": {},
|
353 |
-
"output_type": "execute_result"
|
354 |
-
}
|
355 |
-
],
|
356 |
-
"source": [
|
357 |
-
"new_result[\"source_documents\"]"
|
358 |
-
]
|
359 |
}
|
360 |
],
|
361 |
"metadata": {
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 83,
|
6 |
"metadata": {},
|
7 |
"outputs": [],
|
8 |
"source": [
|
9 |
+
"from tqdm import tqdm\n",
|
10 |
"import pandas as pd\n",
|
11 |
"import openai\n",
|
12 |
"import praw\n",
|
|
|
19 |
},
|
20 |
{
|
21 |
"cell_type": "code",
|
22 |
+
"execution_count": 84,
|
23 |
"metadata": {},
|
24 |
"outputs": [],
|
25 |
"source": [
|
|
|
|
|
26 |
"from langchain.vectorstores import Chroma\n",
|
|
|
|
|
|
|
|
|
27 |
"from langchain.chains import ConversationalRetrievalChain\n",
|
28 |
+
"from langchain.llms import OpenAI\n",
|
29 |
+
"from langchain.embeddings import TensorflowHubEmbeddings"
|
30 |
]
|
31 |
},
|
32 |
{
|
33 |
"cell_type": "code",
|
34 |
+
"execution_count": 85,
|
35 |
"metadata": {},
|
36 |
"outputs": [],
|
37 |
"source": [
|
|
|
52 |
},
|
53 |
{
|
54 |
"cell_type": "code",
|
55 |
+
"execution_count": 86,
|
56 |
"metadata": {},
|
57 |
"outputs": [],
|
58 |
"source": [
|
59 |
"query = '''\n",
|
60 |
+
"I got laid off last week. How should I go about finding a new job?\n",
|
61 |
"'''"
|
62 |
]
|
63 |
},
|
|
|
71 |
},
|
72 |
{
|
73 |
"cell_type": "code",
|
74 |
+
"execution_count": 87,
|
75 |
"metadata": {},
|
76 |
"outputs": [],
|
77 |
"source": [
|
|
|
91 |
"\n",
|
92 |
" topics = re.sub(r'^\\d+\\.\\s*', '', response_message, flags=re.MULTILINE).split(\"\\n\")\n",
|
93 |
"\n",
|
94 |
+
" # Post-processing GPT output\n",
|
95 |
+
"\n",
|
96 |
+
" topics = [topic.strip() for topic in topics]\n",
|
97 |
+
"\n",
|
98 |
+
" topics = [topic[1:-1] if (topic.startswith('\"') and topic.endswith('\"')) or (topic.startswith(\"'\") and topic.endswith(\"'\")) else topic for topic in topics]\n",
|
99 |
+
"\n",
|
100 |
+
" topics = [re.sub(r'[^a-zA-Z0-9\\s]', ' ', topic) for topic in topics]\n",
|
101 |
+
"\n",
|
102 |
" return topics"
|
103 |
]
|
104 |
},
|
105 |
{
|
106 |
"cell_type": "code",
|
107 |
+
"execution_count": 88,
|
108 |
"metadata": {},
|
109 |
"outputs": [
|
110 |
{
|
111 |
"data": {
|
112 |
"text/plain": [
|
113 |
+
"['Job search tips',\n",
|
114 |
+
" 'Resume writing advice',\n",
|
115 |
+
" 'Networking strategies',\n",
|
116 |
+
" 'Interview preparation tips',\n",
|
117 |
+
" 'Online job boards',\n",
|
118 |
+
" 'Career counseling services',\n",
|
119 |
+
" 'Job fairs near me',\n",
|
120 |
+
" 'Freelance opportunities',\n",
|
121 |
+
" 'Remote work options',\n",
|
122 |
+
" 'Industry specific job listings']"
|
123 |
]
|
124 |
},
|
125 |
+
"execution_count": 88,
|
126 |
"metadata": {},
|
127 |
"output_type": "execute_result"
|
128 |
}
|
129 |
],
|
130 |
"source": [
|
131 |
"topics = generate_topics(query)\n",
|
|
|
|
|
|
|
|
|
132 |
"\n",
|
133 |
"topics"
|
134 |
]
|
|
|
138 |
"cell_type": "markdown",
|
139 |
"metadata": {},
|
140 |
"source": [
|
141 |
+
"## Relevant Comments Retrieval"
|
142 |
]
|
143 |
},
|
144 |
{
|
145 |
"cell_type": "code",
|
146 |
+
"execution_count": 89,
|
147 |
"metadata": {},
|
148 |
"outputs": [],
|
149 |
"source": [
|
150 |
+
"def get_relevant_subreddits(topics):\n",
|
151 |
+
" comments = []\n",
|
152 |
"\n",
|
153 |
+
" for topic in tqdm(topics):\n",
|
154 |
+
" for post in reddit.subreddit(\"all\").search(\n",
|
155 |
+
" topic, limit=10):\n",
|
156 |
+
" \n",
|
157 |
+
" post.comment_limit = 20\n",
|
158 |
+
" post.comment_sort = \"top\"\n",
|
159 |
"\n",
|
160 |
+
" # Top level comments only\n",
|
161 |
+
" post.comments.replace_more(limit=0)\n",
|
162 |
"\n",
|
163 |
+
" for comment in post.comments:\n",
|
164 |
+
" author = comment.author.name if comment.author else '[deleted]'\n",
|
165 |
+
" comments.append([post.id, comment.id, post.subreddit.display_name, post.title, author, comment.body])\n",
|
166 |
"\n",
|
167 |
+
" comments = pd.DataFrame(comments,columns=['source', 'comment_id', 'subreddit', 'title', 'author', 'text'])\n",
|
168 |
+
"\n",
|
169 |
+
" # Drop empty texts or [\"deleted\"] texts\n",
|
170 |
+
" comments = comments[comments['text'].str.len() > 0]\n",
|
171 |
+
" comments = comments[comments['text'] != \"[deleted]\"]\n",
|
172 |
+
"\n",
|
173 |
+
" # Drop comments with None authors\n",
|
174 |
+
" comments = comments[comments['author'] != \"AutoModerator\"]\n",
|
175 |
+
"\n",
|
176 |
+
" # Drop duplicate ids\n",
|
177 |
+
" comments = comments.drop_duplicates(subset=['source'])\n",
|
178 |
+
"\n",
|
179 |
+
" return comments"
|
180 |
]
|
181 |
},
|
182 |
{
|
183 |
"cell_type": "code",
|
184 |
+
"execution_count": 90,
|
185 |
"metadata": {},
|
186 |
+
"outputs": [
|
187 |
+
{
|
188 |
+
"name": "stderr",
|
189 |
+
"output_type": "stream",
|
190 |
+
"text": [
|
191 |
+
"100%|██████████| 10/10 [00:41<00:00, 4.13s/it]\n"
|
192 |
+
]
|
193 |
+
}
|
194 |
+
],
|
195 |
"source": [
|
196 |
+
"comments = get_relevant_subreddits(topics)"
|
|
|
|
|
|
|
|
|
|
|
197 |
]
|
198 |
},
|
199 |
{
|
|
|
206 |
},
|
207 |
{
|
208 |
"cell_type": "code",
|
209 |
+
"execution_count": 91,
|
210 |
"metadata": {},
|
211 |
"outputs": [],
|
212 |
"source": [
|
213 |
+
"def construct_retriever(comments, k=10):\n",
|
214 |
+
"\n",
|
215 |
+
" # Convert comments dataframe to a dictionary\n",
|
216 |
+
" comments = comments.to_dict('records')\n",
|
217 |
+
"\n",
|
218 |
+
" # Convert comments[\"text\"] to a list of strings\n",
|
219 |
+
" texts = [comment[\"title\"] + \" \" + comment[\"text\"] for comment in comments]\n",
|
220 |
+
"\n",
|
221 |
+
" db = Chroma.from_texts(texts, TensorflowHubEmbeddings(model_url=\"https://tfhub.dev/google/universal-sentence-encoder/4\"), metadatas=[{\"source\": comment[\"source\"], \"comment_id\": comment[\"comment_id\"], \"author\": comment[\"author\"]} for comment in comments])\n",
|
222 |
+
"\n",
|
223 |
+
" retriever = db.as_retriever(search_type=\"similarity\", search_kwargs={\"k\": k})\n",
|
224 |
+
"\n",
|
225 |
+
" return retriever"
|
226 |
]
|
227 |
},
|
228 |
{
|
229 |
"cell_type": "code",
|
230 |
+
"execution_count": 92,
|
231 |
"metadata": {},
|
232 |
"outputs": [
|
233 |
{
|
|
|
239 |
}
|
240 |
],
|
241 |
"source": [
|
242 |
+
"retriever = construct_retriever(comments)"
|
243 |
]
|
244 |
},
|
245 |
{
|
246 |
"cell_type": "code",
|
247 |
+
"execution_count": 93,
|
248 |
"metadata": {},
|
249 |
"outputs": [],
|
250 |
"source": [
|
251 |
+
"def construct_qa(retriever):\n",
|
252 |
+
" qa = ConversationalRetrievalChain.from_llm(OpenAI(temperature=0), retriever, return_source_documents=True)\n",
|
253 |
+
" return qa"
|
254 |
]
|
255 |
},
|
256 |
{
|
257 |
"cell_type": "code",
|
258 |
+
"execution_count": 95,
|
259 |
"metadata": {},
|
260 |
"outputs": [],
|
261 |
"source": [
|
262 |
+
"chat_history = []\n",
|
263 |
+
"qa = construct_qa(retriever)"
|
264 |
]
|
265 |
},
|
266 |
{
|
267 |
"cell_type": "code",
|
268 |
+
"execution_count": 96,
|
269 |
"metadata": {},
|
270 |
"outputs": [],
|
271 |
"source": [
|
|
|
272 |
"result = qa({\"question\": query, \"chat_history\": chat_history})"
|
273 |
]
|
274 |
},
|
275 |
{
|
276 |
"cell_type": "code",
|
277 |
+
"execution_count": 97,
|
278 |
"metadata": {},
|
279 |
"outputs": [
|
280 |
{
|
281 |
"data": {
|
282 |
"text/plain": [
|
283 |
+
"\" Start by updating your resume and LinkedIn profile to reflect your current skills and experience. Then, look for job postings on job boards like Indeed and LinkedIn, as well as industry-specific job listings. You can also reach out to people you know professionally, such as former colleagues or mentors, to see if they know of any job opportunities. Finally, consider using a staffing agency to help you find a job that's a great fit. Good luck!\""
|
284 |
]
|
285 |
},
|
286 |
+
"execution_count": 97,
|
287 |
"metadata": {},
|
288 |
"output_type": "execute_result"
|
289 |
}
|
|
|
294 |
},
|
295 |
{
|
296 |
"cell_type": "code",
|
297 |
+
"execution_count": 98,
|
298 |
"metadata": {},
|
299 |
"outputs": [
|
300 |
{
|
301 |
"data": {
|
302 |
"text/plain": [
|
303 |
+
"[Document(page_content='Request: job search tips WAITING FOR THE RIGHT JOB WITH DECENT PAY > QUICK JOB', metadata={'source': '122mdcc', 'comment_id': 'jdv57nv', 'author': 'notenoughbeds'}),\n",
|
304 |
+
" Document(page_content='Where to look for jobs? Online job boards? LinkedIn\\n\\nIndeed\\n\\nIf you need a... beginners job. Craigslist has a work section', metadata={'source': 'uj35l8', 'comment_id': 'i7gbuat', 'author': 'No-Statement-3019'}),\n",
|
305 |
+
" Document(page_content='Job search tips in Canada You’re selling yourself way too cheap. Look for a senior position and then people will want to hire you more too', metadata={'source': '1156zsa', 'comment_id': 'j91mo9r', 'author': 'pxpxy'}),\n",
|
306 |
+
" Document(page_content='Job search tips Just apply, covid has made super easy for RTs to get jobs', metadata={'source': '11uekx5', 'comment_id': 'jcnvfnx', 'author': 'Crass_Cameron'}),\n",
|
307 |
+
" Document(page_content='Looking for freelance opportunities >\\tWhere would be the best place to look?\\n\\nPeople you’ve worked with professionally before. Either for jobs for them or for them to refer you to people they know.', metadata={'source': '11q0h1u', 'comment_id': 'jc0w8rp', 'author': 'dataguy24'}),\n",
|
308 |
+
" Document(page_content='Did Career Counseling services help you land a job after you graduated? I found it to be helpful in polishing my resume', metadata={'source': 'xbql4p', 'comment_id': 'io1141n', 'author': 'avo_cado'}),\n",
|
309 |
+
" Document(page_content=\"Does anyone have any good job search tips? I recommend keeping an up to date LinkedIn profile that indicates you're actively searching for roles. I've also had luck with Indeed.\\n\\nDepending on your field, I recommend a staffing agency. They can vouch for you and place you with a company that's a great fit.\", metadata={'source': '134wnu5', 'comment_id': 'jih9zzp', 'author': 'Carolinablue87'}),\n",
|
310 |
+
" Document(page_content='How did you find your job? Did someone you know tell you about it? Job listing aggregator (ie indeed, simplyhired)? Industry specific job listings? Somehow else? strong institutional connection between the employer and my grad school', metadata={'source': '88vquy', 'comment_id': 'dwnknrk', 'author': 'lollersauce914'}),\n",
|
311 |
+
" Document(page_content='I (17 F) am really confused about my career. Is there any way I can get career counseling services for free? Any resources or tips would be appreciated too! What are you confused about? I have no qualifications, just a working professional of 8 years. \\nDoes your work have a subreddit or do you have a mentor at work that you could reach out to?', metadata={'source': '10jv7c4', 'comment_id': 'j5mtxb7', 'author': 'bhop02'}),\n",
|
312 |
+
" Document(page_content=\"Debating whether or not to pay for access to sites that have specific job listings - any successes or tales of caution? I never had to look for my jobs through those sites. I don't think they can let the companies post job positions exclusively on their sites only. You can always find the same job post somewhere else if you keep trying. I hope this is helpful, good luck!\", metadata={'source': '11kppd0', 'comment_id': 'jb8qvjo', 'author': 'Whitney-Sweet'})]"
|
313 |
]
|
314 |
},
|
315 |
+
"execution_count": 98,
|
316 |
"metadata": {},
|
317 |
"output_type": "execute_result"
|
318 |
}
|
|
|
320 |
"source": [
|
321 |
"result[\"source_documents\"]"
|
322 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
}
|
324 |
],
|
325 |
"metadata": {
|
semantic_search.py
DELETED
@@ -1,39 +0,0 @@
|
|
1 |
-
from sklearn.neighbors import NearestNeighbors
|
2 |
-
import tensorflow_hub as hub
|
3 |
-
import numpy as np
|
4 |
-
|
5 |
-
class SemanticSearch:
|
6 |
-
|
7 |
-
def __init__(self):
|
8 |
-
self.use = hub.load('https://tfhub.dev/google/universal-sentence-encoder/4')
|
9 |
-
self.fitted = False
|
10 |
-
|
11 |
-
|
12 |
-
def fit(self, data, batch=1000, n_neighbors=5):
|
13 |
-
print(f"Fitting with n={n_neighbors}...")
|
14 |
-
self.data = data
|
15 |
-
self.embeddings = self.get_text_embedding(data, batch=batch)
|
16 |
-
n_neighbors = min(n_neighbors, len(self.embeddings))
|
17 |
-
self.nn = NearestNeighbors(n_neighbors=n_neighbors)
|
18 |
-
self.nn.fit(self.embeddings)
|
19 |
-
self.fitted = True
|
20 |
-
|
21 |
-
|
22 |
-
def __call__(self, text, return_data=True):
|
23 |
-
inp_emb = self.use([text])
|
24 |
-
distances, neighbors = self.nn.kneighbors(inp_emb, return_distance=True)
|
25 |
-
|
26 |
-
if return_data:
|
27 |
-
return [self.data[i] for i in neighbors[0]], distances
|
28 |
-
else:
|
29 |
-
return neighbors[0], distances
|
30 |
-
|
31 |
-
|
32 |
-
def get_text_embedding(self, texts, batch=1000):
|
33 |
-
embeddings = []
|
34 |
-
for i in range(0, len(texts), batch):
|
35 |
-
text_batch = texts[i:(i+batch)]
|
36 |
-
emb_batch = self.use(text_batch)
|
37 |
-
embeddings.append(emb_batch)
|
38 |
-
embeddings = np.vstack(embeddings)
|
39 |
-
return embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|