RAGREPORTS / app.py
kajila's picture
Update app.py
e082d15 verified
raw
history blame
4.71 kB
import subprocess
import sys
import os
import uuid
import json
from pathlib import Path
from dotenv import load_dotenv
# Install dependencies if not already installed
def install_packages():
subprocess.check_call([sys.executable, "-m", "pip", "install", "openai"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain_community"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "sentence-transformers"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "python-dotenv"])
install_packages()
# Import installed modules
from huggingface_hub import login, CommitScheduler
import openai
import gradio as gr
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
# Load environment variables from .env file
load_dotenv()
# Get API tokens from environment variables
hf_token = os.getenv("HUGGINGFACE_TOKEN")
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file
if hf_token is None:
raise ValueError("Hugging Face token is missing. Please check your .env file.")
# Log in to Hugging Face
login(hf_token)
# Set up embeddings and vector store
embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
collection_name = 'report-10k-2024'
vectorstore_persisted = Chroma(
collection_name=collection_name,
persist_directory='./report_10kdb',
embedding_function=embeddings
)
retriever = vectorstore_persisted.as_retriever(
search_type='similarity',
search_kwargs={'k': 5}
)
# Define Q&A system message
qna_system_message = """
You are an AI assistant for Finsights Grey Inc., helping automate extraction, summarization, and analysis of 10-K reports.
Your responses should be based solely on the context provided.
If an answer is not found in the context, respond with "I don't know."
"""
qna_user_message_template = """
###Context
Here are some documents that are relevant to the question.
{context}
###Question
{question}
"""
# Define the predict function
def predict(user_input, company):
filter = "dataset/" + company + "-10-k-2023.pdf"
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})
# Create context for query
context_list = [d.page_content for d in relevant_document_chunks]
context_for_query = ".".join(context_list)
# Create messages
prompt = [
{'role': 'system', 'content': qna_system_message},
{'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)}
]
# Get response from the LLM
try:
response = openai.ChatCompletion.create(
model='mistralai/Mixtral-8x7B-Instruct-v0.1',
messages=prompt,
temperature=0
)
prediction = response.choices[0].message.content
except Exception as e:
prediction = str(e)
# Log inputs and outputs to a local log file
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
log_folder = log_file.parent
scheduler = CommitScheduler(
repo_id="RAGREPORTS-log",
repo_type="dataset",
folder_path=log_folder,
path_in_repo="data",
every=2
)
with scheduler.lock:
with log_file.open("a") as f:
f.write(json.dumps(
{
'user_input': user_input,
'retrieved_context': context_for_query,
'model_response': prediction
}
))
f.write("\n")
return prediction
def get_predict(question, company):
# Map user selection to company name
company_map = {
"AWS": "aws",
"IBM": "IBM",
"Google": "Google",
"Meta": "meta",
"Microsoft": "msft"
}
selected_company = company_map.get(company)
if not selected_company:
return "Invalid company selected"
return predict(question, selected_company)
# Set-up the Gradio UI
with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
with gr.Row():
company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
question = gr.Textbox(label="Enter your question")
submit = gr.Button("Submit")
output = gr.Textbox(label="Output")
submit.click(
fn=get_predict,
inputs=[question, company],
outputs=output
)
demo.queue()
demo.launch()