|
|
|
|
|
|
|
import json |
|
import tiktoken |
|
import os |
|
import pandas as pd |
|
import uuid |
|
import gradio as gr |
|
from openai import OpenAI |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_community.document_loaders import PyPDFDirectoryLoader |
|
from langchain_community.embeddings.sentence_transformer import ( |
|
SentenceTransformerEmbeddings |
|
) |
|
from langchain_community.vectorstores import Chroma |
|
from huggingface_hub import CommitScheduler |
|
from pathlib import Path |
|
|
|
|
|
client = OpenAI() |
|
|
|
|
|
collection_name = 'project3_rag_db' |
|
embedding_model_name = 'thenlper/gte-large' |
|
embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name) |
|
persisted_vectordb_location = './project3_rag_db' |
|
model_name = 'gpt-4o-mini' |
|
|
|
vectorstore_persisted = Chroma( |
|
collection_name=collection_name, |
|
persist_directory=persisted_vectordb_location, |
|
embedding_function=embedding_model) |
|
|
|
|
|
|
|
log_file = Path("logs/") / f"data_{uuid.uuid4()}.json" |
|
log_folder = log_file.parent |
|
|
|
scheduler = CommitScheduler( |
|
repo_id="anirudhabokil/project3_rag_10K_chatbot_logs", |
|
repo_type="dataset", |
|
folder_path=log_folder, |
|
path_in_repo="data", |
|
every=2 |
|
) |
|
|
|
|
|
qna_system_message = """ |
|
You are an assistant to a Financial Analyst for a Fin tech company. Your task is to provide relevant information about analysis of key information from 10-K reports. |
|
10-K reports are comprehensive annual reports filed by publicly traded companies in the United States with the Securities and Exchange Commission (SEC). |
|
User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context. |
|
The context contains references to specific portions of documents relevant to the user's query, along with source links. |
|
The source for a context will begin with the token ###Source: |
|
|
|
When crafting your response: |
|
1. Select only context relevant to answer the question. |
|
2. Include the source links in your response. |
|
3. User questions will begin with the token: ###Question. |
|
4. If the question is irrelevant to 10-K reports respond with - "I am an assistant to a Financial Analyst. I can only help you with questions related to 10-K reports" |
|
|
|
Please adhere to the following guidelines: |
|
- Your response should only be about the question asked and nothing else. |
|
- Answer only using the context provided. |
|
- Do not mention anything about the context in your final answer. |
|
- If the answer is not found in the context, it is very very important for you to respond with "I don't know." |
|
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source: |
|
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources. |
|
|
|
Here is an example of how to structure your response: |
|
|
|
Answer: |
|
[Answer] |
|
|
|
Source: |
|
[Use the ###Source provided in the context as it. Do not add https prefix] |
|
""" |
|
|
|
|
|
qna_user_message_template = """ |
|
###Context |
|
Here are some 10-K reports and their source links that are relevant to the question mentioned below. |
|
{context} |
|
|
|
###Question |
|
{question} |
|
""" |
|
|
|
|
|
def predict(user_input,company): |
|
|
|
filter = "/content/dataset/"+company+"-10-k-2023.pdf" |
|
print(filter) |
|
relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter}) |
|
print(relevant_document_chunks) |
|
|
|
context_list = [d.page_content + "\n ###Source: " + d.metadata['source'] + '\n\n ' for d in relevant_document_chunks] |
|
context_for_query = ". ".join(context_list) |
|
print(context_for_query) |
|
|
|
|
|
prompt = [ |
|
{'role': 'system', 'content': qna_system_message}, |
|
{'role': 'user', 'content': qna_user_message_template.format( |
|
context=context_for_query, |
|
question=user_input |
|
) |
|
}] |
|
|
|
print(prompt) |
|
|
|
|
|
try: |
|
response = client.chat.completions.create(model=model_name,messages=prompt,temperature=0) |
|
print(response) |
|
answer = response.choices[0].message.content.strip() |
|
|
|
except Exception as e: |
|
answer = f'Sorry, I encountered the following error: \n {e}' |
|
|
|
|
|
|
|
|
|
|
|
with scheduler.lock: |
|
with log_file.open("a") as f: |
|
f.write(json.dumps( |
|
{ |
|
'user_input': user_input, |
|
'retrieved_context': context_for_query, |
|
'model_response': answer |
|
} |
|
)) |
|
f.write("\n") |
|
|
|
return answer |
|
|
|
|
|
|
|
|
|
|
|
user_input = gr.Textbox(label="Ask your question") |
|
company = gr.Dropdown(['aws','google','IBM','Meta','msft'], label="Company") |
|
answer = gr.Label(label="Answer") |
|
|
|
|
|
demo = gr.Interface(fn=predict, |
|
inputs=[user_input, company], |
|
outputs=answer, |
|
title="10-K Chatbot", |
|
description="This API answers questions based on 10-k reports", |
|
flagging_mode="auto", |
|
concurrency_limit=8) |
|
|
|
demo.queue() |
|
demo.launch(share=True, debug=True) |
|
|