Spaces:
Runtime error
Runtime error
##Setup | |
#Import the necessary Libraries | |
import os | |
import uuid | |
import json | |
import gradio as gr | |
from openai import OpenAI | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from huggingface_hub import CommitScheduler | |
from pathlib import Path | |
# Create Client | |
client = OpenAI( | |
#base_url="https://api.endpoints.anyscale.com/v1", | |
api_key=os.environ["anyscale_api_key"] | |
) | |
# Define the embedding model and the vectorstore | |
embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large') | |
# Load the persisted vectorDB | |
collection_name = 'finsights-grey-10k-2023' | |
vectorstore_persisted = Chroma( | |
collection_name=collection_name, | |
embedding_function=embedding_model, | |
persist_directory='finsights_db' | |
) | |
retriever = vectorstore_persisted.as_retriever( | |
search_type="similarity", | |
search_kwargs={'k': 5}, | |
) | |
# Prepare the logging functionality | |
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" | |
log_folder = log_file.parent | |
scheduler = CommitScheduler( | |
repo_id = "finsight-qna", | |
repo_type = "dataset", | |
folder_path = log_folder, | |
path_in_repo = "data", | |
every = 2 | |
) | |
# Define the Q&A system message | |
qna_system_message = """ | |
You are an assistant to a financial technology firm who answers user queries on 10-K reports from various industry players which contain detailed information about financial performance, risk factors, market trends, and strategic initiatives. | |
User input will have the context required by you to answer user questions. | |
This context will begin with the token: ###Context. | |
When crafting your response,select the most relevant context or contexts to answer the question. | |
User questions will begin with the token: ###Question. | |
Please answer only using the context provided in the input. Do not mention anything about the context in your final answer. | |
If the answer is not found in the context, respond "I don't know". | |
""" | |
# Define the user message template | |
qna_user_message_template = """ | |
###Context | |
Here are some documents that are relevant to the question mentioned below.- | |
{context} | |
###Question | |
{question} | |
""" | |
# Define the predict function that runs when 'Submit' is clicked or when an API request is made | |
def predict(user_input, company): | |
filter = {"source": f"/content/dataset/{company}-10-k-2023.pdf"} | |
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter=filter) | |
# Create context_for_query | |
context_list = [f"Page {doc.metadata['page']}: {doc.page_content}" for doc in relevant_document_chunks] | |
context_for_query = ".".join(context_list) | |
# Create messages | |
prompt = [ | |
{'role': 'system', 'content': qna_system_message}, | |
{'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)} | |
] | |
# Get response from the LLM | |
try: | |
response = client.chat.completions.create( | |
model="mlabonne/NeuralHermes-2.5-Mistral-7B", | |
messages=prompt, | |
temperature=0 | |
) | |
prediction = response.choices[0].message.content | |
except Exception as e: | |
prediction = f'Sorry, I encountered the following error: \n {e}' | |
print(prediction) | |
# Log both the inputs and outputs to a local log file | |
with scheduler.lock: | |
with log_file.open("a") as f: | |
f.write(json.dumps({ | |
'user_input': user_input, | |
'retrieved_context': context_for_query, | |
'model_response': prediction | |
})) | |
f.write("\n") | |
return prediction | |
# Set-up the Gradio UI | |
# Add text box and radio button to the interface | |
# The radio button is used to select the company 10k report in which the context needs to be retrieved. | |
textbox = gr.Textbox(placeholder="Enter your query here") | |
company = gr.Radio(choices=["IBM", "META", "aws", "google", "msft"], label="Company") | |
# Create the interface | |
demo = gr.Interface( | |
inputs=[textbox, company], | |
fn=predict, | |
outputs="text", | |
description="This web API presents an interface to ask questions on contents of IBM, META, AWS, GOOGLE and MSFT 10-K reports for the year 2023", | |
article="Note that questions that are not relevant to the aforementioned companies' 10-K reports will not be answered", | |
title="Q&A for IBM, META, AWS, GOOG & MSFT 10-K Statements", | |
examples=[ | |
["Has the company made any significant acquisitions in the AI space, and how are these acquisitions being integrated into the company's strategy?", "IBM"], | |
["How much capital has been allocated towards AI research and development?", "META"], | |
["What initiatives has the company implemented to address ethical concerns surrounding AI, such as fairness, accountability, and privacy?", "aws"], | |
["How does the company plan to differentiate itself in the AI space relative to competitors?", "google"] | |
], | |
concurrency_limit=16 | |
) | |
demo.queue() | |
demo.launch(share=True, debug=False) | |