kajila commited on
Commit
22b25c2
·
verified ·
1 Parent(s): a821671

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +149 -0
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import sys
3
+ import os
4
+ import uuid
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from huggingface_hub import login, CommitScheduler
9
+ import openai
10
+ import gradio as gr
11
+ # Install dependencies if not already installed
12
+ def install_packages():
13
+ packages = ["openai", "langchain_community", "sentence-transformers", "chromadb", "huggingface_hub", "python-dotenv"]
14
+ for package in packages:
15
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
16
+
17
+ install_packages()
18
+ from dotenv import load_dotenv
19
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
20
+ from langchain_community.vectorstores import Chroma
21
+ # Load environment variables from .env file
22
+ load_dotenv()
23
+
24
+ # Get API tokens from environment variables
25
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure OPENAI_API_KEY is in your .env file
26
+ hf_token = os.getenv("hf_token")
27
+
28
+ if not hf_token:
29
+ raise ValueError("Hugging Face token is missing. Please set 'hf_token' as an environment variable.")
30
+
31
+ # Log in to Hugging Face
32
+ login(hf_token)
33
+
34
+ print("Logged in to Hugging Face successfully.")
35
+
36
+ # Set up embeddings and vector store
37
+ embeddings = SentenceTransformerEmbeddings(model_name="thenlper/gte-large")
38
+ collection_name = 'report-10k-2024'
39
+
40
+ vectorstore_persisted = Chroma(
41
+ collection_name=collection_name,
42
+ persist_directory='./report_10kdb',
43
+ embedding_function=embeddings
44
+ )
45
+
46
+ # Set up the retriever
47
+ retriever = vectorstore_persisted.as_retriever(
48
+ search_type='similarity',
49
+ search_kwargs={'k': 5}
50
+ )
51
+
52
+ # Define Q&A system messages
53
+ qna_system_message = """
54
+ You are an AI assistant for Finsights Grey Inc., helping automate extraction, summarization, and analysis of 10-K reports.
55
+ Your responses should be based solely on the context provided.
56
+ If an answer is not found in the context, respond with "I don't know."
57
+ """
58
+
59
+ qna_user_message_template = """
60
+ ###Context
61
+ Here are some documents that are relevant to the question.
62
+ {context}
63
+ ###Question
64
+ {question}
65
+ """
66
+
67
+ # Define the predict function
68
+ def predict(user_input, company):
69
+ filter = f"dataset/{company}-10-k-2023.pdf"
70
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})
71
+
72
+ # Create context for query
73
+ context_list = [d.page_content for d in relevant_document_chunks]
74
+ context_for_query = ".".join(context_list)
75
+
76
+ # Create messages
77
+ prompt = [
78
+ {'role': 'system', 'content': qna_system_message},
79
+ {'role': 'user', 'content': qna_user_message_template.format(context=context_for_query, question=user_input)}
80
+ ]
81
+
82
+ try:
83
+ # Get response from the LLM
84
+ response = openai.Completion.create(
85
+ model='gpt-3.5-turbo',
86
+ messages=prompt,
87
+ temperature=0
88
+ )
89
+ prediction = response['choices'][0]['message']['content']
90
+
91
+ except Exception as e:
92
+ prediction = f"Error: {str(e)}"
93
+
94
+ # Log inputs and outputs to a local log file
95
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
96
+ log_file.parent.mkdir(parents=True, exist_ok=True) # Create log directory if it doesn't exist
97
+
98
+ scheduler = CommitScheduler(
99
+ repo_id="RAGREPORTS-log",
100
+ repo_type="dataset",
101
+ folder_path=log_file.parent,
102
+ path_in_repo="data",
103
+ every=2
104
+ )
105
+
106
+ with scheduler.lock:
107
+ with log_file.open("a") as f:
108
+ f.write(json.dumps(
109
+ {
110
+ 'user_input': user_input,
111
+ 'retrieved_context': context_for_query,
112
+ 'model_response': prediction
113
+ }
114
+ ))
115
+ f.write("\n")
116
+
117
+ return prediction
118
+
119
+ def get_predict(question, company):
120
+ company_map = {
121
+ "AWS": "aws",
122
+ "IBM": "IBM",
123
+ "Google": "Google",
124
+ "Meta": "meta",
125
+ "Microsoft": "msft"
126
+ }
127
+ selected_company = company_map.get(company)
128
+ if not selected_company:
129
+ return "Invalid company selected"
130
+
131
+ return predict(question, selected_company)
132
+
133
+ # Set up the Gradio UI
134
+ with gr.Blocks(theme="gradio/seafoam@>=0.0.1,<0.1.0") as demo:
135
+ with gr.Row():
136
+ company = gr.Radio(["AWS", "IBM", "Google", "Meta", "Microsoft"], label="Select a company")
137
+ question = gr.Textbox(label="Enter your question")
138
+
139
+ submit = gr.Button("Submit")
140
+ output = gr.Textbox(label="Output")
141
+
142
+ submit.click(
143
+ fn=get_predict,
144
+ inputs=[question, company],
145
+ outputs=output
146
+ )
147
+
148
+ demo.queue()
149
+ demo.launch()