Spaces:
Runtime error
Runtime error
import os | |
#os.system("pip install faiss-cpu") | |
#os.system("pip install tqdm") | |
#os.system("pip install numpy") | |
#os.system("pip install pandas") | |
from transformers import pipeline | |
from tqdm.auto import tqdm | |
import faiss | |
import numpy as np | |
import torch | |
import pandas as pd | |
pipe_UAE = pipeline("feature-extraction", model="WhereIsAI/UAE-Large-V1") | |
pipe_llama = pipeline("text-generation", model="meta-llama/Llama-2-7b-chat-hf", token=SECRET_TOKEN) | |
def encode_query(query): | |
""" | |
Encodes a query into embeddings using the "WhereIsAI/UAE-Large-V1" model pipeline. | |
Args: | |
query (string): Query to be encoded. | |
Returns: | |
numpy.ndarray: An array containing the embeddings for the query. | |
""" | |
# Since the pipeline can process strings directly, we assume it returns a list of lists (one per query) | |
result = pipe_UAE(query) | |
# Convert result to a numpy array and handle dimensionality | |
query_embedding = np.array(result) | |
# Typically, the result might have an extra dimension for batch or sequence length which we don't need | |
if query_embedding.ndim > 2: | |
query_embedding = query_embedding.mean(axis=1) # Mean across the sequence length if it's not done | |
return query_embedding.squeeze() # Ensure it is 2D (1, embedding_dim) | |
def construct_prompt(query, contexts): | |
""" | |
Constructs a prompt for the model by combining the query with the retrieved contexts. | |
Args: | |
query (str): The user's query. | |
contexts (List[str]): The original texts of the retrieved documents. | |
separator (str): The separator to use between contexts. | |
Returns: | |
str: The constructed prompt. | |
""" | |
return f"[CONTEXT] {''.join(contexts['chunk'])} [/CONTEXT] [QUESTION] {query} [/QUESTION]" | |
def clean_output_text(output): | |
""" | |
Cleans the output of the model to only get the latest answer. | |
Args: | |
output (str): The output generated by our prediction. | |
Returns: | |
str: The cleaned response. | |
""" | |
output = output.split("[/QUESTION]")[-1].strip() | |
# Ensuring the answer ends with a full sentence | |
last_dot_index = output.rfind(".") | |
if last_dot_index != -1: | |
output = output[:last_dot_index + 1] | |
return output.strip() | |
def generate_response(prompt): | |
""" | |
Generates a response to the given prompt using the loaded LLaMA model. | |
Args: | |
prompt (str): The prompt to generate a response for. | |
Returns: | |
str: The generated response. | |
""" | |
response = pipe_llama(prompt) | |
return response | |
def build_index(embeddings): | |
""" | |
Builds a FAISS index for efficient similarity search. | |
Args: | |
embeddings (List[torch.Tensor]): List of document chunk embeddings. | |
Returns: | |
faiss.IndexFlatL2: FAISS index for the embeddings. | |
""" | |
dim = embeddings.shape[1] # Retrieve the dimension of our embeddings | |
index = faiss.IndexFlatL2(dim) # Create our index (FlatL2 -> Euclidian distance, precise database) | |
embeddings_np = np.vstack([emb.squeeze().numpy() for emb in embeddings]) # Create the numpy array containing the embeddings | |
index.add(embeddings_np) # Add the embeddings to the index | |
return index | |
def retrieve_documents(query, index, doc_chunks, k=5): | |
""" | |
Retrieves the top-k relevant document chunks for a given query, along with the original document text. | |
Args: | |
query (str): The query string. | |
index (faiss.IndexFlatL2): The FAISS index for document chunk embeddings. | |
doc_chunks (DataFrame): DataFrame containing document chunks and their original context. | |
k (int): Number of top documents to retrieve. | |
Returns: | |
DataFrame: Top-k relevant document chunks, including the original context and link. | |
""" | |
query_embedding = encode_query(query) # Encode our query text | |
query_embedding_np = query_embedding.reshape(1, -1) # Ensure it's 2D with correct shape | |
# Perform the search | |
D, I = index.search(query_embedding_np, k) # Search for embeddings similar to our query | |
return doc_chunks.iloc[I[0]] # Return the similar embeddings | |
def prompt_model(query, index, preprocessed_docs): | |
# Retrieve the documents | |
retrieved_indexes = retrieve_documents(query, index, preprocessed_docs, k=5) | |
# Create the prompt for the model | |
prompt = construct_prompt(query, retrieved_indexes) | |
# Generate the response | |
response = generate_response(prompt) | |
# Clean the output text | |
cleaned_response = clean_output_text(response) | |
# Extract the links from the DataFrame | |
links = cleaned_response['links'].tolist() | |
# Format the links as a single string with each link on a new line | |
formatted_links = "\n".join([f"[{link}]" for link in links]) | |
# Append the formatted links to the answer | |
final_answer = f"{cleaned_response}\nSources:\n{formatted_links}" | |
return final_answer | |