BastienHot's picture
Update rag_functions.py
9f42557 verified
raw
history blame
4.88 kB
import os
#os.system("pip install faiss-cpu")
#os.system("pip install tqdm")
#os.system("pip install numpy")
#os.system("pip install pandas")
from transformers import pipeline
from tqdm.auto import tqdm
import faiss
import numpy as np
import torch
import pandas as pd
pipe_UAE = pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1")
pipe_llama = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf", token=SECRET_TOKEN)
def encode_query(query):
"""
Encodes a query into embeddings using the "WhereIsAI/UAE-Large-V1" model pipeline.
Args:
query (string): Query to be encoded.
Returns:
numpy.ndarray: An array containing the embeddings for the query.
"""
# Since the pipeline can process strings directly, we assume it returns a list of lists (one per query)
result = pipe_UAE(query)
# Convert result to a numpy array and handle dimensionality
query_embedding = np.array(result)
# Typically, the result might have an extra dimension for batch or sequence length which we don't need
if query_embedding.ndim > 2:
query_embedding = query_embedding.mean(axis=1) # Mean across the sequence length if it's not done
return query_embedding.squeeze() # Ensure it is 2D (1, embedding_dim)
def construct_prompt(query, contexts):
"""
Constructs a prompt for the model by combining the query with the retrieved contexts.
Args:
query (str): The user's query.
contexts (List[str]): The original texts of the retrieved documents.
separator (str): The separator to use between contexts.
Returns:
str: The constructed prompt.
"""
return f"[CONTEXT] {''.join(contexts['chunk'])} [/CONTEXT] [QUESTION] {query} [/QUESTION]"
def clean_output_text(output):
"""
Cleans the output of the model to only get the latest answer.
Args:
output (str): The output generated by our prediction.
Returns:
str: The cleaned response.
"""
output = output.split("[/QUESTION]")[-1].strip()
# Ensuring the answer ends with a full sentence
last_dot_index = output.rfind(".")
if last_dot_index != -1:
output = output[:last_dot_index + 1]
return output.strip()
def generate_response(prompt):
"""
Generates a response to the given prompt using the loaded LLaMA model.
Args:
prompt (str): The prompt to generate a response for.
Returns:
str: The generated response.
"""
response = pipe_llama(prompt)
return response
def build_index(embeddings):
"""
Builds a FAISS index for efficient similarity search.
Args:
embeddings (List[torch.Tensor]): List of document chunk embeddings.
Returns:
faiss.IndexFlatL2: FAISS index for the embeddings.
"""
dim = embeddings.shape[1] # Retrieve the dimension of our embeddings
index = faiss.IndexFlatL2(dim) # Create our index (FlatL2 -> Euclidian distance, precise database)
embeddings_np = np.vstack([emb.squeeze().numpy() for emb in embeddings]) # Create the numpy array containing the embeddings
index.add(embeddings_np) # Add the embeddings to the index
return index
def retrieve_documents(query, index, doc_chunks, k=5):
"""
Retrieves the top-k relevant document chunks for a given query, along with the original document text.
Args:
query (str): The query string.
index (faiss.IndexFlatL2): The FAISS index for document chunk embeddings.
doc_chunks (DataFrame): DataFrame containing document chunks and their original context.
k (int): Number of top documents to retrieve.
Returns:
DataFrame: Top-k relevant document chunks, including the original context and link.
"""
query_embedding = encode_query(query) # Encode our query text
query_embedding_np = query_embedding.reshape(1, -1) # Ensure it's 2D with correct shape
# Perform the search
D, I = index.search(query_embedding_np, k) # Search for embeddings similar to our query
return doc_chunks.iloc[I[0]] # Return the similar embeddings
def prompt_model(query, index, preprocessed_docs):
# Retrieve the documents
retrieved_indexes = retrieve_documents(query, index, preprocessed_docs, k=5)
# Create the prompt for the model
prompt = construct_prompt(query, retrieved_indexes)
# Generate the response
response = generate_response(prompt)
# Clean the output text
cleaned_response = clean_output_text(response)
# Extract the links from the DataFrame
links = cleaned_response['links'].tolist()
# Format the links as a single string with each link on a new line
formatted_links = "\n".join([f"[{link}]" for link in links])
# Append the formatted links to the answer
final_answer = f"{cleaned_response}\nSources:\n{formatted_links}"
return final_answer