anirudhabokil commited on
Commit
f9149ca
·
verified ·
1 Parent(s): 2241964

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -14
app.py CHANGED
@@ -1,23 +1,156 @@
1
- import os
2
- import time
 
3
  import json
 
 
 
4
  import uuid
5
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- user_question = gr.Textbox(label="Ask your question")
8
- model_output = gr.Label(label="Answer")
9
 
10
- def qa_bot(user_question):
11
- return os.environ["OPENAI_BASE_URL"] + " Answer"
 
12
 
13
- demo = gr.Interface(
14
- fn=qa_bot,
15
- inputs=user_question,
16
- outputs=model_output,
17
- title="Ask your question",
18
- flagging_mode="auto",
19
- concurrency_limit=8)
 
 
 
 
 
20
 
21
  demo.queue()
22
  demo.launch(share=True, debug=True)
23
- # demo.close()
 
1
+
2
+ ## Setup
3
+ # Import the necessary Libraries
4
  import json
5
+ import tiktoken
6
+ import os
7
+ import pandas as pd
8
  import uuid
9
  import gradio as gr
10
+ from openai import OpenAI
11
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
12
+ from langchain_community.document_loaders import PyPDFDirectoryLoader
13
+ from langchain_community.embeddings.sentence_transformer import (
14
+ SentenceTransformerEmbeddings
15
+ )
16
+ from langchain_community.vectorstores import Chroma
17
+ from huggingface_hub import CommitScheduler
18
+ from pathlib import Path
19
+
20
+ # Create Client
21
+ client = OpenAI()
22
+
23
+ # Define the embedding model and the vectorstore
24
+ collection_name = 'project3_rag_db'
25
+ embedding_model_name = 'thenlper/gte-large'
26
+ embedding_model = SentenceTransformerEmbeddings(model_name=embedding_model_name)
27
+ persisted_vectordb_location = '/content/drive/MyDrive/project3_rag_db'
28
+ model_name = 'gpt-4o-mini'
29
+ # Load the persisted vectorDB
30
+ vectorstore_persisted = Chroma(
31
+ collection_name=collection_name,
32
+ persist_directory=persisted_vectordb_location,
33
+ embedding_function=embedding_model)
34
+
35
+ # Prepare the logging functionality
36
+
37
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
38
+ log_folder = log_file.parent
39
+
40
+ scheduler = CommitScheduler(
41
+ repo_id="anirudhabokil/project3_rag_10K_chatbot_logs",
42
+ repo_type="dataset",
43
+ folder_path=log_folder,
44
+ path_in_repo="data",
45
+ every=2
46
+ )
47
+
48
+ # Define the Q&A system message
49
+ qna_system_message = """
50
+ You are an assistant to a Financial Analyst for a Fin tech company. Your task is to provide relevant information about analysis of key information from 10-K reports.
51
+ 10-K reports are comprehensive annual reports filed by publicly traded companies in the United States with the Securities and Exchange Commission (SEC).
52
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
53
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
54
+ The source for a context will begin with the token ###Source:
55
+
56
+ When crafting your response:
57
+ 1. Select only context relevant to answer the question.
58
+ 2. Include the source links in your response.
59
+ 3. User questions will begin with the token: ###Question.
60
+ 4. If the question is irrelevant to 10-K reports respond with - "I am an assistant to a Financial Analyst. I can only help you with questions related to 10-K reports"
61
+
62
+ Please adhere to the following guidelines:
63
+ - Your response should only be about the question asked and nothing else.
64
+ - Answer only using the context provided.
65
+ - Do not mention anything about the context in your final answer.
66
+ - If the answer is not found in the context, it is very very important for you to respond with "I don't know."
67
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
68
+ - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
69
+
70
+ Here is an example of how to structure your response:
71
+
72
+ Answer:
73
+ [Answer]
74
+
75
+ Source:
76
+ [Use the ###Source provided in the context as it. Do not add https prefix]
77
+ """
78
+
79
+ # Define the user message template
80
+ qna_user_message_template = """
81
+ ###Context
82
+ Here are some 10-K reports and their source links that are relevant to the question mentioned below.
83
+ {context}
84
+
85
+ ###Question
86
+ {question}
87
+ """
88
+
89
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
90
+ def predict(user_input,company):
91
+
92
+ filter = "/content/dataset/"+company+"-10-k-2023.pdf"
93
+ print(filter)
94
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
95
+ print(relevant_document_chunks)
96
+ # Create context_for_query
97
+ context_list = [d.page_content + "\n ###Source: " + d.metadata['source'] + '\n\n ' for d in relevant_document_chunks]
98
+ context_for_query = ". ".join(context_list)
99
+ print(context_for_query)
100
+
101
+ # Create messages
102
+ prompt = [
103
+ {'role': 'system', 'content': qna_system_message},
104
+ {'role': 'user', 'content': qna_user_message_template.format(
105
+ context=context_for_query,
106
+ question=user_input
107
+ )
108
+ }]
109
+
110
+ print(prompt)
111
+ # Get response from the LLM
112
+
113
+ try:
114
+ response = client.chat.completions.create(model=model_name,messages=prompt,temperature=0)
115
+ print(response)
116
+ answer = response.choices[0].message.content.strip()
117
+ # Handle errors using try-except
118
+ except Exception as e:
119
+ answer = f'Sorry, I encountered the following error: \n {e}'
120
+
121
+ # While the prediction is made, log both the inputs and outputs to a local log file
122
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
123
+ # access
124
+
125
+ with scheduler.lock:
126
+ with log_file.open("a") as f:
127
+ f.write(json.dumps(
128
+ {
129
+ 'user_input': user_input,
130
+ 'retrieved_context': context_for_query,
131
+ 'model_response': answer
132
+ }
133
+ ))
134
+ f.write("\n")
135
 
136
+ return answer
 
137
 
138
+ # Set-up the Gradio UI
139
+ # Add text box and radio button to the interface
140
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
141
 
142
+ user_input = gr.Textbox(label="Ask your question")
143
+ company = gr.Dropdown(['aws','google','IBM','Meta','msft'], label="Company")
144
+ answer = gr.Label(label="Answer")
145
+ # Create the interface
146
+ # For the inputs parameter of Interface provide [textbox,company]
147
+ demo = gr.Interface(fn=predict,
148
+ inputs=[user_input, company],
149
+ outputs=answer,
150
+ title="10-K Chatbot",
151
+ description="This API answers questions based on 10-k reports",
152
+ flagging_mode="auto",
153
+ concurrency_limit=8)
154
 
155
  demo.queue()
156
  demo.launch(share=True, debug=True)