Spaces:
Running
Running
# import streamlit as st | |
# import json | |
# import torch | |
# from transformers import AutoTokenizer, AutoModel | |
# import faiss | |
# import google.generativeai as genai | |
# from flashrank.Ranker import Ranker, RerankRequest | |
# # Configure Google Generative AI API Key | |
# genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs") # Replace with your API key | |
# # Load and preprocess the uploaded file | |
# def load_and_preprocess(uploaded_file): | |
# data = json.load(uploaded_file) | |
# passages = [f"Speaker: {item['speaker']}. Text: {item['text']}" | |
# for item in data if item["text"].strip()] | |
# return data, passages | |
# # Load embedding model | |
# def load_model(model_name="BAAI/bge-m3"): | |
# tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# model = AutoModel.from_pretrained(model_name) | |
# return tokenizer, model | |
# # Generate embeddings | |
# def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"): | |
# model.to(device) | |
# embeddings = [] | |
# for i in range(0, len(passages), batch_size): | |
# batch = passages[i:i + batch_size] | |
# inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
# with torch.no_grad(): | |
# outputs = model(**inputs).last_hidden_state.mean(dim=1) | |
# embeddings.append(outputs.cpu()) | |
# embeddings = torch.cat(embeddings, dim=0) | |
# return embeddings.numpy() | |
# # Store embeddings in FAISS | |
# def store_in_faiss(embeddings): | |
# dimension = embeddings.shape[1] | |
# index = faiss.IndexFlatL2(dimension) | |
# index.add(embeddings) | |
# return index | |
# # Retrieve top-k passages | |
# def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"): | |
# model.to(device) | |
# inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
# with torch.no_grad(): | |
# query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() | |
# distances, indices = faiss_index.search(query_embedding, k) | |
# retrieved_passages = [passages[i] for i in indices[0]] | |
# return retrieved_passages | |
# # Rerank passages using FlashRank Ranker | |
# def rerank_passages(query, passages): | |
# formatted_passages = [{"text": passage} for passage in passages] | |
# ranker = Ranker(model_name="rank-T5-flan", cache_dir="/my_cache_dir") # Adjust cache directory as needed | |
# rerank_request = RerankRequest(query=query, passages=formatted_passages) | |
# results = ranker.rerank(rerank_request) | |
# return results | |
# # Generate a response using Gemini 1.5 Flash | |
# def generate_response(reranked_passages, query): | |
# context = " ".join([passage["text"] for passage in reranked_passages]) | |
# input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
# model = genai.GenerativeModel("gemini-1.5-flash") | |
# response = model.generate_content(input_text) | |
# return response.text | |
# # Streamlit app | |
# def main(): | |
# st.set_page_config(page_title="Chatbot with Document Upload", layout="wide") | |
# st.title("π Chatbot for Minutes of Meeting") | |
# # Initialize session state | |
# if "chat_history" not in st.session_state: | |
# st.session_state.chat_history = [] | |
# if "faiss_index" not in st.session_state: | |
# st.session_state.faiss_index = None | |
# if "passages" not in st.session_state: | |
# st.session_state.passages = None | |
# if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
# st.session_state.tokenizer, st.session_state.model = load_model() | |
# # File uploader | |
# uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"]) | |
# if uploaded_file: | |
# st.write("Processing the file...") | |
# data, passages = load_and_preprocess(uploaded_file) | |
# st.session_state.passages = passages | |
# # Generate embeddings and store in FAISS | |
# tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
# embeddings = generate_embeddings(passages, tokenizer, model) | |
# st.session_state.faiss_index = store_in_faiss(embeddings) | |
# st.success("File processed and embeddings generated successfully!") | |
# # Chat interface | |
# if st.session_state.faiss_index: | |
# st.header("Ask a Question") | |
# user_query = st.text_input("Type your question here:") | |
# if user_query: | |
# # Retrieve and rerank passages | |
# top_k_passages = retrieve_top_k(user_query, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages) | |
# reranked_passages = rerank_passages(user_query, top_k_passages) | |
# # Generate response | |
# response = generate_response(reranked_passages, user_query) | |
# # Display response | |
# st.markdown(f"**Question:** {user_query}") | |
# st.markdown(f"**Answer:** {response}") | |
# # Update chat history | |
# st.session_state.chat_history.append({"question": user_query, "answer": response}) | |
# # Chat history | |
# if st.session_state.chat_history: | |
# st.header("Chat History") | |
# for chat in st.session_state.chat_history: | |
# st.markdown(f"**Q:** {chat['question']}") | |
# st.markdown(f"**A:** {chat['answer']}") | |
# # Run the app | |
# if __name__ == "__main__": | |
# main() | |
import streamlit as st | |
from streamlit_chat import message | |
import json | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import faiss | |
import google.generativeai as genai | |
from flashrank.Ranker import Ranker, RerankRequest | |
from langchain.memory import ConversationBufferMemory | |
from pydantic import BaseModel,ConfigDict | |
genai.configure(api_key="AIzaSyArG3gnpZHnzi10mMSnyOMhzYJBeAZEJUs") | |
class CustomMemory(ConversationBufferMemory): | |
model_config = ConfigDict(arbitrary_types_allowed=True) | |
def load_and_preprocess(uploaded_file): | |
data = json.load(uploaded_file) | |
passages = [f"Speaker: {item['speaker']}. Text: {item['text']}" | |
for item in data if item["text"].strip()] | |
return data, passages | |
def load_model(model_name="BAAI/bge-m3"): | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
return tokenizer, model | |
def generate_embeddings(passages, tokenizer, model, batch_size=10, device="cuda" if torch.cuda.is_available() else "cpu"): | |
model.to(device) | |
embeddings = [] | |
for i in range(0, len(passages), batch_size): | |
batch = passages[i:i + batch_size] | |
inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs).last_hidden_state.mean(dim=1) | |
embeddings.append(outputs.cpu()) | |
embeddings = torch.cat(embeddings, dim=0) | |
return embeddings.numpy() | |
def store_in_faiss(embeddings): | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
def retrieve_top_k(query, tokenizer, model, faiss_index, passages, k=5, device="cuda" if torch.cuda.is_available() else "cpu"): | |
model.to(device) | |
inputs = tokenizer([query], return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) | |
with torch.no_grad(): | |
query_embedding = model(**inputs).last_hidden_state.mean(dim=1).cpu().numpy() | |
distances, indices = faiss_index.search(query_embedding, k) | |
retrieved_passages = [passages[i] for i in indices[0]] | |
return retrieved_passages | |
def rerank_passages(query, passages): | |
formatted_passages = [{"text": passage} for passage in passages] | |
ranker = Ranker(model_name="rank-T5-flan", cache_dir="/my_cache_dir") # Adjust cache directory as needed | |
rerank_request = RerankRequest(query=query, passages=formatted_passages) | |
results = ranker.rerank(rerank_request) | |
return results | |
def generate_response(context, query): | |
input_text = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
model = genai.GenerativeModel("gemini-1.5-flash") | |
response = model.generate_content(input_text) | |
return response.text | |
def handle_userinput(user_question): | |
top_k_passages = retrieve_top_k(user_question, st.session_state.tokenizer, st.session_state.model, st.session_state.faiss_index, st.session_state.passages) | |
reranked_passages = rerank_passages(user_question, top_k_passages) | |
context = " ".join([passage["text"] for passage in reranked_passages]) | |
response = generate_response(context, user_question) | |
st.session_state.memory.chat_memory.add_user_message(user_question) | |
st.session_state.memory.chat_memory.add_ai_message(response) | |
return response | |
def main(): | |
st.set_page_config(page_title="Chatbot with MoM Document Upload", layout="wide") | |
st.title("π Chatbot for Minutes of Meeting ") | |
if "memory" not in st.session_state: | |
st.session_state.memory = CustomMemory(memory_key='chat_history', return_messages=True) | |
if "faiss_index" not in st.session_state: | |
st.session_state.faiss_index = None | |
if "passages" not in st.session_state: | |
st.session_state.passages = None | |
if "tokenizer" not in st.session_state or "model" not in st.session_state: | |
st.session_state.tokenizer, st.session_state.model = load_model() | |
uploaded_file = st.file_uploader("Upload a JSON file for processing", type=["json"]) | |
if uploaded_file: | |
st.write("Processing the file...") | |
data, passages = load_and_preprocess(uploaded_file) | |
st.session_state.passages = passages | |
tokenizer, model = st.session_state.tokenizer, st.session_state.model | |
embeddings = generate_embeddings(passages, tokenizer, model) | |
st.session_state.faiss_index = store_in_faiss(embeddings) | |
st.success("File processed and embeddings generated successfully!") | |
if st.session_state.faiss_index: | |
st.header("Ask a Question") | |
user_query = st.text_input("Type your question here:") | |
if user_query: | |
response = handle_userinput(user_query) | |
if "chat_history_ui" not in st.session_state: | |
st.session_state.chat_history_ui = [] | |
st.session_state.chat_history_ui.append({"role": "user", "content": user_query}) | |
st.session_state.chat_history_ui.append({"role": "bot", "content": response}) | |
if "chat_history_ui" in st.session_state: | |
for i,chat in enumerate(st.session_state.chat_history_ui): | |
if chat["role"] == "user": | |
message(chat["content"], is_user=True,key=f"user_{i}") | |
else: | |
message(chat["content"], is_user=False,key=f"bot_{i}") | |
if __name__ == "__main__": | |
main() | |