star_nox commited on
Commit
c1d7a66
·
1 Parent(s): e5f680f

added context retrieval

Browse files
Files changed (3) hide show
  1. __pycache__/retrieval.cpython-310.pyc +0 -0
  2. app.py +11 -3
  3. retrieval.py +66 -0
__pycache__/retrieval.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
2
-
3
  import gradio as gr
4
-
5
  from text_generation import Client, InferenceAPIClient
6
 
 
 
 
 
7
  openchat_preprompt = (
8
  "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
9
  "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
@@ -110,7 +112,13 @@ def predict(
110
  for i in range(0, len(history) - 1, 2)
111
  ]
112
  yield chat, history
113
-
 
 
 
 
 
 
114
 
115
  def reset_textbox():
116
  return gr.update(value="")
 
1
  import os
 
2
  import gradio as gr
 
3
  from text_generation import Client, InferenceAPIClient
4
 
5
+ import retrieval
6
+
7
+ NUM_ANSWERS_GENERATED = 3
8
+
9
  openchat_preprompt = (
10
  "\n<human>: Hi!\n<bot>: My name is Bot, model version is 0.15, part of an open-source kit for "
11
  "fine-tuning new bots! I was created by Together, LAION, and Ontocord.ai and the open-source "
 
112
  for i in range(0, len(history) - 1, 2)
113
  ]
114
  yield chat, history
115
+
116
+ # add context retrieval part here
117
+ ta = retrieval.Retrieval()
118
+ ta._load_pinecone_vectorstore()
119
+ question = inputs
120
+ top_context_list = ta.retrieve_contexts_from_pinecone(user_question=question, topk=NUM_ANSWERS_GENERATED)
121
+ print(top_context_list)
122
 
123
  def reset_textbox():
124
  return gr.update(value="")
retrieval.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import pathlib
4
+ import sys
5
+ import time
6
+ from typing import Any, Dict, List
7
+
8
+ import pinecone # cloud-hosted vector database for context retrieval
9
+ # for vector search
10
+ from langchain.embeddings import HuggingFaceEmbeddings
11
+ from langchain.vectorstores import Pinecone
12
+
13
+ from dotenv import load_dotenv
14
+
15
+ from PIL import Image
16
+ from transformers import (AutoModelForSequenceClassification, AutoTokenizer, GPT2Tokenizer, OPTForCausalLM, T5ForConditionalGeneration)
17
+
18
+ PINECONE_API_KEY="insert your pinecone api key here"
19
+
20
+ class Retrieval:
21
+ def __init__(self,
22
+ device='cuda',
23
+ use_clip=True):
24
+
25
+ self.user_question = ''
26
+ self.max_text_length = None
27
+ self.pinecone_index_name = 'uiuc-chatbot' # uiuc-chatbot-v2
28
+ self.use_clip = use_clip
29
+
30
+ # init parameters
31
+ self.device = device
32
+ self.num_answers_generated = 3
33
+
34
+ self.vectorstore = None
35
+
36
+ def _load_pinecone_vectorstore(self,):
37
+ model_name = "intfloat/e5-large" # best text embedding model. 1024 dims.
38
+ pincecone_index = pinecone.Index("uiuc-chatbot")
39
+ embeddings = HuggingFaceEmbeddings(model_name=model_name)
40
+ #pinecone.init(api_key=os.environ['PINECONE_API_KEY'], environment="us-west1-gcp")
41
+ pinecone.init(api_key=PINECONE_API_KEY, environment="us-west1-gcp")
42
+
43
+ print(pinecone.list_indexes())
44
+
45
+ self.vectorstore = Pinecone(index=pincecone_index, embedding_function=embeddings.embed_query, text_key="text")
46
+
47
+
48
+ def retrieve_contexts_from_pinecone(self, user_question: str, topk: int = None) -> List[Any]:
49
+ '''
50
+ Invoke Pinecone for vector search. These vector databases are created in the notebook `data_formatting_patel.ipynb` and `data_formatting_student_notes.ipynb`.
51
+ Returns a list of LangChain Documents. They have properties: `doc.page_content`: str, doc.metadata['page_number']: int, doc.metadata['textbook_name']: str.
52
+ '''
53
+ print("USER QUESTION: ", user_question)
54
+ print("TOPK: ", topk)
55
+
56
+
57
+ if topk is None:
58
+ topk = self.num_answers_generated
59
+
60
+ # similarity search
61
+ top_context_list = self.vectorstore.similarity_search(user_question, k=topk)
62
+
63
+ # add the source info to the bottom of the context.
64
+ top_context_metadata = [f"Source: page {doc.metadata['page_number']} in {doc.metadata['textbook_name']}" for doc in top_context_list]
65
+ relevant_context_list = [f"{text.page_content}. {meta}" for text, meta in zip(top_context_list, top_context_metadata)]
66
+ return relevant_context_list