kgauvin603 commited on
Commit
db8c4ba
·
verified ·
1 Parent(s): 6919ca1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from sentence_transformers import SentenceTransformer
3
+ import chromadb
4
+ import pandas as pd
5
+ import os
6
+ import json
7
+ from pathlib import Path
8
+ from llama_index.llms.anyscale import Anyscale
9
+
10
+ # Load the sentence transformer model for embedding text
11
+ model = SentenceTransformer('all-MiniLM-L6-v2')
12
+
13
+ # Initialize the ChromaDB client for managing the vector database
14
+ chroma_client = chromadb.Client()
15
+
16
+ # Function to build the vector database from a CSV file
17
+ def build_database():
18
+ # Read the CSV file containing document data
19
+ df = pd.read_csv('vector_store.csv')
20
+ print(df.head())
21
+ # Name of the collection to store the data
22
+ collection_name = 'Dataset-10k-companies'
23
+
24
+ # Uncomment the line below to delete the existing collection if needed
25
+ # chroma_client.delete_collection(name=collection_name)
26
+
27
+ # Create a new collection in ChromaDB with the correct dimensionality
28
+ collection = chroma_client.create_collection(name=collection_name, dimensionality=384)
29
+
30
+ # Add data from the DataFrame to the collection
31
+ collection.add(
32
+ documents=df['documents'].tolist(),
33
+ ids=df['ids'].tolist(),
34
+ metadatas=df['metadatas'].apply(eval).tolist(),
35
+ embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
36
+ )
37
+
38
+ return collection
39
+
40
+ # Build the database when the app starts
41
+ collection = build_database()
42
+
43
+ # Access the Anyscale API key from environment variables
44
+ anyscale_api_key = os.environ.get('anyscale_api_key')
45
+
46
+ # Instantiate the Anyscale client for using the Llama language model
47
+ client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
48
+
49
+ # Function to get relevant chunks from the database based on the query
50
+ def get_relevant_chunks(query, collection, top_n=3):
51
+ # Encode the query into an embedding
52
+ query_embedding = model.encode(query).tolist()
53
+
54
+ # Query the collection to get the top_n most relevant results
55
+ results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
56
+
57
+ relevant_chunks = []
58
+ # Extract relevant chunks and their metadata
59
+ for i in range(len(results['documents'][0])):
60
+ chunk = results['documents'][0][i]
61
+ source = results['metadatas'][0][i]['source']
62
+ page = results['metadatas'][0][i]['page']
63
+ relevant_chunks.append((chunk, source, page))
64
+
65
+ return relevant_chunks
66
+
67
+ # System message template for the LLM to provide structured responses
68
+ qna_system_message = """
69
+ You are an assistant to Finsights analysts. Your task is to provide relevant information about the financial performance of the companies followed by Finsights.
70
+ User input will include the necessary context for you to answer their questions. This context will begin with the token: ###Context.
71
+ The context contains references to specific portions of documents relevant to the user's query, along with source links.
72
+ The source for a context will begin with the token: ###Source.
73
+ When crafting your response:
74
+ 1. Select only the context relevant to answer the question.
75
+ 2. Include the source links in your response.
76
+ 3. User questions will begin with the token: ###Question.
77
+ 4. If the question is irrelevant to Finsights, respond with: "I am an assistant for Finsight Docs. I can only help you with questions related to Finsights."
78
+ Adhere to the following guidelines:
79
+ - Your response should only address the question asked and nothing else.
80
+ - Answer only using the context provided.
81
+ - Do not mention anything about the context in your final answer.
82
+ - If the answer is not found in the context, respond with: "I don't know."
83
+ - Always quote the source when you use the context. Cite the relevant source at the end of your response under the section - Source:
84
+ - Do not make up sources. Use only the links provided in the sources section of the context. You are prohibited from providing other links/sources.
85
+ Here is an example of how to structure your response:
86
+ Answer:
87
+ [Answer]
88
+ Source:
89
+ [Source]
90
+ """
91
+
92
+ # User message template for passing context and question to the LLM
93
+ qna_user_message_template = """
94
+ ###Context
95
+ Here are some documents and their source links that are relevant to the question mentioned below.
96
+ {context}
97
+ ###Question
98
+ {question}
99
+ """
100
+
101
+ # Function to get a response from the LLM with retries
102
+ def get_llm_response(prompt, max_attempts=3):
103
+ full_response = ""
104
+ for attempt in range(max_attempts):
105
+ try:
106
+ # Generate a response from the LLM
107
+ response = client.complete(prompt, max_tokens=1000) # Increase max_tokens if possible
108
+ chunk = response.text.strip()
109
+ full_response += chunk
110
+ if chunk.endswith((".", "!", "?")): # Check if the response seems complete
111
+ break
112
+ else:
113
+ # Continue the prompt from where it left off
114
+ prompt = "Please continue from where you left off:\n" + chunk[-100:] # Use the last 100 chars as context
115
+ except Exception as e:
116
+ print(f"Attempt {attempt + 1} failed with error: {e}")
117
+ return full_response
118
+
119
+ # Prediction function to handle user queries
120
+ def predict(company, user_query):
121
+ try:
122
+ # Modify the query to include the company name
123
+ modified_query = f"{user_query} for {company}"
124
+
125
+ # Get relevant chunks from the database
126
+ relevant_chunks = get_relevant_chunks(modified_query, collection)
127
+
128
+ # Prepare the context string from the relevant chunks
129
+ context = ""
130
+ for chunk, source, page in relevant_chunks:
131
+ context += chunk + "\n"
132
+ context += f"###Source {source}, Page {page}\n"
133
+
134
+ # Prepare the user message with context and question
135
+ user_message = qna_user_message_template.format(context=context, question=user_query)
136
+
137
+ # Craft the prompt for the Llama model
138
+ prompt = f"{qna_system_message}\n\n{qna_user_message_template.format(context=context, question=user_query)}"
139
+
140
+ # Generate the response using the Llama model through Anyscale
141
+ answer = get_llm_response(prompt)
142
+
143
+ # Log the interaction for future reference
144
+ log_interaction(company, user_query, context, answer)
145
+
146
+ return answer
147
+ except Exception as e:
148
+ return f"An error occurred: {str(e)}"
149
+
150
+ # Function to log interactions in a JSON lines file
151
+ def log_interaction(company, user_query, context, answer):
152
+ log_file = Path("interaction_log.jsonl")
153
+ with log_file.open("a") as f:
154
+ json.dump({
155
+ 'company': company,
156
+ 'user_query': user_query,
157
+ 'context': context,
158
+ 'answer': answer
159
+ }, f)
160
+ f.write("\n")
161
+
162
+ # Create Gradio interface for user interaction
163
+ company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
164
+ iface = gr.Interface(
165
+ fn=predict,
166
+ inputs=[
167
+ gr.Radio(company_list, label="Select Company"),
168
+ gr.Textbox(lines=2, placeholder="Enter your query here...", label="User Query")
169
+ ],
170
+ outputs=gr.Textbox(label="Generated Answer"),
171
+ title="Company Reports Q&A",
172
+ description="Query the vector database and get an LLM response based on the documents in the collection."
173
+ )
174
+
175
+ # Launch the Gradio interface
176
+ iface.launch(share=True)