mayankchugh-learning commited on
Commit
872912a
·
verified ·
1 Parent(s): fe230fc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py CHANGED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import the necessary Libraries
2
+ import os
3
+ import uuid
4
+ import json
5
+
6
+ import gradio as gr
7
+
8
+ from openai import OpenAI
9
+
10
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
11
+ from langchain_community.vectorstores import Chroma
12
+
13
+ from huggingface_hub import CommitScheduler
14
+ from pathlib import Path
15
+ from dotenv import load_dotenv
16
+
17
+
18
+ # Create Client
19
+ load_dotenv()
20
+
21
+ os.environ["ANYSCALE_API_KEY"]=os.getenv("ANYSCALE_API_KEY")
22
+
23
+ client = OpenAI(
24
+ base_url="https://api.endpoints.anyscale.com/v1",
25
+ api_key=os.environ['ANYSCALE_API_KEY']
26
+ )
27
+
28
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
29
+ # Define the embedding model and the vectorstore
30
+
31
+ collection_name = 'report-10k-2024'
32
+
33
+ vectorstore_persisted = Chroma(
34
+ collection_name=collection_name,
35
+ persist_directory='./dataset-10k',
36
+ embedding_function=embedding_model
37
+ )
38
+
39
+ # Load the persisted vectorDB
40
+
41
+ retriever = vectorstore_persisted.as_retriever(
42
+ search_type='similarity',
43
+ search_kwargs={'k': 5}
44
+ )
45
+
46
+
47
+ # Prepare the logging functionality
48
+
49
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
50
+ log_folder = log_file.parent
51
+
52
+ scheduler = CommitScheduler(
53
+ repo_id="---------",
54
+ repo_type="dataset",
55
+ folder_path=log_folder,
56
+ path_in_repo="data",
57
+ every=2
58
+ )
59
+
60
+ # Define the Q&A system message
61
+
62
+ qna_system_message = """
63
+ You are an assistant to a coder. Your task is to provide relevant information about the Python package Streamlit.
64
+
65
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
66
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
67
+ The source for a context will begin with the token ###Source
68
+
69
+ When crafting your response:
70
+ 1. Select the most relevant context or contexts to answer the question.
71
+ 2. Include the source links in your response.
72
+ 3. User questions will begin with the token: ###Question.
73
+ 4. If the question is irrelevant to streamlit respond with - "I am an assistant for streamlit Docs. I can only help you with questions related to streamlit"
74
+
75
+ Please adhere to the following guidelines:
76
+ - Answer only using the context provided.
77
+ - Do not mention anything about the context in your final answer.
78
+ - If the answer is not found in the context, it is very very important for you to respond with "I don't know. Please check the docs @ 'https://docs.streamlit.io/'"
79
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Sources:
80
+ - Do not make up sources. Use the links provided in the sources section of the context and nothing else. You are prohibited from providing other links/sources.
81
+
82
+ Here is an example of how to structure your response:
83
+
84
+ Answer:
85
+ [Answer]
86
+
87
+ Source
88
+ [Source]
89
+ """
90
+
91
+ # Define the user message template
92
+ qna_user_message_template = """
93
+ ###Context
94
+ Here are some documents that are relevant to the question.
95
+ {context}
96
+ ```
97
+ {question}
98
+ ```
99
+ """
100
+
101
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
102
+ def predict(user_input,company):
103
+
104
+ filter = "dataset/"+company+"-10-k-2023.pdf"
105
+ relevant_document_chunks = vectorstore_persisted.similarity_search(user_input, k=5, filter={"source":filter})
106
+
107
+ # Create context_for_query
108
+ context_list = [d.page_content for d in relevant_document_chunks]
109
+ context_for_query = ".".join(context_list)
110
+
111
+ # Create messages
112
+ prompt = [
113
+ {'role':'system', 'content': qna_system_message},
114
+ {'role': 'user', 'content': qna_user_message_template.format(
115
+ context=context_for_query,
116
+ question=user_input
117
+ )
118
+ }
119
+ ]
120
+
121
+ # Get response from the LLM
122
+ try:
123
+ response = client.chat.completions.create(
124
+ model='mistralai/Mixtral-8x7B-Instruct-v0.1',
125
+ messages=prompt,
126
+ temperature=0
127
+ )
128
+
129
+ prediction = response.choices[0].message.content
130
+
131
+ except Exception as e:
132
+ prediction = e
133
+
134
+ # While the prediction is made, log both the inputs and outputs to a local log file
135
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
136
+ # access
137
+
138
+ with scheduler.lock:
139
+ with log_file.open("a") as f:
140
+ f.write(json.dumps(
141
+ {
142
+ 'user_input': user_input,
143
+ 'retrieved_context': context_for_query,
144
+ 'model_response': prediction
145
+ }
146
+ ))
147
+ f.write("\n")
148
+
149
+ return prediction
150
+
151
+ # Set-up the Gradio UI
152
+ # Add text box and radio button to the interface
153
+ # The radio button is used to select the company 10k report in which the context needs to be retrieved.
154
+
155
+ with gr.Blocks() as demo:
156
+ with gr.Row():
157
+ question = gr.Textbox(label="Enter your question")
158
+ company = gr.Radio(["aws", "IBM", "google", "meta", "msft"], label="Select a company")
159
+
160
+ submit = gr.Button("Submit")
161
+ output = gr.Textbox(label="Output")
162
+
163
+ submit.click(
164
+ fn=predict,
165
+ inputs=[question, company],
166
+ outputs=output
167
+ )
168
+
169
+ demo.launch()
170
+ # Create the interface
171
+ # For the inputs parameter of Interface provide [textbox,company]
172
+
173
+
174
+ demo.queue()
175
+ demo.launch()