kgauvin603 commited on
Commit
d37ce17
·
verified ·
1 Parent(s): 509d6e0

Rename 02JUL24app.py to app.py

Browse files
Files changed (1) hide show
  1. 02JUL24app.py → app.py +31 -35
02JUL24app.py → app.py RENAMED
@@ -1,66 +1,62 @@
 
1
  import gradio as gr
2
  from sentence_transformers import SentenceTransformer
 
 
 
 
3
  import chromadb
4
  import pandas as pd
5
  import os
 
6
  import json
7
  from pathlib import Path
8
  from llama_index.llms.anyscale import Anyscale
9
 
10
- # Load the sentence transformer model for embedding text
11
- model = SentenceTransformer('all-MiniLM-L6-v2')
 
12
 
13
- # Initialize the ChromaDB client for managing the vector database
14
  chroma_client = chromadb.Client()
15
 
16
- # Function to re-encode embeddings
17
- def reencode_embeddings(embeddings):
18
- return [model.encode(eval(embedding.replace(',,', ','))).tolist() for embedding in embeddings]
19
-
20
- # Function to build the vector database from a CSV file
21
  def build_database():
22
- # Read the CSV file containing document data
23
  df = pd.read_csv('vector_store.csv')
24
-
25
- # Name of the collection to store the data
26
  collection_name = 'Dataset-10k-companies'
27
-
28
- # Uncomment the line below to delete the existing collection if needed
29
- # chroma_client.delete_collection(name=collection_name)
30
-
31
- # Create a new collection in ChromaDB
32
  collection = chroma_client.create_collection(name=collection_name)
33
-
34
- # Re-encode the embeddings to match the model's dimensionality
35
- embeddings = reencode_embeddings(df['embeddings'].tolist())
36
 
37
- # Add data from the DataFrame to the collection
38
  collection.add(
39
  documents=df['documents'].tolist(),
40
  ids=df['ids'].tolist(),
41
  metadatas=df['metadatas'].apply(eval).tolist(),
42
- embeddings=embeddings
43
  )
44
-
45
  return collection
46
 
47
- # Build the database when the app starts
48
  collection = build_database()
49
 
50
- # Access the Anyscale API key from environment variables
51
  anyscale_api_key = os.environ.get('anyscale_api_key')
52
 
53
- # Instantiate the Anyscale client for using the Llama language model
54
  client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
55
 
56
  # Function to get relevant chunks from the database based on the query
57
  def get_relevant_chunks(query, collection, top_n=3):
58
  # Encode the query into an embedding
59
  query_embedding = model.encode(query).tolist()
60
-
61
  # Query the collection to get the top_n most relevant results
62
  results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
63
-
64
  relevant_chunks = []
65
  # Extract relevant chunks and their metadata
66
  for i in range(len(results['documents'][0])):
@@ -68,7 +64,7 @@ def get_relevant_chunks(query, collection, top_n=3):
68
  source = results['metadatas'][0][i]['source']
69
  page = results['metadatas'][0][i]['page']
70
  relevant_chunks.append((chunk, source, page))
71
-
72
  return relevant_chunks
73
 
74
  # System message template for the LLM to provide structured responses
@@ -128,10 +124,10 @@ def predict(company, user_query):
128
  try:
129
  # Modify the query to include the company name
130
  modified_query = f"{user_query} for {company}"
131
-
132
  # Get relevant chunks from the database
133
  relevant_chunks = get_relevant_chunks(modified_query, collection)
134
-
135
  # Prepare the context string from the relevant chunks
136
  context = ""
137
  for chunk, source, page in relevant_chunks:
@@ -149,7 +145,7 @@ def predict(company, user_query):
149
 
150
  # Log the interaction for future reference
151
  log_interaction(company, user_query, context, answer)
152
-
153
  return answer
154
  except Exception as e:
155
  return f"An error occurred: {str(e)}"
@@ -167,8 +163,8 @@ def log_interaction(company, user_query, context, answer):
167
  f.write("\n")
168
 
169
  # Create Gradio interface for user interaction
170
- company_list = ["MSFT", "AWS", "Meta", "Google", "IBM"]
171
- iface = gr.Interface(
172
  fn=predict,
173
  inputs=[
174
  gr.Radio(company_list, label="Select Company"),
@@ -179,5 +175,5 @@ iface = gr.Interface(
179
  description="Query the vector database and get an LLM response based on the documents in the collection."
180
  )
181
 
182
- # Launch the Gradio interface
183
- iface.launch(share=True)
 
1
+
2
  import gradio as gr
3
  from sentence_transformers import SentenceTransformer
4
+ from langchain_community.embeddings.sentence_transformer import (
5
+ SentenceTransformerEmbeddings
6
+ )
7
+ from langchain_community.vectorstores import Chroma
8
  import chromadb
9
  import pandas as pd
10
  import os
11
+ import csv
12
  import json
13
  from pathlib import Path
14
  from llama_index.llms.anyscale import Anyscale
15
 
16
+ # Transformer model for embedding
17
+ #model = SentenceTransformer('all-MiniLM-L6-v2')
18
+ model = SentenceTransformerEmbeddings(model_name='thenlper/gte-large')
19
 
20
+ # ChromaDB client for managing the vdb
21
  chroma_client = chromadb.Client()
22
 
23
+ # Function to build the vdb from csv
 
 
 
 
24
  def build_database():
 
25
  df = pd.read_csv('vector_store.csv')
26
+ print(df.head())
 
27
  collection_name = 'Dataset-10k-companies'
28
+
29
+
30
+ # Creating a new collection
 
 
31
  collection = chroma_client.create_collection(name=collection_name)
 
 
 
32
 
33
+ # Add data from the created DataFrame
34
  collection.add(
35
  documents=df['documents'].tolist(),
36
  ids=df['ids'].tolist(),
37
  metadatas=df['metadatas'].apply(eval).tolist(),
38
+ embeddings=df['embeddings'].apply(lambda x: eval(x.replace(',,', ','))).tolist()
39
  )
40
+
41
  return collection
42
 
43
+ # Build the database
44
  collection = build_database()
45
 
46
+ # Get API key from hf environment variables
47
  anyscale_api_key = os.environ.get('anyscale_api_key')
48
 
49
+ # Anyscale client for using the Llama language model
50
  client = Anyscale(api_key=anyscale_api_key, model="meta-llama/Llama-2-70b-chat-hf")
51
 
52
  # Function to get relevant chunks from the database based on the query
53
  def get_relevant_chunks(query, collection, top_n=3):
54
  # Encode the query into an embedding
55
  query_embedding = model.encode(query).tolist()
56
+
57
  # Query the collection to get the top_n most relevant results
58
  results = collection.query(query_embeddings=[query_embedding], n_results=top_n)
59
+
60
  relevant_chunks = []
61
  # Extract relevant chunks and their metadata
62
  for i in range(len(results['documents'][0])):
 
64
  source = results['metadatas'][0][i]['source']
65
  page = results['metadatas'][0][i]['page']
66
  relevant_chunks.append((chunk, source, page))
67
+
68
  return relevant_chunks
69
 
70
  # System message template for the LLM to provide structured responses
 
124
  try:
125
  # Modify the query to include the company name
126
  modified_query = f"{user_query} for {company}"
127
+
128
  # Get relevant chunks from the database
129
  relevant_chunks = get_relevant_chunks(modified_query, collection)
130
+
131
  # Prepare the context string from the relevant chunks
132
  context = ""
133
  for chunk, source, page in relevant_chunks:
 
145
 
146
  # Log the interaction for future reference
147
  log_interaction(company, user_query, context, answer)
148
+
149
  return answer
150
  except Exception as e:
151
  return f"An error occurred: {str(e)}"
 
163
  f.write("\n")
164
 
165
  # Create Gradio interface for user interaction
166
+ company_list = ["Meta", "IBM", "MSFT", "Google", "AWS"]
167
+ interface = gr.Interface(
168
  fn=predict,
169
  inputs=[
170
  gr.Radio(company_list, label="Select Company"),
 
175
  description="Query the vector database and get an LLM response based on the documents in the collection."
176
  )
177
 
178
+ # Launch the Gradio interface with public sharing enabled
179
+ interface.launch()