|
|
|
import gradio as gr |
|
from sentence_transformers import SentenceTransformer |
|
import chromadb |
|
import pandas as pd |
|
import os |
|
import json |
|
from pathlib import Path |
|
from llama_index.llms.anyscale import Anyscale |
|
|
|
|
|
model = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
|
|
|
chroma_client = chromadb.Client() |
|
|
|
|
|
def build_database(): |
|
|
|
df = pd.read_csv('collection_data.csv') |
|
|
|
|
|
collection_name = 'Dataset-10k-companies' |
|
|
|
|
|
|
|
|
|
|
|
collection = chroma_client.create_collection(name=collection_name) |
|
|
|
|
|
collection.add( |
|
documents=df['documents'].tolist(), |
|
ids=df['ids'].tolist(), |
|
metadatas=df['metadatas'].apply(eval).tolist(), |
|
embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist() |
|
) |
|
|
|
return collection |
|
|
|
|
|
collection = build_database() |
|
|
|
|
|
anyscale_api_key = os.environ.get('anyscale_api_key') |
|
|
|
|
|
client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf") |
|
|
|
|
|
def get_relevant_chunks(query, collection, top_n=3): |
|
query_embedding = model.encode(query).tolist() |
|
results = collection.query(query_embeddings=[query_embedding], n_results=top_n) |
|
|
|
relevant_chunks = [] |
|
for i in range(len(results['documents'][0])): |
|
chunk = results['documents'][0][i] |
|
source = results['metadatas'][0][i]['source'] |
|
page = results['metadatas'][0][i]['page'] |
|
relevant_chunks.append((chunk, source, page)) |
|
|
|
return relevant_chunks |
|
|
|
|
|
qna_system_message = """ |
|
You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights. |
|
|
|
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. |
|
The context contains references to specific portions of documents relevant to the user's query, along with source links. |
|
The source for a context will begin with the token: ###Source. |
|
|
|
When crafting your response: |
|
1. Select only the context relevant to answer the question. |
|
2. Include the source links in your response. |
|
3. User questions will begin with the token: ###Question. |
|
4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights." |
|
|
|
Adhere to the following guidelines: |
|
- Your response should only address the question asked and nothing else. |
|
- Answer only using the context provided. |
|
- Do not mention anything about the context in your final answer. |
|
- If the answer is not found in the context, respond with: "I don't know." |
|
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: |
|
- Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources. |
|
|
|
Here is an example of how to structure your response: |
|
|
|
Answer: |
|
[Answer] |
|
|
|
Source: |
|
[Source] |
|
""" |
|
|
|
qna_user_message_template = """ |
|
###Context |
|
Here are some documents and their source links that are relevant to the question mentioned below. |
|
{context} |
|
|
|
###Question |
|
{question} |
|
""" |
|
|
|
|
|
def get_llm_response(prompt, max_attempts=3): |
|
full_response = "" |
|
for attempt in range(max_attempts): |
|
try: |
|
response = client.complete(prompt, max_tokens=1000) |
|
chunk = response.text.strip() |
|
full_response += chunk |
|
if chunk.endswith((".", "!", "?")): |
|
break |
|
else: |
|
prompt = "Please continue from where you left off:\n" + chunk[-100:] |
|
except Exception as e: |
|
print(f"Attempt {attempt + 1} failed with error: {e}") |
|
return full_response |
|
|
|
|
|
def predict(company, user_query): |
|
try: |
|
|
|
modified_query = f"{user_query} for {company}" |
|
|
|
|
|
relevant_chunks = get_relevant_chunks(modified_query, collection) |
|
|
|
|
|
context = "" |
|
for chunk, source, page in relevant_chunks: |
|
context += chunk + "\n" |
|
context += f"###Source {source}, Page {page}\n" |
|
|
|
|
|
user_message = qna_user_message_template.format(context=context, question=user_query) |
|
|
|
|
|
prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}" |
|
|
|
|
|
answer = get_llm_response(prompt) |
|
|
|
|
|
|
|
|
|
|
|
log_interaction(company, user_query, context, answer) |
|
|
|
return answer |
|
except Exception as e: |
|
return f"An error occurred: {str(e)}" |
|
|
|
|
|
def log_interaction(company, user_query, context, answer): |
|
log_file = Path("interaction_log.jsonl") |
|
with log_file.open("a") as f: |
|
json.dump({ |
|
'company': company, |
|
'user_query': user_query, |
|
'context': context, |
|
'answer': answer |
|
}, f) |
|
f.write("\n") |
|
|
|
|
|
company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"] |
|
iface = gr.Interface( |
|
fn=predict, |
|
inputs=[ |
|
gr.Radio(company_list, label="Select Company"), |
|
gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query") |
|
], |
|
outputs=gr.Textbox(label="Generated Answer"), |
|
title="Company Reports Q&A", |
|
description="Query the vector database and get an LLM response based on the documents in the collection." |
|
) |
|
|
|
|
|
iface.launch() |
|
|