kgauvin603 commited on
Commit
36d6a64
·
verified ·
1 Parent(s): 6a1d8dd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -0
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Setup
2
+ !pip -q install gradio
3
+
4
+ # Install the necessary libraries
5
+ !pip install -q openai==1.23.2 \
6
+ tiktoken==0.6.0 \
7
+ pypdf==4.0.1 \
8
+ langchain==0.1.1 \
9
+ langchain-community==0.0.13 \
10
+ chromadb==0.4.22 \
11
+ sentence-transformers==2.3.1
12
+
13
+ # Import the necessary libraries
14
+ import gradio as gr
15
+ import os
16
+ import uuid
17
+ import json
18
+ import tiktoken
19
+ import pandas as pd
20
+ from openai import OpenAI
21
+ from huggingface_hub import HfApi
22
+ from huggingface_hub import CommitScheduler
23
+ from langchain_community.embeddings.sentence_transformer import (
24
+ SentenceTransformerEmbeddings
25
+ )
26
+ from langchain_community.vectorstores import Chroma
27
+ from google.colab import userdata, drive
28
+ from pathlib import Path
29
+ from langchain.document_loaders import PyPDFDirectoryLoader
30
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
31
+ import json
32
+ import tiktoken
33
+ import pandas as pd
34
+
35
+
36
+
37
+ # Define the embedding model and the vectorstore
38
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
39
+
40
+ # If dataset directory exixts, remove it and all of the contents within
41
+
42
+ if os.path.exists('dataset'):
43
+ !rm -rf dataset
44
+
45
+ # If collection_db exists, remove it and all of the contents within
46
+
47
+ if os.path.exists('collection_db'):
48
+ !rm -rf dataset
49
+
50
+ #Upload Dataset-10k.zip and unzip it dataset folder using -d option
51
+ !unzip Dataset-10k.zip -d dataset
52
+
53
+ # Provide pdf_folder_location
54
+ pdf_folder_location = "dataset"
55
+
56
+ # Load the directory to pdf_loader
57
+ pdf_loader = PyPDFDirectoryLoader(pdf_folder_location)
58
+
59
+ # Create text_splitter using recursive splitter
60
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
61
+ encoding_name='cl100k_base',
62
+ chunk_size=512,
63
+ chunk_overlap=16
64
+ )
65
+
66
+ # Create chunks
67
+ report_chunks = pdf_loader.load_and_split(text_splitter)
68
+
69
+ #Create a Colelction Name
70
+ collection_name = 'collection'
71
+
72
+ # Create the vector Database
73
+ vectorstore = Chroma.from_documents(
74
+ report_chunks,
75
+ embedding_model,
76
+ collection_name=collection_name,
77
+ persist_directory='./collection_db'
78
+ )
79
+
80
+ # Persist the DB
81
+ vectorstore.persist()
82
+
83
+ vectorstore_persisted = Chroma(
84
+ collection_name=collection_name,
85
+ persist_directory='./collection_db',
86
+ embedding_function=embedding_model
87
+ )
88
+
89
+ retriever = vectorstore_persisted.as_retriever(
90
+ search_type='similarity',
91
+ search_kwargs={'k': 5}
92
+ )
93
+
94
+ #Mount the Google Drive
95
+ drive.mount('/content/drive')
96
+
97
+ #Copy the persisted database to your drive
98
+ !cp -r collection_db /content/drive/MyDrive/
99
+
100
+ # Get anyscale api key
101
+ anyscale_api_key = userdata.get('dev-work')
102
+
103
+ # Initialise the client
104
+ client = OpenAI(
105
+ base_url="https://api.endpoints.anyscale.com/v1",
106
+ api_key=anyscale_api_key
107
+ )
108
+ #Provide the model name
109
+ model_name = 'mlabonne/NeuralHermes-2.5-Mistral-7B'
110
+
111
+ # Initialise the embedding model
112
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
113
+
114
+ # Load the persisted DB
115
+ persisted_vectordb_location = '/content/drive/MyDrive/collection_db'
116
+
117
+ #Create a Colelction Name
118
+ collection_name = 'collection'
119
+
120
+ # Load the persisted DB
121
+ vectorstore_persisted = Chroma(
122
+ collection_name=collection_name,
123
+ persist_directory=persisted_vectordb_location,
124
+ embedding_function=embedding_model
125
+ )
126
+
127
+ # Prepare the logging functionality
128
+
129
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
130
+ log_folder = log_file.parent
131
+
132
+ scheduler = CommitScheduler(
133
+ repo_id="kgauvin603/rag-10k-analysis",
134
+ repo_type="dataset",
135
+ folder_path=log_folder,
136
+ path_in_repo="data",
137
+ every=2,
138
+ token=hf_token
139
+ )
140
+
141
+ # Define the Q&A system message
142
+ qna_system_message = """You are an assistant to a financial services firm who answers user queries on annual reports.
143
+ User input will have the context required by you to answer user questions.
144
+ This context will begin with the token: ###Context.
145
+ The context contains references to specific portions of a document relevant to the user query.
146
+
147
+ User questions will begin with the token: ###Question.
148
+
149
+ Please answer only using the context provided in the input. Do not mention anything about the context in your final answer.
150
+
151
+ If the answer is not found in the context, respond "I don't know".
152
+ """
153
+
154
+ # Create a message template
155
+ qna_user_message_template = """
156
+ ###Context
157
+ Here are some documents that are relevant to the question mentioned below.
158
+ {context}
159
+
160
+ ###Question
161
+ {question}
162
+ """
163
+
164
+ # Define the predict function that runs when 'Submit' is clicked or when an API request is made
165
+ def predict(user_input, company):
166
+
167
+ filter = "dataset/" + company + "-10-k-2023.pdf"
168
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source": filter})
169
+
170
+ # Create context_for_query
171
+ context_list = [d.page_content for d in relevant_document_chunks]
172
+ context_for_query = ". ".join(context_list)
173
+
174
+ # Create messages
175
+ prompt = [
176
+ {'role': 'system', 'content': qna_system_message},
177
+ {'role': 'user', 'content': qna_user_message_template.format(
178
+ context=context_for_query,
179
+ question=user_input
180
+ )}
181
+ ]
182
+
183
+ try:
184
+ response = client.chat.completions.create(
185
+ model=model_name,
186
+ messages=prompt,
187
+ temperature=0
188
+ )
189
+
190
+ prediction = response.choices[0].message.content.strip()
191
+ except Exception as e:
192
+ prediction = f'Sorry, I encountered the following error: \n{e}'
193
+
194
+ # Log both the inputs and outputs to a local log file
195
+ # Ensure that the commit scheduler is locked to avoid parallel access
196
+ with scheduler.lock:
197
+ with log_file.open("a") as f:
198
+ f.write(json.dumps(
199
+ {
200
+ 'user_input': user_input,
201
+ 'retrieved_context': context_for_query,
202
+ 'model_response': prediction
203
+ }
204
+ ))
205
+ f.write("\n")
206
+
207
+ return prediction
208
+
209
+ # Set up the Gradio UI
210
+ # Add text box and radio button to the interface
211
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
212
+
213
+ textbox = gr.Textbox(label="User Input")
214
+ #company = gr.List(label="Select Company", choices=["IBM", "Meta", "aws", "google","msft"])
215
+ company = gr.Dropdown(label="Select Company", choices=["IBM", "Meta", "aws", "google","msft"])
216
+
217
+ # Create the interface
218
+ # For the inputs parameter of Interface provide [textbox, company]
219
+ demo = gr.Interface(fn=predict, inputs=[textbox, company], outputs="text")
220
+
221
+ demo.queue()
222
+ demo.launch(share=True)