import streamlit as st
import os
import json
from transformers import GPT2Tokenizer, GPT2LMHeadModel, BertTokenizer, BertModel,T5Tokenizer, T5ForConditionalGeneration,AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
from nltk.corpus import stopwords
def is_new_file_upload(uploaded_file):
if 'last_uploaded_file' in st.session_state:
# Check if the newly uploaded file is different from the last one
if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or
uploaded_file.size != st.session_state.last_uploaded_file['size']):
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
# st.write("A new src image file has been uploaded.")
return True
else:
# st.write("The same src image file has been re-uploaded.")
return False
else:
# st.write("This is the first file upload detected.")
st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size}
return True
def combined_similarity(similarity, sentence, query):
# Tokenize both the sentence and the query
# sentence_words = set(sentence.split())
# query_words = set(query.split())
sentence_words = set(word for word in sentence.split() if word.lower() not in st.session_state.stop_words)
query_words = set(word for word in query.split() if word.lower() not in st.session_state.stop_words)
# Calculate the number of common words
common_words = len(sentence_words.intersection(query_words))
# Adjust the similarity score with the common words count
combined_score = similarity + (common_words / max(len(query_words), 1)) # Normalize by the length of the query to keep the score between -1 and 1
return combined_score
big_text = """
Knowledge Extraction A
"""
# Display the styled text
st.markdown(big_text, unsafe_allow_html=True)
uploaded_json_file = st.file_uploader("Upload a pre-processed file",
type=['json'])
st.markdown(
f'Sample 1 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
What is death? What is a lucid dream? What is the seat of consciousness?",unsafe_allow_html=True)
st.markdown(
f'Sample 2 download and then upload to above',
unsafe_allow_html=True)
st.markdown("sample queries for above file:
what does nontechnical managers worry about? what if you put all the knowledge, frameworks, and tips from this book to full use? tell me about AI agent",unsafe_allow_html=True)
if uploaded_json_file is not None:
if is_new_file_upload(uploaded_json_file):
print("is new file uploaded")
save_path = './uploaded_files'
if not os.path.exists(save_path):
os.makedirs(save_path)
with open(os.path.join(save_path, uploaded_json_file.name), "wb") as f:
f.write(uploaded_json_file.getbuffer()) # Write the file to the specified location
st.success(f'Saved file temp_{uploaded_json_file.name} in {save_path}')
st.session_state.uploaded_path=os.path.join(save_path, uploaded_json_file.name)
# st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path)
# print("page_count=",st.session_state.page_count)
content = uploaded_json_file.read()
try:
st.session_state.restored_paragraphs = json.loads(content)
#print(data)
# Check if the parsed data is a dictionary
if isinstance(st.session_state.restored_paragraphs, list):
# Count the restored_paragraphs of top-level elements
st.session_state.list_count = len(st.session_state.restored_paragraphs)
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
else:
st.write('The JSON content is not a dictionary.')
except json.JSONDecodeError:
st.write('Invalid JSON file.')
st.rerun()
if 'is_initialized' not in st.session_state:
st.session_state['is_initialized'] = True
nltk.download('punkt')
nltk.download('stopwords')
st.session_state.stop_words = set(stopwords.words('english'))
st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", )
st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
if 'list_count' in st.session_state:
st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
if 'paragraph_sentence_encodings' not in st.session_state:
print("start embedding paragarphs")
read_progress_bar = st.progress(0)
st.session_state.paragraph_sentence_encodings = []
for index,paragraph in enumerate(st.session_state.restored_paragraphs):
#print(paragraph)
progress_percentage = (index) / (st.session_state.list_count - 1)
# print(progress_percentage)
read_progress_bar.progress(progress_percentage)
sentence_encodings = []
sentences = sent_tokenize(paragraph['text'])
for sentence in sentences:
if sentence.strip().endswith('?'):
sentence_encodings.append(None)
continue
if len(sentence.strip()) < 4:
sentence_encodings.append(None)
continue
sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad():
sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
sentence_encodings.append([sentence, sentence_encoding])
# sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
st.rerun()
if 'paragraph_sentence_encodings' in st.session_state:
query = st.text_input("Enter your query")
if query:
query_tokens = st.session_state.bert_tokenizer(query, return_tensors="pt", padding=True, truncation=True).to('cuda')
with torch.no_grad(): # Disable gradient calculation for inference
# Perform the forward pass on the GPU
query_encoding = st.session_state.bert_model(**query_tokens).last_hidden_state[:, 0,
:].cpu().numpy() # Move the result to CPU and convert to NumPy
paragraph_scores = []
sentence_scores = []
sentence_encoding = []
total_count=len(st.session_state.paragraph_sentence_encodings)
processing_progress_bar = st.progress(0)
for index,paragraph_sentence_encoding in enumerate(st.session_state.paragraph_sentence_encodings):
progress_percentage = index / (total_count- 1)
processing_progress_bar.progress(progress_percentage)
best_similarity = -1
sentence_similarities = []
for sentence_encoding in paragraph_sentence_encoding[1]:
if sentence_encoding:
similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0]
# adjusted_similarity = similarity*len(sentence_encoding[0].split())**0.5
combined_score = combined_similarity(similarity, sentence_encoding[0], query)
# print("sentence="+sentence_encoding[0] + " len="+str())
sentence_similarities.append(combined_score)
sentence_scores.append((combined_score, sentence_encoding[0]))
# best_similarity = max(best_similarity, similarity)
sentence_similarities.sort(reverse=True)
# Calculate the average of the top three sentence similarities
if len(sentence_similarities) >= 3:
top_three_avg_similarity = np.mean(sentence_similarities[:3])
elif sentence_similarities:
top_three_avg_similarity = np.mean(sentence_similarities)
else:
top_three_avg_similarity = 0
paragraph_scores.append((top_three_avg_similarity, paragraph_sentence_encoding[0]))
sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True)
# Display the scores and sentences
# print("Top scored sentences and their scores:")
# for score, sentence in sentence_scores: # Print top 10 for demonstration
# print(f"Score: {score:.4f}, Sentence: {sentence}")
# Sort the paragraphs by their best similarity score
paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True)
# Debug prints to understand the scores and paragraphs
st.write("Top scored paragraphs and their scores:")
for score, paragraph in paragraph_scores[:5]: # Print top 5 for debugging
st.write(f"Score: {score}, Paragraph: {paragraph['text']}")