Spaces:
Sleeping
Sleeping
import subprocess | |
import sys | |
import os | |
import uuid | |
import json | |
from pathlib import Path | |
# Install dependencies if not already installed | |
def install_packages(): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "openai"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain_community"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "sentence-transformers"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"]) | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "python-dotenv"]) | |
install_packages() | |
# Import installed modules | |
from huggingface_hub import login, CommitScheduler | |
import openai | |
import gradio as gr | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from dotenv import load_dotenv | |
# Load environment variables from .env file | |
load_dotenv() | |
# Get API tokens from environment variables | |
#hf_token = os.getenv("HUGGINGFACE_TOKEN") | |
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file | |
#if hf_token is None: | |
# raise ValueError("Hugging Face token is missing. Please check your .env file.") | |
# Log in to Hugging Face | |
# Retrieve the Hugging Face token from environment variables | |
hf_token = os.getenv("hf_token") | |
# Check if the token is retrieved successfully | |
if not hf_token: | |
raise ValueError("Hugging Face token is missing. Please set 'hf_token' as an environment variable.") | |
# Log in to Hugging Face with the retrieved token | |
login(hf_token) | |
print("Logged in to Hugging Face successfully.") | |
# Set up embeddings and vector store | |
embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large") | |
collection_name = 'report-10k-2024' | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory='./report_10kdb', | |
embedding_function=embeddings | |
) | |
retriever = vectorstore_persisted.as_retriever( | |
search_type='similarity', | |
search_kwargs={'k': 5} | |
) | |
# Define Q&A system message | |
qna_system_message = """ | |
You are an AI assistant for Finsights Grey Inc., helping automate extraction, summarization, and analysis of 10-K reports. | |
Your responses should be based solely on the context provided. | |
If an answer is not found in the context, respond with "I don't know." | |
""" | |
qna_user_message_template = """ | |
###Context | |
Here are some documents that are relevant to the question. | |
{context} | |
###Question | |
{question} | |
""" | |
# Define the predict function | |
def predict(user_input, company): | |
filter = "dataset/" + company + "-10-k-2023.pdf" | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter}) | |
# Create context for query | |
context_list = [d.page_content for d in relevant_document_chunks] | |
context_for_query = ".".join(context_list) | |
# Create messages | |
prompt = [ | |
{'role': 'system', 'content': qna_system_message}, | |
{'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)} | |
] | |
# Get response from the LLM | |
# Get response from the LLM using the updated API method | |
# Get response from the LLM using the updated API method | |
response = openai.completions.create( | |
model='gpt-3.5-turbo', # Specify the model you want to use | |
messages=prompt, # Pass the prompt (context and user message) | |
temperature=0 # Set temperature for response variety | |
) | |
# Extract the prediction from the response | |
prediction = response['choices'][0]['message']['content'] | |
#except Exception as e: | |
# This will run if an exception occurs | |
prediction = str(e) | |
# Print the prediction or error | |
print(prediction) | |
# Log inputs and outputs to a local log file | |
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
log_folder = log_file.parent | |
scheduler = CommitScheduler( | |
repo_id="RAGREPORTS-log", | |
repo_type="dataset", | |
folder_path=log_folder, | |
path_in_repo="data", | |
every=2 | |
) | |
with scheduler.lock: | |
with log_file.open("a") as f: | |
f.write(json.dumps( | |
{ | |
'user_input': user_input, | |
'retrieved_context': context_for_query, | |
'model_response': prediction | |
} | |
)) | |
f.write("\n") | |
# Return the prediction after logging | |
#return prediction | |
def get_predict(question, company): | |
# Map user selection to company name | |
company_map = { | |
"AWS": "aws", | |
"IBM": "IBM", | |
"Google": "Google", | |
"Meta": "meta", | |
"Microsoft": "msft" | |
} | |
selected_company = company_map.get(company) | |
if not selected_company: | |
return "Invalid company selected" | |
return predict(question, selected_company) | |
# Set-up the Gradio UI | |
with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo: | |
with gr.Row(): | |
company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company") | |
question = gr.Textbox(label="Enter your question") | |
submit = gr.Button("Submit") | |
output = gr.Textbox(label="Output") | |
submit.click( | |
fn=get_predict, | |
inputs=[question, company], | |
outputs=output | |
) | |
demo.queue() | |
demo.launch() | |