Spaces:
Sleeping
Sleeping
## Setup | |
# Import the necessary Libraries | |
import os | |
import uuid | |
import joblib | |
import json | |
import tiktoken | |
import pandas as pd | |
import gradio as gr | |
from openai import OpenAI | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_core.documents import Document | |
from langchain_community.document_loaders import PyPDFDirectoryLoader | |
from langchain_community.embeddings.sentence_transformer import ( | |
SentenceTransformerEmbeddings | |
) | |
from langchain_community.vectorstores import Chroma | |
from huggingface_hub import CommitScheduler | |
from pathlib import Path | |
# Create Client | |
os.environ['OPENAI_API_KEY'] = "gl-U2FsdGVkX1+0bNWD6YsVLZUYsn0m1WfLxUzrP0xUFbtWFAfk9Z1Cz+mD8u1yqKtV"; # e.g. gl-U2FsdGVkX19oG1mRO+LGAiNeC7nAeU8M65G4I6bfcdI7+9GUEjFFbplKq48J83by | |
os.environ["OPENAI_BASE_URL"] = "https://aibe.mygreatlearning.com/openai/v1" # e.g. "https://aibe.mygreatlearning.com/openai/v1"; | |
client = OpenAI() | |
# Define the embedding model and the vectorstore | |
model_name = 'gpt-4o-mini' | |
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
# Load the persisted vectorDB | |
persisted_vectordb_location = '10k-reports_db' | |
collection_name = '10k-reports' | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
persist_directory=persisted_vectordb_location, | |
embedding_function=embedding_model | |
) | |
vectorstore_persisted.get() | |
# Prepare the logging functionality | |
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
log_folder = log_file.parent | |
scheduler = CommitScheduler( | |
repo_id="Keytaro/10K-reports-mlops-logs", | |
repo_type="dataset", | |
folder_path=log_folder, | |
path_in_repo="data", | |
every=2 | |
) | |
# Define the Q&A system message | |
qna_system_message = """ | |
You are an assistant to a Gen AI Data Scientist. Your task is to automate the extraction, summarization, and analysis of information from the 10-K reports. | |
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. | |
The context contains references to specific portions of documents relevant to the user's query, along with source links. | |
The source for a context will begin with the token ###Source | |
When crafting your response: | |
1. Select only context relevant to answer the question. | |
2. Include the source links in your response. | |
3. User questions will begin with the token: ###Question. | |
4. If the question is irrelevant to streamlit respond with - "I am an assistant for Gen AI Data Scientist. I can only help you with questions related to 10-K reports." | |
Please adhere to the following guidelines: | |
- Your response should only be about the question asked and nothing else. | |
- Answer only using the context provided. | |
- Do not mention anything about the context in your final answer. | |
- If the answer is not found in the context, it is very very important for you to respond with "I don't know. Please check the 10-K reports" | |
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: | |
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources. | |
Here is an example of how to structure your response: | |
Answer: | |
[Answer] | |
Source: | |
[Source] | |
""" | |
# Define the user message template | |
qna_user_message_template = """ | |
###Context | |
Here are some documents and their source links that are relevant to the question mentioned below. | |
{context} | |
###Question | |
{question} | |
""" | |
# Define the predict function that runs when 'Submit' is clicked or when a API request is made | |
def predict(user_input,company): | |
companyfile = { | |
"Amazon": "aws", | |
"Google": "google", | |
"Microsoft": "msft", | |
"Meta": "Meta", | |
"IBM": "IBM" | |
}.get(company, None) | |
if companyfile is not None: | |
user_input = user_input.replace("the company", company) | |
filter = "dataset/"+companyfile+"-10-k-2023.pdf" | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter}) | |
# Create context_for_query | |
context_list = [d.page_content + f"\n ###Source: \'{d.metadata['source']}\', p.{d.metadata['page']}\n\n " for d in relevant_document_chunks] | |
context_for_query = ". ".join(context_list) | |
# Create messages | |
prompt = [ | |
{'role':'system', 'content': qna_system_message}, | |
{'role': 'user', 'content': qna_user_message_template.format( | |
context=context_for_query, | |
question=user_input | |
) | |
} | |
] | |
# Get response from the LLM | |
try: | |
response = client.chat.completions.create( | |
model=model_name, | |
messages=prompt, | |
temperature=0 | |
) | |
prediction = response.choices[0].message.content.strip() | |
except Exception as e: | |
prediction = f'Sorry, I encountered the following error: \n {e}' | |
# While the prediction is made, log both the inputs and outputs to a local log file | |
# While writing to the log file, ensure that the commit scheduler is locked to avoid parallel | |
# access | |
with scheduler.lock: | |
with log_file.open("a") as f: | |
f.write(json.dumps( | |
{ | |
'user_input': user_input, | |
'retrieved_context': context_for_query, | |
'model_response': prediction | |
} | |
)) | |
f.write("\n") | |
return (prediction, user_input, context_for_query) | |
# Set-up the Gradio UI | |
# Add text box and radio button to the interface | |
# The radio button is used to select the company 10k report in which the context needs to be retrieved. | |
textbox = gr.Textbox() | |
company = gr.Radio() | |
inputs = [ | |
gr.Radio(label="user_input", choices=["Has the company made any significant acquisitions in the AI space, and how are these acquisitions being integrated into the company's strategy?", | |
"How much capital has been allocated towards AI research and development by the company?", | |
"What initiatives has the company implemented to address ethical concerns surrounding AI, such as fairness, accountability, and privacy?", | |
"How does the company plan to differentiate itself in the AI space relative to competitors?", | |
"What are the company’s policies and frameworks regarding AI ethics, governance, and responsible AI use as detailed in their 10-K reports?", | |
"What are the primary business segments of the company, and how does each segment contribute to the overall revenue and profitability?", | |
"What are the key risk factors identified in the 10-K report that could potentially impact the company’s business operations and financial performance?" | |
]), | |
gr.Radio(label="Company", choices=["Amazon", "Google", "Microsoft", "Meta", "IBM"]), | |
] | |
output = [ | |
gr.Textbox(label="Answer"), | |
gr.Textbox(label="query"), | |
gr.Textbox(label="context_for_query") | |
] | |
# Create the interface | |
# For the inputs parameter of Interface provide [textbox,company] | |
demo = gr.Interface( | |
fn=predict, | |
inputs=inputs, | |
outputs=output, | |
title="10-K reports RAG system", | |
description="This API allows you to answer one of the 5 questions based on 10-K reports.", | |
allow_flagging="auto", # | |
concurrency_limit=8 # | |
) | |
demo.queue() | |
demo.launch() | |