Spaces:
Build error
Build error
import torch | |
import streamlit as st | |
from streamlit import components | |
import pandas as pd | |
from transformers import BartTokenizer, BartForConditionalGeneration | |
from transformers import T5Tokenizer, T5ForConditionalGeneration | |
import evaluate | |
from datasets import load_dataset | |
from transformers import AutoTokenizer, LongT5ForConditionalGeneration | |
import numpy as np | |
from math import ceil | |
import en_core_web_lg | |
from collections import Counter | |
from string import punctuation | |
# Gensim | |
import gensim | |
from gensim.summarization import summarize | |
import spacy | |
nlp = en_core_web_lg.load() | |
st.set_page_config(page_title ='Clinical Note Summarization', | |
#page_icon= "Notes", | |
layout='wide') | |
st.title('Clinical Note Summarization') | |
st.sidebar.markdown('Using transformer model') | |
## Loading in dataset | |
#df = pd.read_csv('mtsamples_small.csv',index_col=0) | |
df = pd.read_csv("demo_shpi_w_rouge25Nov.csv") | |
#df.shape | |
df['HADM_ID'] = df['HADM_ID'].astype(str).apply(lambda x: x.replace('.0','')) | |
##Renaming column | |
#df.rename(columns={'patient id':'Patient_ID', | |
# 'hospital admission id':'Admission_ID', | |
# 'transcription':'Original_Text'}, inplace = True) | |
#Renaming column | |
df.rename(columns={'SUBJECT_ID':'Patient_ID', | |
'HADM_ID':'Admission_ID', | |
'hpi_input_text':'Original_Text', | |
'hpi_reference_summary':'Reference_text'}, inplace = True) | |
#data.rename(columns={'gdp':'log(gdp)'}, inplace=True) | |
#Filter selection | |
st.sidebar.header("Search for Patient:") | |
patientid = df['Patient_ID'] | |
patient = st.sidebar.selectbox('Select Patient ID:', patientid) | |
admissionid = df['Admission_ID'].loc[df['Patient_ID'] == patient] | |
HospitalAdmission = st.sidebar.selectbox(' ', admissionid) | |
#Another way to for filter selection | |
#patient = st.sidebar.multiselect( | |
# "Select Patient ID:", | |
# options=df['Patient_ID'].unique(), | |
# default= None | |
#) | |
#HospitalAdmission = st.sidebar.multiselect( | |
# "Select Hospital Admission ID:", | |
# options=df['Admission_ID'].unique(), | |
# #default=df['Admission_ID'].unique() | |
# default = None | |
#) | |
# List of Model available | |
model = st.sidebar.selectbox('Select Model', ('BART','BERT','BertGPT2','Gensim','LexRank','Long T5','Luhn','Pysummarization','SBERT Summary Tokenizer','T5','T5 Seq2Seq','T5-Base','TextRank')) | |
if model == 'BART': | |
_num_beams = 4 | |
_no_repeat_ngram_size = 3 | |
_length_penalty = 1 | |
_min_length = 12 | |
_max_length = 128 | |
_early_stopping = True | |
else: | |
_num_beams = 4 | |
_no_repeat_ngram_size = 3 | |
_length_penalty = 2 | |
_min_length = 30 | |
_max_length = 200 | |
_early_stopping = True | |
col3,col4 = st.columns(2) | |
patientid = col3.write(f"Patient ID: {patient} ") | |
admissionid =col4.write(f"Admission ID: {HospitalAdmission} ") | |
col1, col2 = st.columns(2) | |
_min_length = col1.number_input("Minimum Length", value=_min_length) | |
_max_length = col2.number_input("Maximun Length", value=_max_length) | |
##_early_stopping = col3.number_input("early_stopping", value=_early_stopping) | |
#text = st.text_area('Input Clinical Note here') | |
# Query out relevant Clinical notes | |
original_text = df.query( | |
"Patient_ID == @patient & Admission_ID == @HospitalAdmission" | |
) | |
original_text2 = original_text['Original_Text'].values | |
runtext =st.text_area('Input Clinical Note here:', str(original_text2), height=300) | |
reference_text = original_text['Reference_text'].values | |
def visualize(sentence_list, best_sentences): | |
text = '' | |
#display(HTML(f'<h1>Summary - {title}</h1>')) | |
for run_text in sentence_list: | |
if run_text in best_sentences: | |
#text += ' ' + str(run_text).replace(run_text, f"<mark>{run_text}</mark>") | |
text += ' ' + str(run_text).replace(run_text, f"<span class='highlight yellow'>{run_text}</span>") | |
else: | |
text += ' ' + run_text | |
# display(HTML(f""" {text} """)) | |
output = '' | |
best_sentences = [] | |
for run_text in output: | |
#print(sentence) | |
best_sentences.append(str(run_text)) | |
return text | |
#===== Pysummarization ===== | |
from pysummarization.nlpbase.auto_abstractor import AutoAbstractor | |
from pysummarization.tokenizabledoc.simple_tokenizer import SimpleTokenizer | |
from pysummarization.abstractabledoc.top_n_rank_abstractor import TopNRankAbstractor | |
import regex as re | |
auto_abstractor = AutoAbstractor() | |
auto_abstractor.tokenizable_doc = SimpleTokenizer() | |
auto_abstractor.delimiter_list = [".", "\n"] | |
abstractable_doc = TopNRankAbstractor() | |
def pysummarizer(input_text): | |
# print(type(text)) | |
summary = auto_abstractor.summarize(input_text, abstractable_doc) | |
best_sentences=[] | |
#summary_clean = ''.join([str(sentence).capitalize() for sentence in summary['summarize_result'] for summary['summarize_result'] in auto_abstractor.summarize(text, abstractable_doc)]) | |
for sentence in summary['summarize_result']: | |
best_sentences.append(re.sub(r'\s+', ' ', sentence).strip()) | |
clean_summary=''.join(sentence for sentence in best_sentences) | |
return clean_summary | |
##===== BERT Summary tokenizer ===== | |
def BertSummarizer(input_text): | |
from transformers import BigBirdTokenizer | |
from summarizer import Summarizer | |
bertsummarizer = Summarizer() | |
model = Summarizer() | |
result = model(input_text,ratio=0.4) | |
return result | |
##===== SBERT ===== | |
from summarizer.sbert import SBertSummarizer | |
Sbertmodel = SBertSummarizer('paraphrase-MiniLM-L6-v2') | |
def Sbert(input_text): | |
# Sbertresult = Sbertmodel(text, num_sentences=3) | |
Sbertresult = Sbertmodel(input_text, ratio=0.4) | |
return Sbertresult | |
##===== T5 Seq2Seq ===== | |
def t5seq2seq(input_text): | |
import torch | |
import torch.nn.functional as F | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
model = AutoModelForSeq2SeqLM.from_pretrained("t5-base") | |
tokenizer = AutoTokenizer.from_pretrained("t5-base") | |
inputs = tokenizer("summarize: " + input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True) | |
summary= tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return summary | |
def BertGPT2(input_text): | |
#import nlp | |
# BioClinicalBert with BERT2GPT2 model with GPT2 decoder | |
from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel | |
from transformers import AutoTokenizer, AutoModel | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") | |
model.to(device) | |
#bert_tokenizer = BertTokenizer.from_pretrained("bert-base-cased") | |
bert_tokenizer= AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
# CLS token will work as BOS token | |
bert_tokenizer.bos_token = bert_tokenizer.cls_token | |
# SEP token will work as EOS token | |
bert_tokenizer.eos_token = bert_tokenizer.sep_token | |
# make sure GPT2 appends EOS in begin and end | |
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): | |
outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id] | |
return outputs | |
GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens | |
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id | |
gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token | |
# set decoding params | |
model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id | |
model.config.eos_token_id = gpt2_tokenizer.eos_token_id | |
model.config.max_length = 142 | |
model.config.min_length = 56 | |
model.config.no_repeat_ngram_size = 3 | |
model.early_stopping = True | |
model.length_penalty = 2.0 | |
model.num_beams = 4 | |
#test_dataset = nlp.load_dataset("cnn_dailymail", "3.0.0", split="test") | |
batch_size = 64 | |
def Sbertmodel(batch): | |
# Tokenizer will automatically set [BOS] <text> [EOS] | |
# cut off at BERT max length 512 | |
inputs = bert_tokenizer(batch, padding="max_length", truncation=True, max_length=512, return_tensors="pt") | |
input_ids = inputs.input_ids.to("cuda") | |
attention_mask = inputs.attention_mask.to("cuda") | |
outputs = model.generate(input_ids, attention_mask=attention_mask) | |
# all special tokens including will be removed | |
output_str = gpt2_tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
#batch["pred"] = output_str | |
return output_str | |
Sbert(input_text) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
def run_model(input_text): | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
if model == "BART": | |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base") | |
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") | |
input_text = str(input_text) | |
input_text = ' '.join(input_text.split()) | |
input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device) | |
summary_ids = bart_model.generate(input_tokenized, | |
num_beams=_num_beams, | |
no_repeat_ngram_size=_no_repeat_ngram_size, | |
length_penalty=_length_penalty, | |
min_length=_min_length, | |
max_length=_max_length, | |
early_stopping=_early_stopping) | |
output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] | |
st.write('Summary') | |
st.success(output[0]) | |
elif model == "T5": | |
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base") | |
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base") | |
input_text = str(input_text).replace('\n', '') | |
input_text = ' '.join(input_text.split()) | |
input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device) | |
summary_task = torch.tensor([[21603, 10]]).to(device) | |
input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device) | |
summary_ids = t5_model.generate(input_tokenized, | |
num_beams=_num_beams, | |
no_repeat_ngram_size=_no_repeat_ngram_size, | |
length_penalty=_length_penalty, | |
min_length=_min_length, | |
max_length=_max_length, | |
early_stopping=_early_stopping) | |
output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids] | |
st.write('Summary') | |
st.success(output[0]) | |
elif model == "Gensim": | |
output=summarize(str(input_text)) | |
st.write('Summary') | |
st.success(output) | |
elif model == "Pysummarization": | |
output = pysummarizer(input_text) | |
st.write('Summary') | |
st.success(output) | |
elif model == "BERT": | |
output = BertSummarizer(input_text) | |
st.write('Summary') | |
st.success(output) | |
elif model == "SBERT Summary Tokenizer": | |
output = Sbert(input_text) | |
st.write('Summary') | |
st.success(output) | |
elif model == "T5 Seq2Seq": | |
output = t5seq2seq(input_text) | |
st.write('Summary') | |
st.success(output) | |
elif model == "BertGPT2": #Not working correctly. to work on it later on | |
output = BertGPT2(input_text) | |
st.write('Summary') | |
st.success(output) | |
if st.button('Submit'): | |
run_model(runtext) | |
st.markdown('<span style="background-color: #FFFF00">testing</span> if this **works**', unsafe_allow_html=True) | |
st.text_area('Reference text', str(reference_text)) | |
st.text_area(visualize(runtext,reference_text),unsafe_allow_html=True) | |