unstoppable_app / rag.py
veeps
update
63bb4b2
import os
from langchain.chat_models import ChatOpenAI
from langchain.schema import (
SystemMessage,
HumanMessage,
AIMessage
)
from datasets import load_dataset
from pinecone import Pinecone
from pinecone import ServerlessSpec
import time
from langchain_openai import OpenAIEmbeddings
from tqdm.auto import tqdm
dataset = load_dataset(
"jamescalam/llama-2-arxiv-papers-chunked",
split="train"
)
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
chat = ChatOpenAI(
openai_api_key = os.environ["OPENAI_API_KEY"],
model='gpt-3.5-turbo'
)
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hi AI, how are you today?"),
AIMessage(content="I'm great thank you. How can I help you?"),
HumanMessage(content="I'd like to understand string theory.")
]
res = chat(messages)
# add latest AI response to messages
messages.append(res)
# connect to pinecone
api_key = os.getenv('PINECONE_API_KEY')
# configure client
pc = Pinecone(api_key=api_key)
# connect to serverless
spec = ServerlessSpec(
cloud="aws", region="us-east-1"
)
# initialize index
index_name = 'llama-2-rag'
existing_indexes = [
index_info["name"] for index_info in pc.list_indexes()
]
# check if index already exists (it shouldn't if this is first time)
if index_name not in existing_indexes:
# if does not exist, create index
pc.create_index(
index_name,
dimension=1536, # dimensionality of ada 002
metric='dotproduct',
spec=spec
)
# wait for index to be initialized
while not pc.describe_index(index_name).status['ready']:
time.sleep(1)
# connect to index
index = pc.Index(index_name)
time.sleep(1)
# view index stats
index.describe_index_stats()
# create vector embeddings of our index
embed_model = OpenAIEmbeddings(model="text-embedding-ada-002")
# iterate over dataset
data = dataset.to_pandas()
batch_size = 100
for i in tqdm(range(0, len(data), batch_size)):
i_end = min(len(data), i+batch_size)
# get batch of data
batch = data.iloc[i:i_end]
# generate unique ids for each chunk
ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()]
# get text to embed
texts = [x['chunk'] for _, x in batch.iterrows()]
# embed text
embeds = embed_model.embed_documents(texts)
# get metadata to store in Pinecone
metadata = [
{'text': x['chunk'],
'source': x['source'],
'title': x['title']} for i, x in batch.iterrows()
]
# add to Pinecone
index.upsert(vectors=zip(ids, embeds, metadata))
index.describe_index_stats()
#### Retrival Augmented Generation
#from langchain_pinecone import PineconeVectoreStore
from langchain.vectorstores import Pinecone
# the metadata field that contains our text
text_field = "text"
# initialize the vector store object
vectorstore = Pinecone(
index, embed_model.embed_query, text_field
)
query = "What is so special about Llama 2?"
vectorstore.similarity_search(query, k=3)
# connect the output from vectorstore to chat
def augment_prompt(query: str):
# get top 3 results from knowledge base
results = vectorstore.similarity_search(query, k=3)
# get the text from the results
source_knowledge = "\n".join([x.page_content for x in results])
# feed into an augmented prompt
augmented_prompt = f"""Using the contexts below, answer the query.
Contexts:
{source_knowledge}
Query: {query}"""
return augmented_prompt
# create a new user prompt
prompt = HumanMessage(
content=augment_prompt(query)
)
# add to messages
messages.append(prompt)
res = chat(messages)
print(res.content)