Spaces:
Sleeping
Sleeping
import os | |
from langchain.chat_models import ChatOpenAI | |
from langchain.schema import ( | |
SystemMessage, | |
HumanMessage, | |
AIMessage | |
) | |
from datasets import load_dataset | |
from pinecone import Pinecone | |
from pinecone import ServerlessSpec | |
import time | |
from langchain_openai import OpenAIEmbeddings | |
from tqdm.auto import tqdm | |
dataset = load_dataset( | |
"jamescalam/llama-2-arxiv-papers-chunked", | |
split="train" | |
) | |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
chat = ChatOpenAI( | |
openai_api_key = os.environ["OPENAI_API_KEY"], | |
model='gpt-3.5-turbo' | |
) | |
messages = [ | |
SystemMessage(content="You are a helpful assistant."), | |
HumanMessage(content="Hi AI, how are you today?"), | |
AIMessage(content="I'm great thank you. How can I help you?"), | |
HumanMessage(content="I'd like to understand string theory.") | |
] | |
res = chat(messages) | |
# add latest AI response to messages | |
messages.append(res) | |
# connect to pinecone | |
api_key = os.getenv('PINECONE_API_KEY') | |
# configure client | |
pc = Pinecone(api_key=api_key) | |
# connect to serverless | |
spec = ServerlessSpec( | |
cloud="aws", region="us-east-1" | |
) | |
# initialize index | |
index_name = 'llama-2-rag' | |
existing_indexes = [ | |
index_info["name"] for index_info in pc.list_indexes() | |
] | |
# check if index already exists (it shouldn't if this is first time) | |
if index_name not in existing_indexes: | |
# if does not exist, create index | |
pc.create_index( | |
index_name, | |
dimension=1536, # dimensionality of ada 002 | |
metric='dotproduct', | |
spec=spec | |
) | |
# wait for index to be initialized | |
while not pc.describe_index(index_name).status['ready']: | |
time.sleep(1) | |
# connect to index | |
index = pc.Index(index_name) | |
time.sleep(1) | |
# view index stats | |
index.describe_index_stats() | |
# create vector embeddings of our index | |
embed_model = OpenAIEmbeddings(model="text-embedding-ada-002") | |
# iterate over dataset | |
data = dataset.to_pandas() | |
batch_size = 100 | |
for i in tqdm(range(0, len(data), batch_size)): | |
i_end = min(len(data), i+batch_size) | |
# get batch of data | |
batch = data.iloc[i:i_end] | |
# generate unique ids for each chunk | |
ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()] | |
# get text to embed | |
texts = [x['chunk'] for _, x in batch.iterrows()] | |
# embed text | |
embeds = embed_model.embed_documents(texts) | |
# get metadata to store in Pinecone | |
metadata = [ | |
{'text': x['chunk'], | |
'source': x['source'], | |
'title': x['title']} for i, x in batch.iterrows() | |
] | |
# add to Pinecone | |
index.upsert(vectors=zip(ids, embeds, metadata)) | |
index.describe_index_stats() | |
#### Retrival Augmented Generation | |
#from langchain_pinecone import PineconeVectoreStore | |
from langchain.vectorstores import Pinecone | |
# the metadata field that contains our text | |
text_field = "text" | |
# initialize the vector store object | |
vectorstore = Pinecone( | |
index, embed_model.embed_query, text_field | |
) | |
query = "What is so special about Llama 2?" | |
vectorstore.similarity_search(query, k=3) | |
# connect the output from vectorstore to chat | |
def augment_prompt(query: str): | |
# get top 3 results from knowledge base | |
results = vectorstore.similarity_search(query, k=3) | |
# get the text from the results | |
source_knowledge = "\n".join([x.page_content for x in results]) | |
# feed into an augmented prompt | |
augmented_prompt = f"""Using the contexts below, answer the query. | |
Contexts: | |
{source_knowledge} | |
Query: {query}""" | |
return augmented_prompt | |
# create a new user prompt | |
prompt = HumanMessage( | |
content=augment_prompt(query) | |
) | |
# add to messages | |
messages.append(prompt) | |
res = chat(messages) | |
print(res.content) |