veeps commited on
Commit
63bb4b2
1 Parent(s): ab911b6
Files changed (1) hide show
  1. rag.py +109 -1
rag.py CHANGED
@@ -7,6 +7,10 @@ from langchain.schema import (
7
  )
8
  from datasets import load_dataset
9
  from pinecone import Pinecone
 
 
 
 
10
 
11
 
12
  dataset = load_dataset(
@@ -35,4 +39,108 @@ res = chat(messages)
35
  # add latest AI response to messages
36
  messages.append(res)
37
 
38
- # connect to pinecone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  )
8
  from datasets import load_dataset
9
  from pinecone import Pinecone
10
+ from pinecone import ServerlessSpec
11
+ import time
12
+ from langchain_openai import OpenAIEmbeddings
13
+ from tqdm.auto import tqdm
14
 
15
 
16
  dataset = load_dataset(
 
39
  # add latest AI response to messages
40
  messages.append(res)
41
 
42
+ # connect to pinecone
43
+ api_key = os.getenv('PINECONE_API_KEY')
44
+
45
+ # configure client
46
+ pc = Pinecone(api_key=api_key)
47
+
48
+ # connect to serverless
49
+ spec = ServerlessSpec(
50
+ cloud="aws", region="us-east-1"
51
+ )
52
+
53
+ # initialize index
54
+ index_name = 'llama-2-rag'
55
+ existing_indexes = [
56
+ index_info["name"] for index_info in pc.list_indexes()
57
+ ]
58
+
59
+ # check if index already exists (it shouldn't if this is first time)
60
+ if index_name not in existing_indexes:
61
+ # if does not exist, create index
62
+ pc.create_index(
63
+ index_name,
64
+ dimension=1536, # dimensionality of ada 002
65
+ metric='dotproduct',
66
+ spec=spec
67
+ )
68
+ # wait for index to be initialized
69
+ while not pc.describe_index(index_name).status['ready']:
70
+ time.sleep(1)
71
+
72
+ # connect to index
73
+ index = pc.Index(index_name)
74
+ time.sleep(1)
75
+ # view index stats
76
+ index.describe_index_stats()
77
+
78
+ # create vector embeddings of our index
79
+ embed_model = OpenAIEmbeddings(model="text-embedding-ada-002")
80
+
81
+ # iterate over dataset
82
+ data = dataset.to_pandas()
83
+ batch_size = 100
84
+
85
+ for i in tqdm(range(0, len(data), batch_size)):
86
+ i_end = min(len(data), i+batch_size)
87
+ # get batch of data
88
+ batch = data.iloc[i:i_end]
89
+ # generate unique ids for each chunk
90
+ ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()]
91
+ # get text to embed
92
+ texts = [x['chunk'] for _, x in batch.iterrows()]
93
+ # embed text
94
+ embeds = embed_model.embed_documents(texts)
95
+ # get metadata to store in Pinecone
96
+ metadata = [
97
+ {'text': x['chunk'],
98
+ 'source': x['source'],
99
+ 'title': x['title']} for i, x in batch.iterrows()
100
+ ]
101
+ # add to Pinecone
102
+ index.upsert(vectors=zip(ids, embeds, metadata))
103
+
104
+ index.describe_index_stats()
105
+
106
+ #### Retrival Augmented Generation
107
+ #from langchain_pinecone import PineconeVectoreStore
108
+ from langchain.vectorstores import Pinecone
109
+
110
+ # the metadata field that contains our text
111
+ text_field = "text"
112
+
113
+ # initialize the vector store object
114
+ vectorstore = Pinecone(
115
+ index, embed_model.embed_query, text_field
116
+ )
117
+
118
+ query = "What is so special about Llama 2?"
119
+
120
+ vectorstore.similarity_search(query, k=3)
121
+
122
+ # connect the output from vectorstore to chat
123
+ def augment_prompt(query: str):
124
+ # get top 3 results from knowledge base
125
+ results = vectorstore.similarity_search(query, k=3)
126
+ # get the text from the results
127
+ source_knowledge = "\n".join([x.page_content for x in results])
128
+ # feed into an augmented prompt
129
+ augmented_prompt = f"""Using the contexts below, answer the query.
130
+
131
+ Contexts:
132
+ {source_knowledge}
133
+
134
+ Query: {query}"""
135
+ return augmented_prompt
136
+
137
+ # create a new user prompt
138
+ prompt = HumanMessage(
139
+ content=augment_prompt(query)
140
+ )
141
+ # add to messages
142
+ messages.append(prompt)
143
+
144
+ res = chat(messages)
145
+
146
+ print(res.content)