rag-10k-analysis / 01JUL24v3app.py
kgauvin603's picture
Rename app.py to 01JUL24v3app.py
adbfa0a verified
raw
history blame
7.32 kB
import gradio as gr
from sentence_transformers import SentenceTransformer
import chromadb
import pandas as pd
import os
import json
from pathlib import Path
import numpy as np
from llama_index.llms.anyscale import Anyscale
from langchain_community.embeddings.sentence_transformer import (
SentenceTransformerEmbeddings
)
# Load the sentence transformer model
model = SentenceTransformer('all-MiniLM-L6-v2')
#model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
embedding_dim = model.get_sentence_embedding_dimension()
print(f"Embedding Dimension: {embedding_dim}")
# Initialize the ChromaDB client
chroma_client = chromadb.Client()
# Define a function to pad embeddings to the desired dimensionality
def pad_embedding(embedding, target_dim=1024):
embedding = np.array(eval(embedding.replace(',,', ',')))
if len(embedding) < target_dim:
embedding = np.pad(embedding, (0, target_dim - len(embedding)))
elif len(embedding) > target_dim:
embedding = embedding[:target_dim]
return embedding.tolist()
# Function to build the database from CSV
def build_database():
# Read the CSV file
df = pd.read_csv('vector_store.csv')
# Create a collection
collection_name = 'Dataset-10k-companies'
# Delete the existing collection if it exists
# chroma_client.delete_collection(name=collection_name)
# Create a new collection
collection = chroma_client.create_collection(
name=collection_name,
dimension=384
)
# Add the data from the DataFrame to the collection
collection.add(
documents=df['documents'].tolist(),
ids=df['ids'].tolist(),
metadatas=df['metadatas'].apply(eval).tolist(),
embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
#embeddings=df['embeddings'].apply(pad_embedding).tolist()
)
return collection
# Build the database when the app starts
collection = build_database()
# Access the Anyscale API key from environment variables
anyscale_api_key = os.environ.get('anyscale_api_key')
# Instantiate the Anyscale client
client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
# Function to get relevant chunks
def get_relevant_chunks(query, collection, top_n=3):
query_embedding = model.encode(query).tolist()
results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
relevant_chunks = []
for i in range(len(results['documents'][0])):
chunk = results['documents'][0][i]
source = results['metadatas'][0][i]['source']
page = results['metadatas'][0][i]['page']
relevant_chunks.append((chunk, source, page))
return relevant_chunks
# Define system message for LLM
qna_system_message = """
You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights.
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
The context contains references to specific portions of documents relevant to the user's query, along with source links.
The source for a context will begin with the token: ###Source.
When crafting your response:
1. Select only the context relevant to answer the question.
2. Include the source links in your response.
3. User questions will begin with the token: ###Question.
4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights."
Adhere to the following guidelines:
- Your response should only address the question asked and nothing else.
- Answer only using the context provided.
- Do not mention anything about the context in your final answer.
- If the answer is not found in the context, respond with: "I don't know."
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
- Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources.
Here is an example of how to structure your response:
Answer:
[Answer]
Source:
[Source]
"""
# Create a user message template
qna_user_message_template = """
###Context
Here are some documents and their source links that are relevant to the question mentioned below.
{context}
###Question
{question}
"""
# Function to get LLM response
def get_llm_response(prompt, max_attempts=3):
full_response = ""
for attempt in range(max_attempts):
try:
response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
chunk = response.text.strip()
full_response += chunk
if chunk.endswith((".", "!", "?")): # Check if response seems complete
break
else:
prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
except Exception as e:
print(f"Attempt {attempt + 1} failed with error: {e}")
return full_response
# Prediction function
def predict(company, user_query):
try:
# Modify the query to include the company name
modified_query = f"{user_query} for {company}"
# Get relevant chunks
relevant_chunks = get_relevant_chunks(modified_query, collection)
# Prepare the context string
context = ""
for chunk, source, page in relevant_chunks:
context += chunk + "\n"
context += f"###Source {source}, Page {page}\n"
# Prepare the user message
user_message = qna_user_message_template.format(context=context, question=user_query)
# Craft the prompt to pass to the Llama model
prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}"
# Generate the response using the Llama model through Anyscale
answer = get_llm_response(prompt)
# Extract the generated response
# answer = response.text.strip()
# Log the interaction
log_interaction(company, user_query, context, answer)
return answer
except Exception as e:
return f"An error occurred: {str(e)}"
# Function to log interactions
def log_interaction(company, user_query, context, answer):
log_file = Path("interaction_log.jsonl")
with log_file.open("a") as f:
json.dump({
'company': company,
'user_query': user_query,
'context': context,
'answer': answer
}, f)
f.write("\n")
# Create Gradio interface
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
iface = gr.Interface(
fn=predict,
inputs=[
gr.Radio(company_list, label="Select Company"),
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
],
outputs=gr.Textbox(label="Generated Answer"),
title="Company Reports Q&A",
description="Query the vector database and get an LLM response based on the documents in the collection."
)
# Launch the interfaci
iface.launch(share=True)