kajila commited on
Commit
9bccc31
·
verified ·
1 Parent(s): e3da40f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -76
app.py CHANGED
@@ -1,33 +1,26 @@
1
-
2
- # Import the necessary Libraries
3
  import os
4
  import uuid
5
  import json
6
 
7
  import gradio as gr
8
- #import openai
9
- #import load_dotenv
10
- !pip install openai
11
  from openai import OpenAI
12
-
13
  from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
14
  from langchain_community.vectorstores import Chroma
15
-
16
  from huggingface_hub import CommitScheduler
17
  from pathlib import Path
18
  from dotenv import load_dotenv
19
 
20
-
21
- # Create Client
22
  load_dotenv()
23
 
24
- os.environ["OPENAI_API_KEY"] = "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"
 
25
 
 
26
  client = OpenAI()
27
 
 
28
  embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
29
- # Define the embedding model and the vectorstore
30
-
31
  collection_name = 'report-10k-2024'
32
 
33
  vectorstore_persisted = Chroma(
@@ -36,16 +29,12 @@ vectorstore_persisted = Chroma(
36
  embedding_function=embeddings
37
  )
38
 
39
- # Load the persisted vectorDB
40
-
41
  retriever = vectorstore_persisted.as_retriever(
42
  search_type='similarity',
43
  search_kwargs={'k': 5}
44
  )
45
 
46
-
47
- # Prepare the logging functionality
48
-
49
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
50
  log_folder = log_file.parent
51
 
@@ -58,13 +47,10 @@ scheduler = CommitScheduler(
58
  )
59
 
60
  # Define the Q&A system message
61
-
62
  qna_system_message = """
63
  You are an AI assistant to help Finsights Grey Inc., an innovative financial technology firm, develop a Retrieval-Augmented Generation (RAG) system to automate the extraction, summarization, and analysis of information from 10-K reports. Your knowledge base was last updated in August 2023.
64
-
65
  User input will have the context required by you to answer user questions. This context will begin with the token: ###Context.
66
  The context contains references to specific portions of a 10-K report relevant to the user query.
67
-
68
  User questions will begin with the token: ###Question.
69
  Your response should only be about the question asked and the context provided.
70
  Answer only using the context provided.
@@ -72,44 +58,29 @@ Do not mention anything about the context in your final answer.
72
  If the answer is not found in the context, it is very important for you to respond with "I don't know."
73
  Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
74
  Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
75
- Here is an example of how to structure your response:
76
-
77
- Answer:
78
- [Answer]
79
-
80
- Source:
81
- [Source]
82
  """
83
 
84
- # Define the user message template
85
  qna_user_message_template = """
86
  ###Context
87
  Here are some documents that are relevant to the question.
88
  {context}
89
- ```
90
  {question}
91
- ```
92
- """
93
 
94
- # Define the predict function that runs when 'Submit' is clicked or when a API request is made
95
- def predict(user_input,company):
 
96
 
97
- filter = "dataset/"+company+"-10-k-2023.pdf"
98
- relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
 
 
99
 
100
- # Create context_for_query
101
  context_list = [d.page_content for d in relevant_document_chunks]
102
  context_for_query = ".".join(context_list)
103
 
104
  # Create messages
105
- prompt = [
106
- {'role':'system', 'content': qna_system_message},
107
- {'role': 'user', 'content': qna_user_message_template.format(
108
- context=context_for_query,
109
- question=user_input
110
- )
111
- }
112
- ]
113
 
114
  # Get response from the LLM
115
  try:
@@ -122,12 +93,9 @@ def predict(user_input,company):
122
  prediction = response.choices[0].message.content
123
 
124
  except Exception as e:
125
- prediction = e
126
-
127
- # While the prediction is made, log both the inputs and outputs to a local log file
128
- # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
129
- # access
130
 
 
131
  with scheduler.lock:
132
  with log_file.open("a") as f:
133
  f.write(json.dumps(
@@ -144,40 +112,26 @@ def predict(user_input,company):
144
 
145
  def get_predict(question, company):
146
  # Implement your prediction logic here
147
- if company == "AWS":
148
- # Perform prediction for AWS
149
- selectedCompany = "aws"
150
- elif company == "IBM":
151
- # Perform prediction for IBM
152
- selectedCompany = "IBM"
153
- elif company == "Google":
154
- # Perform prediction for Google
155
- selectedCompany = "Google"
156
- elif company == "Meta":
157
- # Perform prediction for Meta
158
- selectedCompany = "meta"
159
- elif company == "Microsoft":
160
- # Perform prediction for Microsoft
161
- selectedCompany = "msft"
162
- else:
163
  return "Invalid company selected"
164
-
165
- output = predict(question, selectedCompany)
166
- return output
167
 
168
  # Set-up the Gradio UI
169
- # Add text box and radio button to the interface
170
- # The radio button is used to select the company 10k report in which the context needs to be retrieved.
171
-
172
- # Create the interface
173
- # For the inputs parameter of Interface provide [textbox,company]
174
-
175
  with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
176
  with gr.Row():
177
  company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
178
  question = gr.Textbox(label="Enter your question")
179
 
180
-
181
  submit = gr.Button("Submit")
182
  output = gr.Textbox(label="Output")
183
 
@@ -188,4 +142,4 @@ with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
188
  )
189
 
190
  demo.queue()
191
- demo.launch()
 
 
 
1
  import os
2
  import uuid
3
  import json
4
 
5
  import gradio as gr
 
 
 
6
  from openai import OpenAI
 
7
  from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
8
  from langchain_community.vectorstores import Chroma
 
9
  from huggingface_hub import CommitScheduler
10
  from pathlib import Path
11
  from dotenv import load_dotenv
12
 
13
+ # Load environment variables from .env file
 
14
  load_dotenv()
15
 
16
+ # Set OpenAI API key
17
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Make sure OPENAI_API_KEY is in your .env file
18
 
19
+ # Initialize OpenAI client
20
  client = OpenAI()
21
 
22
+ # Set up embeddings and vectorstore
23
  embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
 
 
24
  collection_name = 'report-10k-2024'
25
 
26
  vectorstore_persisted = Chroma(
 
29
  embedding_function=embeddings
30
  )
31
 
 
 
32
  retriever = vectorstore_persisted.as_retriever(
33
  search_type='similarity',
34
  search_kwargs={'k': 5}
35
  )
36
 
37
+ # Set up logging
 
 
38
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
39
  log_folder = log_file.parent
40
 
 
47
  )
48
 
49
  # Define the Q&A system message
 
50
  qna_system_message = """
51
  You are an AI assistant to help Finsights Grey Inc., an innovative financial technology firm, develop a Retrieval-Augmented Generation (RAG) system to automate the extraction, summarization, and analysis of information from 10-K reports. Your knowledge base was last updated in August 2023.
 
52
  User input will have the context required by you to answer user questions. This context will begin with the token: ###Context.
53
  The context contains references to specific portions of a 10-K report relevant to the user query.
 
54
  User questions will begin with the token: ###Question.
55
  Your response should only be about the question asked and the context provided.
56
  Answer only using the context provided.
 
58
  If the answer is not found in the context, it is very important for you to respond with "I don't know."
59
  Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
60
  Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
 
 
 
 
 
 
 
61
  """
62
 
 
63
  qna_user_message_template = """
64
  ###Context
65
  Here are some documents that are relevant to the question.
66
  {context}
 
67
  {question}
 
 
68
 
69
+ css
70
+ Copy code
71
+ """
72
 
73
+ # Define the predict function
74
+ def predict(user_input, company):
75
+ filter = "dataset/" + company + "-10-k-2023.pdf"
76
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})
77
 
78
+ # Create context for query
79
  context_list = [d.page_content for d in relevant_document_chunks]
80
  context_for_query = ".".join(context_list)
81
 
82
  # Create messages
83
+ prompt = [ {'role': 'system', 'content': qna_system_message}, {'role': 'user', 'content': qna_user_message_template.format( context=context_for_query, question=user_input )} ]
 
 
 
 
 
 
 
84
 
85
  # Get response from the LLM
86
  try:
 
93
  prediction = response.choices[0].message.content
94
 
95
  except Exception as e:
96
+ prediction = str(e)
 
 
 
 
97
 
98
+ # Log inputs and outputs to a local log file
99
  with scheduler.lock:
100
  with log_file.open("a") as f:
101
  f.write(json.dumps(
 
112
 
113
  def get_predict(question, company):
114
  # Implement your prediction logic here
115
+ company_map = {
116
+ "AWS": "aws",
117
+ "IBM": "IBM",
118
+ "Google": "Google",
119
+ "Meta": "meta",
120
+ "Microsoft": "msft"
121
+ }
122
+
123
+ selected_company = company_map.get(company)
124
+ if not selected_company:
 
 
 
 
 
 
125
  return "Invalid company selected"
126
+
127
+ return predict(question, selected_company)
 
128
 
129
  # Set-up the Gradio UI
 
 
 
 
 
 
130
  with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
131
  with gr.Row():
132
  company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
133
  question = gr.Textbox(label="Enter your question")
134
 
 
135
  submit = gr.Button("Submit")
136
  output = gr.Textbox(label="Output")
137
 
 
142
  )
143
 
144
  demo.queue()
145
+ demo.launch()