Hyma7 commited on
Commit
e354c74
·
verified ·
1 Parent(s): f3b7c20

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import necessary libraries
2
+ import streamlit as st
3
+ import pandas as pd
4
+ from sentence_transformers import SentenceTransformer
5
+ from langchain.vectorstores import FAISS
6
+ from langchain.embeddings.base import Embeddings
7
+ from transformers import pipeline
8
+ from langchain.llms.huggingface_pipeline import HuggingFacePipeline
9
+ from langchain.chains import RetrievalQA
10
+
11
+ # Define a LangChain-compatible wrapper for SentenceTransformer
12
+ class SentenceTransformerEmbeddings(Embeddings):
13
+ """
14
+ Wrapper for SentenceTransformer to integrate with LangChain.
15
+ """
16
+ def __init__(self, model_name: str):
17
+ self.model = SentenceTransformer(model_name)
18
+
19
+ def embed_documents(self, texts):
20
+ """
21
+ Generates embeddings for a list of documents.
22
+ Args:
23
+ texts (list): List of strings to embed.
24
+
25
+ Returns:
26
+ np.ndarray: Embedding vectors.
27
+ """
28
+ return self.model.encode(texts, show_progress_bar=False)
29
+
30
+ def embed_query(self, text):
31
+ """
32
+ Generates an embedding for a single query.
33
+ Args:
34
+ text (str): Query string to embed.
35
+
36
+ Returns:
37
+ np.ndarray: Embedding vector.
38
+ """
39
+ return self.model.encode([text], show_progress_bar=False)[0]
40
+
41
+ # Initialize the embedding model
42
+ embedding_model = SentenceTransformerEmbeddings('all-MiniLM-L6-v2')
43
+
44
+ # Preprocess data into descriptive text entries
45
+ def preprocess_data(data):
46
+ """
47
+ Combines multiple dataset columns into descriptive text entries for embedding.
48
+
49
+ Args:
50
+ data (pd.DataFrame): The input dataset containing participant details.
51
+
52
+ Returns:
53
+ list: A list of combined textual descriptions for each row in the dataset.
54
+ """
55
+ combined_entries = []
56
+ for _, row in data.iterrows():
57
+ entry = f"Participant {row['ID']}:\n"
58
+ entry += f"- AI Knowledge Level: {row['Q1.AI_knowledge']}\n"
59
+ entry += f"- Sources of AI Knowledge: {row['Q2.AI_sources']}\n"
60
+ entry += f"- Perspectives on AI: Dehumanization ({row['Q3#1.AI_dehumanization']}), "
61
+ entry += f"Job Replacement ({row['Q3#2.Job_replacement']})\n"
62
+ entry += f"- Domains Impacted by AI: {row['Q6.Domains']}\n"
63
+ entry += f"- Utility Grade for AI: {row['Q7.Utility_grade']}\n"
64
+ entry += f"- GPA: {row['Q16.GPA']}\n"
65
+ combined_entries.append(entry)
66
+ return combined_entries
67
+
68
+ # App logic
69
+ def main():
70
+ # Set up the Streamlit UI
71
+ st.title("RAG Chatbot")
72
+ st.write("This chatbot answers questions based on the dataset.")
73
+
74
+ # Load the dataset directly from the space directory
75
+ dataset_path = "Survey_AI.csv"
76
+ try:
77
+ data = pd.read_csv(dataset_path)
78
+ st.write("Dataset successfully loaded!")
79
+
80
+ # Preprocess data and create vector store
81
+ combined_texts = preprocess_data(data)
82
+ vector_store = FAISS.from_texts(combined_texts, embedding_model)
83
+ retriever = vector_store.as_retriever()
84
+
85
+ # Set up QA chain
86
+ flan_t5 = pipeline("text2text-generation", model="google/flan-t5-base", device=-1) # CPU mode
87
+ llm = HuggingFacePipeline(pipeline=flan_t5)
88
+ qa_chain = RetrievalQA.from_chain_type(llm=llm, retriever=retriever)
89
+
90
+ # Default sample questions
91
+ sample_questions = [
92
+ "What are the sources of AI knowledge for participants?",
93
+ "Which domains are impacted by AI?",
94
+ "What are participants' perspectives on job replacement due to AI?",
95
+ "What is the average GPA of participants?",
96
+ "What is the utility grade for AI?",
97
+ "Which participants view AI as highly beneficial in their domain?"
98
+ ]
99
+
100
+ st.subheader("Sample Questions")
101
+ selected_question = st.selectbox("Select a question to see the response:", [""] + sample_questions)
102
+
103
+ if selected_question:
104
+ response = qa_chain.run(selected_question)
105
+ st.write("Question:", selected_question)
106
+ st.write("Answer:", response)
107
+
108
+ # Custom user query
109
+ st.subheader("Custom Query")
110
+ query = st.text_input("Or, enter your own question:")
111
+ if query:
112
+ response = qa_chain.run(query)
113
+ st.write("Question:", query)
114
+ st.write("Answer:", response)
115
+
116
+ except FileNotFoundError:
117
+ st.error("Dataset file not found. Please ensure the file is named 'dataset.csv' and uploaded to the root directory.")
118
+
119
+ # Run the app
120
+ if __name__ == "__main__":
121
+ main()