pankilshah commited on
Commit
3f7b2e0
·
verified ·
1 Parent(s): c34ada1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +132 -0
app.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+ import json
4
+
5
+ import gradio as gr
6
+
7
+ import openai
8
+
9
+ from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
10
+ from langchain_community.vectorstores import Chroma
11
+
12
+ from huggingface_hub import CommitScheduler
13
+ from pathlib import Path
14
+
15
+ os.environ["OPENAI_API_KEY"] = "gl-U2FsdGVkX1//b9yZ7Ti6gIdFjv8A8Hps+MnZFMrEO+MRGYjjWkE6E1+6evFAmltp"
16
+ os.environ["OPENAI_BASE_URL"] = "https://aibe.mygreatlearning.com/openai/v1"
17
+
18
+ # openai.api_key = os.environ['OPENAI_API_KEY']
19
+ # openai.api_base = os.environ['OPENAI_BASE_URL']
20
+
21
+
22
+ embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small')
23
+
24
+ tesla_10k_collection = 'tesla-10k-2019-to-2023'
25
+
26
+ vectorstore_persisted = Chroma(
27
+ collection_name=tesla_10k_collection,
28
+ persist_directory='./tesla_db',
29
+ embedding_function=embedding_model
30
+ )
31
+
32
+ retriever = vectorstore_persisted.as_retriever(
33
+ search_type='similarity',
34
+ search_kwargs={'k': 5}
35
+ )
36
+
37
+ # Prepare the logging functionality
38
+
39
+ log_file = Path("logs/") / f"data_{uuid.uuid4()}.json"
40
+ log_folder = log_file.parent
41
+
42
+ #scheduler = CommitScheduler(
43
+ # repo_id="pankilshah/RAG_logs",
44
+ # repo_type="dataset",
45
+ # folder_path=log_folder,
46
+ # path_in_repo="data",
47
+ # every=2
48
+ #)
49
+
50
+ qna_system_message = """
51
+ You are an assistant to a financial services firm who answers user queries on annual reports.
52
+ Users will ask questions delimited by triple backticks, that is, ```.
53
+ User input will have the context required by you to answer user questions.
54
+ This context will begin with the token: ###Context.
55
+ The context contains references to specific portions of a document relevant to the user query.
56
+ Please answer only using the context provided in the input. However, do not mention anything about the context in your answer.
57
+ If the answer is not found in the context, respond "I don't know".
58
+ """
59
+
60
+ qna_user_message_template = """
61
+ ###Context
62
+ Here are some documents that are relevant to the question.
63
+ {context}
64
+ ```
65
+ {question}
66
+ ```
67
+ """
68
+
69
+ # Define the predict function that runs when 'Submit' is clicked or when a API request is made
70
+ def predict(user_input):
71
+
72
+ relevant_document_chunks = retriever.invoke(user_input)
73
+ context_list = [d.page_content for d in relevant_document_chunks]
74
+ context_for_query = ".".join(context_list)
75
+
76
+ prompt = [
77
+ {'role':'system', 'content': qna_system_message},
78
+ {'role': 'user', 'content': qna_user_message_template.format(
79
+ context=context_for_query,
80
+ question=user_input
81
+ )
82
+ }
83
+ ]
84
+
85
+ try:
86
+ response = openai.chat.completions.create(
87
+ model='gpt-4o-mini',
88
+ messages=prompt,
89
+ temperature=0
90
+ )
91
+
92
+ prediction = response.choices[0].message.content.strip()
93
+
94
+ except Exception as e:
95
+ prediction = e
96
+
97
+ # While the prediction is made, log both the inputs and outputs to a local log file
98
+ # While writing to the log file, ensure that the commit scheduler is locked to avoid parallel
99
+ # access
100
+
101
+ #with scheduler.lock:
102
+ # with log_file.open("a") as f:
103
+ # f.write(json.dumps(
104
+ # {
105
+ # 'user_input': user_input,
106
+ # 'retrieved_context': context_for_query,
107
+ # 'model_response': prediction
108
+ # }
109
+ # ))
110
+ # f.write("\n")
111
+
112
+ return prediction
113
+
114
+
115
+ textbox = gr.Textbox(placeholder="Enter your query here", lines=6)
116
+
117
+ # Create the interface
118
+ demo = gr.Interface(
119
+ inputs=textbox, fn=predict, outputs="text",
120
+ title="AMA on Tesla 10-K statements",
121
+ description="This web API presents an interface to ask questions on contents of the Tesla 10-K reports for the period 2019 - 2023.",
122
+ article="Note that questions that are not relevant to the Tesla 10-K report will not be answered.",
123
+ examples=[["What was the total revenue of the company in 2022?", "$ 81.46 Billion"],
124
+ ["Summarize the Management Discussion and Analysis section of the 2021 report in 50 words.", ""],
125
+ ["What was the company's debt level in 2020?", ""],
126
+ ["Identify five key risks identified in the 2019 10k report? Respond with bullet point summaries.", ""]
127
+ ],
128
+ concurrency_limit=16
129
+ )
130
+
131
+ demo.queue()
132
+ demo.launch(share=True)