Spaces:
Sleeping
Sleeping
import subprocess | |
import sys | |
import os | |
import uuid | |
import json | |
from pathlib import Path | |
import gradio as gr | |
def install_packages(): | |
packages = ["openai==0.28", "langchain_community", "sentence-transformers", "chromadb", "huggingface_hub", "python-dotenv", "numpy", "scipy", "scikit-learn"] | |
for package in packages: | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "--upgrade", package]) | |
install_packages() | |
from dotenv import load_dotenv | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from huggingface_hub import login | |
import openai | |
# Load environment variables from .env file | |
load_dotenv() | |
# Get API tokens from environment variables | |
openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file | |
hf_token = os.getenv("hf_token") | |
if not hf_token: | |
raise ValueError("Hugging Face token is missing. Please set 'hf_token' as an environment variable.") | |
# Log in to Hugging Face | |
login(hf_token) | |
print("Logged in to Hugging Face successfully.") | |
# Set up embeddings and vector store | |
embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large") | |
collection_name = 'report-10k-2024' | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory='./report_10kdb', | |
embedding_function=embeddings | |
) | |
# Set up the retriever | |
retriever = vectorstore_persisted.as_retriever( | |
search_type='similarity', | |
search_kwargs={'k': 5} | |
) | |
# Define Q&A system messages | |
qna_system_message = """ | |
You are an AI assistant to help Finsights Grey Inc., an innovative financial technology firm, develop a Retrieval-Augmented Generation (RAG) system to automate the extraction, summarization, and analysis of information from 10-K reports. Your knowledge base was last updated in August 2023. | |
User input will have the context required by you to answer user questions. This context will begin with the token: ###Context. | |
The context contains references to specific portions of a 10-K report relevant to the user query. | |
User questions will begin with the token: ###Question. | |
Your response should only be about the question asked and the context provided. | |
Do not mention anything about the context in your final answer. | |
If the answer is not found in the context, it is very important for you to respond with "I don't know." | |
Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: | |
Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources. | |
Here is an example of how to structure your response: | |
Answer: | |
[Answer] | |
Source: | |
[Source] | |
""" | |
qna_user_message_template = """ | |
###Context | |
Here are some documents that are relevant to the question. | |
{context} | |
""" | |
# Define the predict function | |
def predict(user_input, company): | |
# Define filter based on company and fetch relevant document chunks | |
filter = f"dataset/{company}-10-k-2023.pdf" | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter}) | |
# Create context for query | |
context_list = [d.page_content for d in relevant_document_chunks] | |
context_for_query = ".".join(context_list) # Ensure this is being assigned correctly | |
# Create messages for OpenAI model | |
prompt = [ {'role': 'system', 'content': qna_system_message}, {'role': 'user', 'content': qna_user_message_template.format( context=context_for_query,question=user_input )} ] | |
# Get response from OpenAI LLM | |
try: | |
response = openai.ChatCompletion.create( | |
model='gpt-3.5-turbo', | |
messages=prompt, | |
temperature=0 | |
) | |
prediction = response['choices'][0]['message']['content'] | |
except Exception as e: | |
prediction = f"Error: {str(e)}" | |
return prediction | |
# Example set of questions and company names | |
examples = [ | |
["What are the company's policies and frameworks regarding AI ethics, governance, and responsible AI use as detailed in their 10-K reports?", "AWS"], | |
["What are the primary business segments of the company, and how does each segment contribute to the overall revenue and profitability?", "AWS"], | |
["What are the key risk factors identified in the 10-K report that could potentially impact the company's business operations and financial performance?", "AWS"], | |
["Has the company made any significant acquisitions in the AI space, and how are these acquisitions being integrated into the company's strategy?", "Microsoft"], | |
["How much capital has been allocated towards AI research and development?","Google"], | |
["What initiatives has the company implemented to address ethical concerns surrounding AI, such as fairness, accountability, and privacy?","IBM"], | |
["How does the company plan to differentiate itself in the AI space relative to competitors?","Meta"] | |
] | |
# Define function to handle the prediction process based on user input | |
def get_predict(question, company): | |
# Check for valid company selection | |
if company == "AWS": | |
selectedCompany = "aws" | |
elif company == "IBM": | |
selectedCompany = "IBM" | |
elif company == "Google": | |
selectedCompany = "Google" | |
elif company == "Meta": | |
selectedCompany = "meta" | |
elif company == "Microsoft": | |
selectedCompany = "msft" | |
else: | |
return "Invalid company selected" | |
# Check if question matches any example | |
for example_question, example_company in examples: | |
if question == example_question and selectedCompany == example_company: | |
return f"This is the output for the example question: {example_question}" | |
# Perform prediction | |
output = predict(question, selectedCompany) | |
return output | |
# Set-up the Gradio UI | |
# Add text box and radio button to the interface | |
# The radio button is used to select the company 10k report in which the context needs to be retrieved. | |
with gr.Blocks(theme="gr.themes.Monochrome()") as demo: | |
with gr.Row(): | |
company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company") | |
with gr.Row(): | |
question = gr.Textbox(label="Enter your question") | |
submit = gr.Button("Submit") | |
output = gr.Textbox(label="Output") | |
submit.click( | |
fn=get_predict, | |
inputs=[question, company], | |
outputs=output | |
) | |
examples_component = gr.Examples(examples=examples, inputs=[question, company]) | |
demo.queue() | |
demo.launch() | |