zmbfeng commited on
Commit
40ddb2b
1 Parent(s): 4c2c5b7

encode sentence extracted

Browse files
Files changed (1) hide show
  1. app.py +13 -9
app.py CHANGED
@@ -78,6 +78,15 @@ if 'is_initialized' not in st.session_state:
78
  st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
79
  st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
80
  st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
 
 
 
 
 
 
 
 
 
81
 
82
  if 'list_count' in st.session_state:
83
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
@@ -96,15 +105,10 @@ if 'list_count' in st.session_state:
96
  paragraph_without_newline= paragraph['paragraph'].replace("\n", "")
97
  sentences = sent_tokenize(paragraph_without_newline)
98
  for sentence in sentences:
99
- if sentence.strip().endswith('?'):
100
- sentence_encodings.append(None)
101
- continue
102
- if len(sentence.strip()) < 4:
103
- sentence_encodings.append(None)
104
- continue
105
- sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to('cuda')
106
- with torch.no_grad():
107
- sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
108
  sentence_encodings.append([sentence, sentence_encoding])
109
  # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
110
  st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])
 
78
  st.session_state.bert_model = BertModel.from_pretrained("bert-base-uncased", ).to('cuda')
79
  st.session_state.roberta_tokenizer = AutoTokenizer.from_pretrained("roberta-large-mnli")
80
  st.session_state.roberta_model = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli")
81
+ def encode_sentence(sentence):
82
+ if len(sentence.strip()) < 4:
83
+ return None
84
+
85
+ sentence_tokens = st.session_state.bert_tokenizer(sentence, return_tensors="pt", padding=True, truncation=True).to(
86
+ 'cuda')
87
+ with torch.no_grad():
88
+ sentence_encoding = st.session_state.bert_model(**sentence_tokens).last_hidden_state[:, 0, :].cpu().numpy()
89
+ return sentence_encoding
90
 
91
  if 'list_count' in st.session_state:
92
  st.write(f'The number of elements at the top level of the hierarchy: {st.session_state.list_count }')
 
105
  paragraph_without_newline= paragraph['paragraph'].replace("\n", "")
106
  sentences = sent_tokenize(paragraph_without_newline)
107
  for sentence in sentences:
108
+ # if sentence.strip().endswith('?'):
109
+ # sentence_encodings.append(None)
110
+ # continue
111
+ sentence_encoding = encode_sentence(sentence)
 
 
 
 
 
112
  sentence_encodings.append([sentence, sentence_encoding])
113
  # sentence_encodings.append([sentence,bert_model(**sentence_tokens).last_hidden_state[:, 0, :].detach().numpy()])
114
  st.session_state.paragraph_sentence_encodings.append([paragraph, sentence_encodings])