Spaces:
Sleeping
Sleeping
import gradio as gr | |
from sentence_transformers import SentenceTransformer | |
import chromadb | |
import pandas as pd | |
import os | |
import json | |
from pathlib import Path | |
import numpy as np | |
from llama_index.llms.anyscale import Anyscale | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings | |
) | |
# Load the sentence transformer model | |
model = SentenceTransformer('all-MiniLM-L6-v2') | |
#model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
embedding_dim = model.get_sentence_embedding_dimension() | |
print(f"Embedding Dimension: {embedding_dim}") | |
# Initialize the ChromaDB client | |
chroma_client = chromadb.Client() | |
# Define a function to pad embeddings to the desired dimensionality | |
def pad_embedding(embedding, target_dim=1024): | |
embedding = np.array(eval(embedding.replace(',,', ','))) | |
if len(embedding) < target_dim: | |
embedding = np.pad(embedding, (0, target_dim - len(embedding))) | |
elif len(embedding) > target_dim: | |
embedding = embedding[:target_dim] | |
return embedding.tolist() | |
# Function to build the database from CSV | |
def build_database(): | |
# Read the CSV file | |
df = pd.read_csv('vector_store.csv') | |
# Create a collection | |
collection_name = 'Dataset-10k-companies' | |
# Delete the existing collection if it exists | |
# chroma_client.delete_collection(name=collection_name) | |
# Create a new collection | |
collection = chroma_client.create_collection( | |
name=collection_name, | |
dimension=384 | |
) | |
# Add the data from the DataFrame to the collection | |
collection.add( | |
documents=df['documents'].tolist(), | |
ids=df['ids'].tolist(), | |
metadatas=df['metadatas'].apply(eval).tolist(), | |
embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist() | |
#embeddings=df['embeddings'].apply(pad_embedding).tolist() | |
) | |
return collection | |
# Build the database when the app starts | |
collection = build_database() | |
# Access the Anyscale API key from environment variables | |
anyscale_api_key = os.environ.get('anyscale_api_key') | |
# Instantiate the Anyscale client | |
client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf") | |
# Function to get relevant chunks | |
def get_relevant_chunks(query, collection, top_n=3): | |
query_embedding = model.encode(query).tolist() | |
results = collection.query(query_embeddings=[query_embedding], n_results=top_n) | |
relevant_chunks = [] | |
for i in range(len(results['documents'][0])): | |
chunk = results['documents'][0][i] | |
source = results['metadatas'][0][i]['source'] | |
page = results['metadatas'][0][i]['page'] | |
relevant_chunks.append((chunk, source, page)) | |
return relevant_chunks | |
# Define system message for LLM | |
qna_system_message = """ | |
You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights. | |
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. | |
The context contains references to specific portions of documents relevant to the user's query, along with source links. | |
The source for a context will begin with the token: ###Source. | |
When crafting your response: | |
1. Select only the context relevant to answer the question. | |
2. Include the source links in your response. | |
3. User questions will begin with the token: ###Question. | |
4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights." | |
Adhere to the following guidelines: | |
- Your response should only address the question asked and nothing else. | |
- Answer only using the context provided. | |
- Do not mention anything about the context in your final answer. | |
- If the answer is not found in the context, respond with: "I don't know." | |
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: | |
- Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources. | |
Here is an example of how to structure your response: | |
Answer: | |
[Answer] | |
Source: | |
[Source] | |
""" | |
# Create a user message template | |
qna_user_message_template = """ | |
###Context | |
Here are some documents and their source links that are relevant to the question mentioned below. | |
{context} | |
###Question | |
{question} | |
""" | |
# Function to get LLM response | |
def get_llm_response(prompt, max_attempts=3): | |
full_response = "" | |
for attempt in range(max_attempts): | |
try: | |
response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible | |
chunk = response.text.strip() | |
full_response += chunk | |
if chunk.endswith((".", "!", "?")): # Check if response seems complete | |
break | |
else: | |
prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context | |
except Exception as e: | |
print(f"Attempt {attempt + 1} failed with error: {e}") | |
return full_response | |
# Prediction function | |
def predict(company, user_query): | |
try: | |
# Modify the query to include the company name | |
modified_query = f"{user_query} for {company}" | |
# Get relevant chunks | |
relevant_chunks = get_relevant_chunks(modified_query, collection) | |
# Prepare the context string | |
context = "" | |
for chunk, source, page in relevant_chunks: | |
context += chunk + "\n" | |
context += f"###Source {source}, Page {page}\n" | |
# Prepare the user message | |
user_message = qna_user_message_template.format(context=context, question=user_query) | |
# Craft the prompt to pass to the Llama model | |
prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}" | |
# Generate the response using the Llama model through Anyscale | |
answer = get_llm_response(prompt) | |
# Extract the generated response | |
# answer = response.text.strip() | |
# Log the interaction | |
log_interaction(company, user_query, context, answer) | |
return answer | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
# Function to log interactions | |
def log_interaction(company, user_query, context, answer): | |
log_file = Path("interaction_log.jsonl") | |
with log_file.open("a") as f: | |
json.dump({ | |
'company': company, | |
'user_query': user_query, | |
'context': context, | |
'answer': answer | |
}, f) | |
f.write("\n") | |
# Create Gradio interface | |
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"] | |
iface = gr.Interface( | |
fn=predict, | |
inputs=[ | |
gr.Radio(company_list, label="Select Company"), | |
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query") | |
], | |
outputs=gr.Textbox(label="Generated Answer"), | |
title="Company Reports Q&A", | |
description="Query the vector database and get an LLM response based on the documents in the collection." | |
) | |
# Launch the interfaci | |
iface.launch(share=True) | |