kajila commited on
Commit
a4f4daa
·
verified ·
1 Parent(s): 3db88f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -51
app.py CHANGED
@@ -1,38 +1,30 @@
1
  import subprocess
2
  import sys
3
- # Install openai if it is not already installed
4
- subprocess.check_call([sys.executable, "-m", "pip", "install", "openai"])
5
- # Install langchain_community if it is not already installed
6
- subprocess.check_call([sys.executable, "-m", "pip", "install", "langchain_community"])
7
- # Install sentence-transformers if it is not already installed
8
- subprocess.check_call([sys.executable, "-m", "pip", "install", "sentence-transformers"])
9
- # Install sentence-transformers if it is not already installed
10
- subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
11
- subprocess.check_call([sys.executable, "-m", "pip", "install", "huggingface_hub"])
12
- from huggingface_hub import login
13
- login("RAG")
14
- #huggingface-cli login
15
- import openai
16
  import os
17
  import uuid
18
  import json
 
 
 
19
  import gradio as gr
20
- #from openai import OpenAI
21
  from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
22
  from langchain_community.vectorstores import Chroma
23
- #from huggingface_hub import login
24
- #login("RAG")
25
- from huggingface_hub import CommitScheduler
26
- from pathlib import Path
27
- from dotenv import load_dotenv
28
 
29
  # Load environment variables from .env file
30
  load_dotenv()
31
 
32
- # Set OpenAI API key
33
- openai.api_key = os.getenv("OPENAI_API_KEY") # Make sure OPENAI_API_KEY is in your .env file
 
 
 
34
 
35
- # Initialize OpenAI client
 
36
  client = openai
37
 
38
  # Set up embeddings and vectorstore
@@ -50,7 +42,7 @@ retriever = vectorstore_persisted.as_retriever(
50
  search_kwargs={'k': 5}
51
  )
52
 
53
- # Set up logging
54
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
55
  log_folder = log_file.parent
56
 
@@ -64,26 +56,19 @@ scheduler = CommitScheduler(
64
 
65
  # Define the Q&A system message
66
  qna_system_message = """
67
- You are an AI assistant to help Finsights Grey Inc., an innovative financial technology firm, develop a Retrieval-Augmented Generation (RAG) system to automate the extraction, summarization, and analysis of information from 10-K reports. Your knowledge base was last updated in August 2023.
68
- User input will have the context required by you to answer user questions. This context will begin with the token: ###Context.
69
- The context contains references to specific portions of a 10-K report relevant to the user query.
70
- User questions will begin with the token: ###Question.
71
- Your response should only be about the question asked and the context provided.
72
- Answer only using the context provided.
73
- Do not mention anything about the context in your final answer.
74
- If the answer is not found in the context, it is very important for you to respond with "I don't know."
75
- Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
76
- Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
77
  """
78
 
79
  qna_user_message_template = """
80
  ###Context
81
  Here are some documents that are relevant to the question.
82
  {context}
 
83
  {question}
84
-
85
- css
86
- Copy code
87
  """
88
 
89
  # Define the predict function
@@ -96,7 +81,10 @@ def predict(user_input, company):
96
  context_for_query = ".".join(context_list)
97
 
98
  # Create messages
99
- prompt = [ {'role': 'system', 'content': qna_system_message}, {'role': 'user', 'content': qna_user_message_template.format( context=context_for_query, question=user_input )} ]
 
 
 
100
 
101
  # Get response from the LLM
102
  try:
@@ -105,29 +93,24 @@ def predict(user_input, company):
105
  messages=prompt,
106
  temperature=0
107
  )
108
-
109
  prediction = response.choices[0].message.content
110
-
111
  except Exception as e:
112
  prediction = str(e)
113
 
114
  # Log inputs and outputs to a local log file
115
  with scheduler.lock:
116
  with log_file.open("a") as f:
117
- f.write(json.dumps(
118
- {
119
- 'user_input': user_input,
120
- 'retrieved_context': context_for_query,
121
- 'model_response': prediction
122
- }
123
- ))
124
  f.write("\n")
125
 
126
  return prediction
127
 
128
-
129
  def get_predict(question, company):
130
- # Implement your prediction logic here
131
  company_map = {
132
  "AWS": "aws",
133
  "IBM": "IBM",
@@ -135,14 +118,13 @@ def get_predict(question, company):
135
  "Meta": "meta",
136
  "Microsoft": "msft"
137
  }
138
-
139
  selected_company = company_map.get(company)
140
  if not selected_company:
141
  return "Invalid company selected"
142
 
143
  return predict(question, selected_company)
144
 
145
- # Set-up the Gradio UI
146
  with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
147
  with gr.Row():
148
  company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
@@ -158,4 +140,6 @@ with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
158
  )
159
 
160
  demo.queue()
161
- demo.launch()
 
 
 
1
  import subprocess
2
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import uuid
5
  import json
6
+ from pathlib import Path
7
+ from dotenv import load_dotenv
8
+ from huggingface_hub import login, CommitScheduler
9
  import gradio as gr
 
10
  from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
11
  from langchain_community.vectorstores import Chroma
12
+ import openai
13
+
14
+ # Install required libraries if not already installed
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "openai", "langchain_community", "sentence-transformers", "chromadb", "huggingface_hub", "python-dotenv"])
 
16
 
17
  # Load environment variables from .env file
18
  load_dotenv()
19
 
20
+ # Login to Hugging Face using token from environment variables
21
+ hf_token = os.getenv("HF_TOKEN")
22
+ if not hf_token:
23
+ raise ValueError("Hugging Face token not found in environment variables. Set HF_TOKEN in your .env file.")
24
+ login(hf_token)
25
 
26
+ # Set OpenAI API key from environment variables
27
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file
28
  client = openai
29
 
30
  # Set up embeddings and vectorstore
 
42
  search_kwargs={'k': 5}
43
  )
44
 
45
+ # Define logging configuration
46
  log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
47
  log_folder = log_file.parent
48
 
 
56
 
57
  # Define the Q&A system message
58
  qna_system_message = """
59
+ You are an AI assistant helping Finsights Grey Inc., a financial technology firm, develop a Retrieval-Augmented Generation (RAG) system to automate extraction, summarization, and analysis of 10-K reports.
60
+ Your knowledge base was last updated in August 2023.
61
+ User questions will start with the token: ###Question.
62
+ Answer only based on the provided context.
63
+ If the answer is not found in the context, respond with "I don't know."
 
 
 
 
 
64
  """
65
 
66
  qna_user_message_template = """
67
  ###Context
68
  Here are some documents that are relevant to the question.
69
  {context}
70
+ ###Question
71
  {question}
 
 
 
72
  """
73
 
74
  # Define the predict function
 
81
  context_for_query = ".".join(context_list)
82
 
83
  # Create messages
84
+ prompt = [
85
+ {'role': 'system', 'content': qna_system_message},
86
+ {'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)}
87
+ ]
88
 
89
  # Get response from the LLM
90
  try:
 
93
  messages=prompt,
94
  temperature=0
95
  )
 
96
  prediction = response.choices[0].message.content
 
97
  except Exception as e:
98
  prediction = str(e)
99
 
100
  # Log inputs and outputs to a local log file
101
  with scheduler.lock:
102
  with log_file.open("a") as f:
103
+ f.write(json.dumps({
104
+ 'user_input': user_input,
105
+ 'retrieved_context': context_for_query,
106
+ 'model_response': prediction
107
+ }))
 
 
108
  f.write("\n")
109
 
110
  return prediction
111
 
112
+ # Define the prediction interface function
113
  def get_predict(question, company):
 
114
  company_map = {
115
  "AWS": "aws",
116
  "IBM": "IBM",
 
118
  "Meta": "meta",
119
  "Microsoft": "msft"
120
  }
 
121
  selected_company = company_map.get(company)
122
  if not selected_company:
123
  return "Invalid company selected"
124
 
125
  return predict(question, selected_company)
126
 
127
+ # Set up the Gradio UI
128
  with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
129
  with gr.Row():
130
  company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
 
140
  )
141
 
142
  demo.queue()
143
+ demo.launch()
144
+
145
+