import os from langchain.chat_models import ChatOpenAI from langchain.schema import ( SystemMessage, HumanMessage, AIMessage ) from datasets import load_dataset from pinecone import Pinecone from pinecone import ServerlessSpec import time from langchain_openai import OpenAIEmbeddings from tqdm.auto import tqdm dataset = load_dataset( "jamescalam/llama-2-arxiv-papers-chunked", split="train" ) os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") chat = ChatOpenAI( openai_api_key = os.environ["OPENAI_API_KEY"], model='gpt-3.5-turbo' ) messages = [ SystemMessage(content="You are a helpful assistant."), HumanMessage(content="Hi AI, how are you today?"), AIMessage(content="I'm great thank you. How can I help you?"), HumanMessage(content="I'd like to understand string theory.") ] res = chat(messages) # add latest AI response to messages messages.append(res) # connect to pinecone api_key = os.getenv('PINECONE_API_KEY') # configure client pc = Pinecone(api_key=api_key) # connect to serverless spec = ServerlessSpec( cloud="aws", region="us-east-1" ) # initialize index index_name = 'llama-2-rag' existing_indexes = [ index_info["name"] for index_info in pc.list_indexes() ] # check if index already exists (it shouldn't if this is first time) if index_name not in existing_indexes: # if does not exist, create index pc.create_index( index_name, dimension=1536, # dimensionality of ada 002 metric='dotproduct', spec=spec ) # wait for index to be initialized while not pc.describe_index(index_name).status['ready']: time.sleep(1) # connect to index index = pc.Index(index_name) time.sleep(1) # view index stats index.describe_index_stats() # create vector embeddings of our index embed_model = OpenAIEmbeddings(model="text-embedding-ada-002") # iterate over dataset data = dataset.to_pandas() batch_size = 100 for i in tqdm(range(0, len(data), batch_size)): i_end = min(len(data), i+batch_size) # get batch of data batch = data.iloc[i:i_end] # generate unique ids for each chunk ids = [f"{x['doi']}-{x['chunk-id']}" for i, x in batch.iterrows()] # get text to embed texts = [x['chunk'] for _, x in batch.iterrows()] # embed text embeds = embed_model.embed_documents(texts) # get metadata to store in Pinecone metadata = [ {'text': x['chunk'], 'source': x['source'], 'title': x['title']} for i, x in batch.iterrows() ] # add to Pinecone index.upsert(vectors=zip(ids, embeds, metadata)) index.describe_index_stats() #### Retrival Augmented Generation #from langchain_pinecone import PineconeVectoreStore from langchain.vectorstores import Pinecone # the metadata field that contains our text text_field = "text" # initialize the vector store object vectorstore = Pinecone( index, embed_model.embed_query, text_field ) query = "What is so special about Llama 2?" vectorstore.similarity_search(query, k=3) # connect the output from vectorstore to chat def augment_prompt(query: str): # get top 3 results from knowledge base results = vectorstore.similarity_search(query, k=3) # get the text from the results source_knowledge = "\n".join([x.page_content for x in results]) # feed into an augmented prompt augmented_prompt = f"""Using the contexts below, answer the query. Contexts: {source_knowledge} Query: {query}""" return augmented_prompt # create a new user prompt prompt = HumanMessage( content=augment_prompt(query) ) # add to messages messages.append(prompt) res = chat(messages) print(res.content)