import streamlit as st import os import fitz import re from transformers import AutoModelForSequenceClassification, BertTokenizer, BertModel, \ AutoTokenizer import torch from sklearn.metrics.pairwise import cosine_similarity import numpy as np import nltk from nltk.tokenize import sent_tokenize from nltk.corpus import stopwords def is_new_txt_file_upload(uploaded_txt_file): if 'last_uploaded_txt_file' in st.session_state: # Check if the newly uploaded file is different from the last one if (uploaded_txt_file.name != st.session_state.last_uploaded_txt_file['name'] or uploaded_txt_file.size != st.session_state.last_uploaded_txt_file['size']): st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size} # st.write("A new src image file has been uploaded.") return True else: # st.write("The same src image file has been re-uploaded.") return False else: # st.write("This is the first file upload detected.") st.session_state.last_uploaded_txt_file = {'name': uploaded_txt_file.name, 'size': uploaded_txt_file.size} return True def is_new_file_upload(uploaded_file): if 'last_uploaded_file' in st.session_state: # Check if the newly uploaded file is different from the last one if (uploaded_file.name != st.session_state.last_uploaded_file['name'] or uploaded_file.size != st.session_state.last_uploaded_file['size']): st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} # st.write("A new src image file has been uploaded.") return True else: # st.write("The same src image file has been re-uploaded.") return False else: # st.write("This is the first file upload detected.") st.session_state.last_uploaded_file = {'name': uploaded_file.name, 'size': uploaded_file.size} return True def add_commonality_to_similarity_score(similarity, sentence_to_find_similarity_score, query_to_find_similiarty_score): # Tokenize both the sentence and the query # sentence_words = set(sentence.split()) # query_words = set(query.split()) sentence_words = set(word for word in sentence_to_find_similarity_score.split() if word.lower() not in st.session_state.stop_words) query_words = set(word for word in query_to_find_similiarty_score.split() if word.lower() not in st.session_state.stop_words) # Calculate the number of common words common_words = len(sentence_words.intersection(query_words)) # Adjust the similarity score with the common words count combined_score = similarity + (common_words / max(len(query_words), 1)) # Normalize by the length of the query to keep the score between -1 and 1 return combined_score, similarity, (common_words / max(len(query_words), 1)) def contradiction_detection(premise, hypothesis): inputs = st.session_state.roberta_tokenizer.encode_plus(premise, hypothesis, return_tensors="pt", truncation=True) # Get model predictions outputs = st.session_state.roberta_model(**inputs) # Get the logits (raw predictions before softmax) logits = outputs.logits # Apply softmax to get probabilities for each class probabilities = torch.softmax(logits, dim=1) # Class labels: 0 = entailment, 1 = neutral, 2 = contradiction predicted_class = torch.argmax(probabilities, dim=1).item() # Class labels labels = ["Contradiction", "Neutral", "Entailment"] # Output the result print(f"Prediction: {labels[predicted_class]}") return {labels[predicted_class]} if 'is_initialized' not in st.session_state: st.session_state['is_initialized'] = True nltk.download('punkt') nltk.download('punkt_tab') nltk.download('stopwords') # print("stop words start") # print(stopwords.words('english')) # print("stop words end") stop_words_list = stopwords.words('english') st.session_state.stop_words = set(stop_words_list) st.session_state.bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", ) st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda') st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli") st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli") def encode_sentence(sentence_to_be_encoded): if len(sentence_to_be_encoded.strip()) < 4: return None sentence_tokens = st.session_state.bert_tokenizer(sentence_to_be_encoded, return_tensors="pt", padding=True, truncation=True).to( 'cuda') with torch.no_grad(): sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy() return sentence_encoding def encode_paragraph(paragraph_to_be_encoded): sentence_encodings_for_encoding_paragraph = [] paragraph_without_newline = paragraph_to_be_encoded.replace("\n", "") sentences_for_encoding_paragraph = sent_tokenize(paragraph_without_newline) for sentence_for_encoding_paragraph in sentences_for_encoding_paragraph: # if sentence.strip().endswith('?'): # sentence_encodings.append(None) # continue sentence_encoding = encode_sentence(sentence_for_encoding_paragraph) sentence_encodings_for_encoding_paragraph.append([sentence_for_encoding_paragraph, sentence_encoding]) return sentence_encodings_for_encoding_paragraph if 'list_count' in st.session_state: st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}') if 'paragraph_sentence_encodings' not in st.session_state: print("start embedding paragarphs") read_progress_bar = st.progress(0) st.session_state.paragraph_sentence_encodings = [] for index, paragraph in enumerate(st.session_state.restored_paragraphs): # print(paragraph) progress_percentage = index / (st.session_state.list_count - 1) # print(progress_percentage) read_progress_bar.progress(progress_percentage) # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()]) sentence_encodings = encode_paragraph(paragraph['paragraph']) st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings]) st.rerun() big_text = """

Contradiction Dectection

""" # Display the styled text st.markdown(big_text, unsafe_allow_html=True) def convert_pdf_to_paragraph_list(pdf_doc_to_paragraph_list): paragraphs = [] start_page = 1 for page_num in range(start_page - 1, len(pdf_doc_to_paragraph_list)): # start_page - 1 to adjust for 0-based index page = pdf_doc_to_paragraph_list.load_page(page_num) blocks = page.get_text("blocks") for block in blocks: x0, y0, x1, y1, text, block_type, flags = block if text.strip() != "": text = text.strip() text = re.sub(r'\n\s+\n', '\n\n', text) list_pattern = re.compile(r'^\s*((?:\d+\.|[a-zA-Z]\.|[*-])\s+.+)', re.MULTILINE) match = list_pattern.search(text) containsList = False if match: containsList = True # print ("list detected") if bool(re.search(r'\n{2,}', text)): substrings = re.split(r'\n{2,}', text) for substring in substrings: if substring.strip() != "": paragraph_for_converting_pdf = substring paragraphs.append( {"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num, "text": text}) # print(f" {substring} ") else: paragraph_for_converting_pdf = text paragraphs.append( {"paragraph": paragraph_for_converting_pdf, "containsList": containsList, "page_num": page_num, "text": None}) return paragraphs uploaded_pdf_file = st.file_uploader("Upload a PDF file", type=['pdf']) st.markdown( f'Sample Master PDF download and then upload to above', unsafe_allow_html=True) if uploaded_pdf_file is not None: if is_new_file_upload(uploaded_pdf_file): print("is new file uploaded") if 'prev_query' in st.session_state: del st.session_state['prev_query'] if 'paragraph_sentence_encodings' in st.session_state: del st.session_state['paragraph_sentence_encodings'] save_path = './uploaded_files' if not os.path.exists(save_path): os.makedirs(save_path) with open(os.path.join(save_path, uploaded_pdf_file.name), "wb") as f: f.write(uploaded_pdf_file.getbuffer()) # Write the file to the specified location st.success(f'Saved file temp_{uploaded_pdf_file.name} in {save_path}') st.session_state.uploaded_path = os.path.join(save_path, uploaded_pdf_file.name) # st.session_state.page_count = utils.get_pdf_page_count(st.session_state.uploaded_pdf_path) # print("page_count=",st.session_state.page_count) doc = fitz.open(st.session_state.uploaded_path) st.session_state.restored_paragraphs = convert_pdf_to_paragraph_list(doc) if isinstance(st.session_state.restored_paragraphs, list): # Count the restored_paragraphs of top-level elements st.session_state.list_count = len(st.session_state.restored_paragraphs) st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count}') st.rerun() def contradiction_detection_for_sentence(cd_query, cd_query_line_number): query_encoding = encode_sentence(cd_query) total_count = len(st.session_state.paragraph_sentence_encodings) processing_progress_bar = st.progress(0) sentence_scores, paragraph_scores = find_sentences_scores( st.session_state.paragraph_sentence_encodings, query_encoding, cd_query, processing_progress_bar, total_count) sorted_paragraph_scores = sorted(paragraph_scores, key=lambda x: x[0], reverse=True) # st.write("Top scored paragraphs and their scores:") for i, (similarity_score, commonality_score, paragraph_from_sorted_paragraph_scores) in enumerate( sorted_paragraph_scores[:3]): # number of paragraphs to consider # st.write("top_three_sentences: ", paragraph['top_three_sentences']) # st.write("paragarph number ***", i) prev_contradiction_detected = True for top_sentence in paragraph_from_sorted_paragraph_scores['top_three_sentences']: if prev_contradiction_detected: contradiction_detection_result = contradiction_detection(cd_query, top_sentence[1]) if contradiction_detection_result == {"Contradiction"}: if top_sentence[2] >= 0.25: st.write("master document page number ", paragraph_from_sorted_paragraph_scores['original_text']['page_num']) st.write("master document sentence: ", top_sentence[1]) st.write("secondary document line number", cd_query_line_number) st.write("secondary document sentence: ", cd_query) # st.write("commonality score",top_sentence[2]) st.write(contradiction_detection_result) # st.write(contradiction_detection(st.session_state.premise, top_sentence[1])) else: prev_contradiction_detected = False else: break def find_sentences_scores(paragraph_sentence_encodings, query_encoding, query_plain, processing_progress_bar, total_count): paragraph_scores = [] sentence_scores = [] for paragraph_sentence_encoding_index, paragraph_sentence_encoding in enumerate(paragraph_sentence_encodings): find_sentences_scores_progress_percentage = paragraph_sentence_encoding_index / (total_count - 1) processing_progress_bar.progress(find_sentences_scores_progress_percentage) sentence_similarities = [] for sentence_encoding in paragraph_sentence_encoding[1]: if sentence_encoding: similarity = cosine_similarity(query_encoding, sentence_encoding[1])[0][0] combined_score, similarity_score, commonality_score = add_commonality_to_similarity_score(similarity, sentence_encoding[ 0], query_plain) # print(f"{sentence_encoding[0]} {combined_score} {similarity_score} {commonality_score}") sentence_similarities.append((combined_score, sentence_encoding[0], commonality_score)) sentence_scores.append((combined_score, sentence_encoding[0])) sentence_similarities.sort(reverse=True, key=lambda x: x[0]) # print(sentence_similarities) if len(sentence_similarities) >= 3: top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities[:3]]) top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities[:3]]) top_three_sentences = sentence_similarities[:3] elif sentence_similarities: top_three_avg_similarity = np.mean([s[0] for s in sentence_similarities]) top_three_avg_commonality = np.mean([s[2] for s in sentence_similarities]) top_three_sentences = sentence_similarities else: top_three_avg_similarity = 0 top_three_avg_commonality = 0 top_three_sentences = [] # print(f"top_three_sentences={top_three_sentences}") # top_three_texts = [s[1] for s in top_three_sentences] # remaining_texts = [s[0] for s in paragraph_sentence_encoding[1] if s and s[0] not in top_three_texts] # reordered_paragraph = top_three_texts + remaining_texts # # original_paragraph = ' '.join([s[0] for s in paragraph_sentence_encoding[1] if s]) # modified_paragraph = ' '.join(reordered_paragraph) paragraph_scores.append( (top_three_avg_similarity, top_three_avg_commonality, {'top_three_sentences': top_three_sentences, 'original_text': paragraph_sentence_encoding[0]}) ) sentence_scores = sorted(sentence_scores, key=lambda x: x[0], reverse=True) return sentence_scores, paragraph_scores if 'paragraph_sentence_encodings' in st.session_state: uploaded_text_file = st.file_uploader("Choose a .txt file", type="txt") st.markdown( f'Sample Secondary txt download and then upload to above', unsafe_allow_html=True) if uploaded_text_file is not None: if is_new_txt_file_upload(uploaded_text_file): #if True: lines = uploaded_text_file.readlines() # Initialize an empty list to store line number and text line_list = [] # Iterate through each line and add to the list for line_number, line in enumerate(lines, start=1): # Decode the line (since it will be in bytes) decoded_line = line.decode("utf-8").strip() line_list.append((line_number, decoded_line)) # Display the list of tuples # st.write("Line Number and Line Content:") for item in line_list: # st.write(f"Line {item[0]}: {item[1]}") sentences = sent_tokenize(item[1]) for sentence in sentences: # st.write(f"sententce {sentence}") contradiction_detection_for_sentence(sentence, item[0]) # print(top_sentence[1]) # st.write(f"Similarity Score: {similarity_score}, Commonality Score: {commonality_score}") # st.write("top_three_sentences: ", paragraph['top_three_sentences']) # st.write("Original Paragraph: ", paragraph['original_text']) # A Member will be considered Actively at Work if he or she is able and available for active performance of all of his or her regular duties # A Member will be considered as inactive at Work if he or she is able and available for active performance of all of his or her regular duties # A Member shall be deemed inactive at Work if he or she is capable and available to perform all of his or her regular responsibilities. # st.write("Modified Paragraph: ", paragraph['modified_text']) toggle_single_sentence_input = st.checkbox("optional, if you want to enter a single sentence instead of an entire text file", value=False) if toggle_single_sentence_input: st.markdown( "sample queries to invoke contradiction:
A Member shall be deemed inactive at Work if he or she is capable and available to perform all of his or her regular responsibilities.", unsafe_allow_html=True) query = st.text_input("Enter your query") if query: if 'prev_query' not in st.session_state or st.session_state.prev_query != query: # if True: st.session_state.prev_query = query st.session_state.premise = query contradiction_detection_for_sentence(query, 1)