shubh2014shiv commited on
Commit
a40678b
·
1 Parent(s): 18556fd

Update app.py

Browse files

Added comment while creating new folder for downloading the Japanese English translation model

Files changed (1) hide show
  1. app.py +361 -361
app.py CHANGED
@@ -1,361 +1,361 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import plotly.express as px
4
- import plotly.graph_objects as go
5
- from st_aggrid import AgGrid
6
- from st_aggrid.grid_options_builder import GridOptionsBuilder
7
- from st_aggrid.shared import JsCode
8
- from st_aggrid.shared import GridUpdateMode
9
- from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
10
- import torch
11
- import numpy as np
12
- import json
13
- from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead
14
- import pytorch_lightning as pl
15
- from pathlib import Path
16
-
17
- # Defining some functions for caching purpose by streamlit
18
- class TranslationModel(pl.LightningModule):
19
- def __init__(self):
20
- super().__init__()
21
- self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True)
22
-
23
-
24
- @st.experimental_singleton
25
- def loadFineTunedJaEn_NMT_Model():
26
- save_dest = Path('model')
27
- save_dest.mkdir(exist_ok=True)
28
-
29
- f_checkpoint = Path("model/best-checkpoint.ckpt")
30
-
31
- if not f_checkpoint.exists():
32
- with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"):
33
- from GD_download import download_file_from_google_drive
34
- download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint)
35
-
36
- trained_model = TranslationModel.load_from_checkpoint(f_checkpoint)
37
-
38
- return trained_model
39
-
40
- @st.experimental_singleton
41
- def getJpEn_Tokenizers():
42
- try:
43
- with st.spinner("Downloading English and Japanese Transformer Tokenizers"):
44
- ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en")
45
- en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
46
- except:
47
- st.error("Issue with downloading tokenizers")
48
-
49
- return ja_tokenizer, en_tokenizer
50
-
51
- st.set_page_config(layout="wide")
52
- st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
53
- st.sidebar.subheader("自然言語処理 トピック")
54
- topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"])
55
-
56
- st.write("-" * 5)
57
- jp_review_text = None
58
- #JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
59
-
60
- if topic == "Sentiment Analysis":
61
- st.markdown(
62
- "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
63
- unsafe_allow_html=True)
64
- st.markdown(
65
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
66
- unsafe_allow_html=True)
67
-
68
- amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
69
-
70
- cellstyle_jscode = JsCode(
71
- """
72
- function(params) {
73
- if (params.value.includes('positive')) {
74
- return {
75
- 'color': 'black',
76
- 'backgroundColor': '#32CD32'
77
- }
78
- } else {
79
- return {
80
- 'color': 'black',
81
- 'backgroundColor': '#FF7F7F'
82
- }
83
- }
84
- };
85
- """
86
- )
87
- st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
88
- unsafe_allow_html=True)
89
-
90
- st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
91
- unsafe_allow_html=True)
92
-
93
- choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
94
-
95
- SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
96
- WRITE_REVIEW = "Manually write review"
97
-
98
- gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
99
- gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
100
- gb.configure_pagination()
101
- if choose == SELECT_ONE_REVIEW:
102
- gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
103
- gridOptions = gb.build()
104
-
105
- if choose == SELECT_ONE_REVIEW:
106
- jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
107
- enable_enterprise_modules=True,
108
- allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
109
- st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
110
- if len(jp_review_choice['selected_rows']) != 0:
111
- jp_review_text = jp_review_choice['selected_rows'][0]['review']
112
- st.markdown(
113
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
114
- unsafe_allow_html=True)
115
- st.write(jp_review_choice['selected_rows'])
116
-
117
- if choose == WRITE_REVIEW:
118
-
119
- AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
120
- enable_enterprise_modules=True,
121
- allow_unsafe_jscode=True)
122
- with open("test_reviews_jp.csv", "rb") as file:
123
- st.download_button(label="Download Additional Japanese Reviews", data=file,
124
- file_name="Additional Japanese Reviews.csv")
125
- st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
126
- sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
127
- jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
128
- value=sample_japanese_review_input)
129
- if len(jp_review_text) == 0:
130
- st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
131
-
132
- if jp_review_text:
133
- st.markdown(
134
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
135
- unsafe_allow_html=True)
136
- tokens_column, tokenID_column = st.columns(2)
137
- tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
138
- tokens = tokenizer.tokenize(jp_review_text)
139
- token_ids = tokenizer.convert_tokens_to_ids(tokens)
140
- with tokens_column:
141
- token_expander = st.expander("Expand to see the tokens", expanded=False)
142
- with token_expander:
143
- st.write(tokens)
144
- with tokenID_column:
145
- tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
146
- with tokenID_expander:
147
- st.write(token_ids)
148
-
149
- st.markdown(
150
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
151
- unsafe_allow_html=True)
152
- encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
153
- add_special_tokens=True,
154
- return_attention_mask=True,
155
- padding=True,
156
- max_length=200,
157
- return_tensors='pt',
158
- truncation=True)
159
- input_ids = encoded_data['input_ids']
160
- attention_masks = encoded_data['attention_mask']
161
- input_ids_column, attention_masks_column = st.columns(2)
162
- with input_ids_column:
163
- input_ids_expander = st.expander("Expand to see the input IDs tensor")
164
- with input_ids_expander:
165
- st.write(input_ids)
166
- with attention_masks_column:
167
- attention_masks_expander = st.expander("Expand to see the attention mask tensor")
168
- with attention_masks_expander:
169
- st.write(attention_masks)
170
-
171
- st.markdown(
172
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
173
- unsafe_allow_html=True)
174
-
175
- label_dict = {'positive': 1, 'negative': 0}
176
- if st.button("Predict Sentiment"):
177
- with st.spinner("Wait.."):
178
- predictions = []
179
- model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
180
- num_labels=len(label_dict),
181
- output_attentions=False,
182
- output_hidden_states=False)
183
- #model.load_state_dict(
184
- # torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
185
- # map_location=torch.device('cpu')))
186
-
187
- model.load_state_dict(
188
- torch.load('reviewSentiments_jp.pt',
189
- map_location=torch.device('cpu')))
190
-
191
- inputs = {
192
- 'input_ids': input_ids,
193
- 'attention_mask': attention_masks
194
- }
195
-
196
- with torch.no_grad():
197
- outputs = model(**inputs)
198
-
199
- logits = outputs.logits
200
- logits = logits.detach().cpu().numpy()
201
- scores = 1 / (1 + np.exp(-1 * logits))
202
-
203
- result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]}
204
-
205
- result_col,graph_col = st.columns(2)
206
- with result_col:
207
- st.write(result)
208
- with graph_col:
209
- fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']])
210
- fig.update_layout(title="Probability distribution of Sentiment for the given text",\
211
- yaxis_title="Probability (確率)")
212
- fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
213
- st.plotly_chart(fig)
214
-
215
- elif topic == "Text Summarization":
216
- st.markdown(
217
- "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>",
218
- unsafe_allow_html=True)
219
- st.markdown(
220
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>",
221
- unsafe_allow_html=True)
222
-
223
- news_articles = pd.read_csv("jp_news_articles_val.csv").sample(frac=0.75,
224
- random_state=42)
225
- gb = GridOptionsBuilder.from_dataframe(news_articles)
226
- gb.configure_pagination()
227
- gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
228
- gridOptions = gb.build()
229
- jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material',
230
- enable_enterprise_modules=True,
231
- allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
232
-
233
- # WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
234
- if len(jp_article['selected_rows']) == 0:
235
- st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom")
236
- else:
237
- article_text = jp_article['selected_rows'][0]['News Articles']
238
-
239
- text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500)
240
- summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5)
241
-
242
- if text and st.button("Summarize it! (要約しよう)"):
243
- waitPlaceholder = st.image("wait.gif")
244
- summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum"
245
- tokenizer = AutoTokenizer.from_pretrained(summarization_model_name )
246
- model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name )
247
-
248
- input_ids = tokenizer(
249
- article_text,
250
- return_tensors="pt",
251
- padding="max_length",
252
- truncation=True,
253
- max_length=512
254
- )["input_ids"]
255
-
256
- output_ids = model.generate(
257
- input_ids=input_ids,
258
- max_length=summary_length,
259
- no_repeat_ngram_size=2,
260
- num_beams=4
261
- )[0]
262
-
263
- summary = tokenizer.decode(
264
- output_ids,
265
- skip_special_tokens=True,
266
- clean_up_tokenization_spaces=False
267
- )
268
-
269
- waitPlaceholder.empty()
270
-
271
- st.markdown(
272
- "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>",
273
- unsafe_allow_html=True)
274
-
275
- st.write(summary)
276
- elif topic == "Japanese to English Translation":
277
- st.markdown(
278
- "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>",
279
- unsafe_allow_html=True)
280
- st.markdown(
281
- "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>",
282
- unsafe_allow_html=True)
283
-
284
- st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo")
285
- link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)'
286
- st.markdown(link, unsafe_allow_html=True)
287
-
288
- bsd_more_info = st.expander(label="Expand to get more information on data and training report")
289
- with bsd_more_info:
290
- st.markdown(
291
- "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>",
292
- unsafe_allow_html=True)
293
- st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs")
294
-
295
- st.markdown(
296
- "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>",
297
- unsafe_allow_html=True)
298
- st.write(
299
- "The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)")
300
-
301
- with open("./BSD_ja-en_val.json", encoding='utf-8') as f:
302
- bsd_sample_data = json.load(f)
303
-
304
- en, ja = [], []
305
- for i in range(len(bsd_sample_data)):
306
- for j in range(len(bsd_sample_data[i]['conversation'])):
307
- en.append(bsd_sample_data[i]['conversation'][j]['en_sentence'])
308
- ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence'])
309
-
310
- df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en})
311
- gb = GridOptionsBuilder.from_dataframe(df)
312
- gb.configure_pagination()
313
- gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
314
- gridOptions = gb.build()
315
- translation_text = AgGrid(df, gridOptions=gridOptions, theme='material',
316
- enable_enterprise_modules=True,
317
- allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
318
- if len(translation_text['selected_rows']) != 0:
319
- bsd_jp = translation_text['selected_rows'][0]['Japanese']
320
- st.markdown(
321
- "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>",
322
- unsafe_allow_html=True)
323
- st.write(bsd_jp)
324
-
325
- if st.button("Translate"):
326
- ja_tokenizer, en_tokenizer = getJpEn_Tokenizers()
327
- trained_model = loadFineTunedJaEn_NMT_Model()
328
- trained_model.freeze()
329
-
330
-
331
- def translate(text):
332
- text_encoding = ja_tokenizer(
333
- text,
334
- max_length=100,
335
- padding="max_length",
336
- truncation=True,
337
- return_attention_mask=True,
338
- add_special_tokens=True,
339
- return_tensors='pt'
340
- )
341
-
342
- generated_ids = trained_model.model.generate(
343
- input_ids=text_encoding['input_ids'],
344
- attention_mask=text_encoding['attention_mask'],
345
- max_length=100,
346
- num_beams=2,
347
- repetition_penalty=2.5,
348
- length_penalty=1.0,
349
- early_stopping=True
350
- )
351
-
352
- preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for
353
- gen_id in generated_ids]
354
-
355
- return "".join(preds)[5:]
356
-
357
-
358
- st.markdown(
359
- "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>",
360
- unsafe_allow_html=True)
361
- st.write(translate(bsd_jp))
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import plotly.express as px
4
+ import plotly.graph_objects as go
5
+ from st_aggrid import AgGrid
6
+ from st_aggrid.grid_options_builder import GridOptionsBuilder
7
+ from st_aggrid.shared import JsCode
8
+ from st_aggrid.shared import GridUpdateMode
9
+ from transformers import T5Tokenizer, BertForSequenceClassification,AutoTokenizer, AutoModelForSeq2SeqLM
10
+ import torch
11
+ import numpy as np
12
+ import json
13
+ from transformers import AutoTokenizer, BertTokenizer, AutoModelWithLMHead
14
+ import pytorch_lightning as pl
15
+ from pathlib import Path
16
+
17
+ # Defining some functions for caching purpose by streamlit
18
+ class TranslationModel(pl.LightningModule):
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.model = AutoModelWithLMHead.from_pretrained("Helsinki-NLP/opus-mt-ja-en", return_dict=True)
22
+
23
+
24
+ @st.experimental_singleton
25
+ def loadFineTunedJaEn_NMT_Model():
26
+ save_dest = Path('model')
27
+ save_dest.mkdir(exist_ok=True)
28
+ st.write("Creating new folder for downloading the Japanese to English Translation Model. ")
29
+ f_checkpoint = Path("model/best-checkpoint.ckpt")
30
+ st.write("'Folder: model/best-checkpoint.ckpt' created.")
31
+ if not f_checkpoint.exists():
32
+ with st.spinner("Downloading model.This may take a while! \n Don't refresh or close this page!"):
33
+ from GD_download import download_file_from_google_drive
34
+ download_file_from_google_drive('1CZQKGj9hSqj7kEuJp_jm7bNVXrbcFsgP', f_checkpoint)
35
+
36
+ trained_model = TranslationModel.load_from_checkpoint(f_checkpoint)
37
+
38
+ return trained_model
39
+
40
+ @st.experimental_singleton
41
+ def getJpEn_Tokenizers():
42
+ try:
43
+ with st.spinner("Downloading English and Japanese Transformer Tokenizers"):
44
+ ja_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ja-en")
45
+ en_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
46
+ except:
47
+ st.error("Issue with downloading tokenizers")
48
+
49
+ return ja_tokenizer, en_tokenizer
50
+
51
+ st.set_page_config(layout="wide")
52
+ st.title("Project - Japanese Natural Language Processing (自然言語処理) using Transformers")
53
+ st.sidebar.subheader("自然言語処理 トピック")
54
+ topic = st.sidebar.radio(label="Select the NLP project topics", options=["Sentiment Analysis","Text Summarization","Japanese to English Translation"])
55
+
56
+ st.write("-" * 5)
57
+ jp_review_text = None
58
+ #JAPANESE_SENTIMENT_PROJECT_PATH = './Japanese Amazon reviews sentiments/'
59
+
60
+ if topic == "Sentiment Analysis":
61
+ st.markdown(
62
+ "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Transfer Learning based Japanese Sentiments Analysis using BERT<b></h2>",
63
+ unsafe_allow_html=True)
64
+ st.markdown(
65
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese Amazon Reviews Data (日本のAmazonレビューデータ)<b></h3>",
66
+ unsafe_allow_html=True)
67
+
68
+ amazon_jp_reviews = pd.read_csv("review_val.csv").sample(frac=1,random_state=10).iloc[:16000]
69
+
70
+ cellstyle_jscode = JsCode(
71
+ """
72
+ function(params) {
73
+ if (params.value.includes('positive')) {
74
+ return {
75
+ 'color': 'black',
76
+ 'backgroundColor': '#32CD32'
77
+ }
78
+ } else {
79
+ return {
80
+ 'color': 'black',
81
+ 'backgroundColor': '#FF7F7F'
82
+ }
83
+ }
84
+ };
85
+ """
86
+ )
87
+ st.write('<style>div.row-widget.stRadio > div{flex-direction:row;justify-content: center;} </style>',
88
+ unsafe_allow_html=True)
89
+
90
+ st.write('<style>div.st-bf{flex-direction:column;} div.st-ag{font-weight:bold;padding-left:2px;}</style>',
91
+ unsafe_allow_html=True)
92
+
93
+ choose = st.radio("", ("Choose a review from the dataframe below", "Manually write review"))
94
+
95
+ SELECT_ONE_REVIEW = "Choose a review from the dataframe below"
96
+ WRITE_REVIEW = "Manually write review"
97
+
98
+ gb = GridOptionsBuilder.from_dataframe(amazon_jp_reviews)
99
+ gb.configure_column("sentiment", cellStyle=cellstyle_jscode)
100
+ gb.configure_pagination()
101
+ if choose == SELECT_ONE_REVIEW:
102
+ gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
103
+ gridOptions = gb.build()
104
+
105
+ if choose == SELECT_ONE_REVIEW:
106
+ jp_review_choice = AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
107
+ enable_enterprise_modules=True,
108
+ allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
109
+ st.info("Select any one the Japanese Reviews by clicking the checkbox. Reviews can be navigated from each page.")
110
+ if len(jp_review_choice['selected_rows']) != 0:
111
+ jp_review_text = jp_review_choice['selected_rows'][0]['review']
112
+ st.markdown(
113
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Selected Review in JSON (JSONで選択されたレビュー)<b></h3>",
114
+ unsafe_allow_html=True)
115
+ st.write(jp_review_choice['selected_rows'])
116
+
117
+ if choose == WRITE_REVIEW:
118
+
119
+ AgGrid(amazon_jp_reviews, gridOptions=gridOptions, theme='material',
120
+ enable_enterprise_modules=True,
121
+ allow_unsafe_jscode=True)
122
+ with open("test_reviews_jp.csv", "rb") as file:
123
+ st.download_button(label="Download Additional Japanese Reviews", data=file,
124
+ file_name="Additional Japanese Reviews.csv")
125
+ st.info("Additional subset of Japanese Reviews can be downloaded and any review can be copied & pasted in text area.")
126
+ sample_japanese_review_input = "子供のレッスンバッグ用に購入。 思ったより大きく、ピアノ教本を入れるには充分でした。中は汚れてました。 何より驚いたのは、商品の梱包。 2つ折は許せるが、透明ビニール袋の底思いっきり空いてますけど? 何これ?包むっていうか挟んで終わり?底が全開している。 引っ張れば誰でも中身の注文書も、商品も見れる状態って何なの? 個人情報が晒されて、商品も粗末な扱いで嫌な気持ちでした。 郵送で中身が無事のが奇跡じゃないでしょうか? ありえない"
127
+ jp_review_text = st.text_area(label="Press 'Ctrl+Enter' after writing review in below text area",
128
+ value=sample_japanese_review_input)
129
+ if len(jp_review_text) == 0:
130
+ st.error("Input text cannot empty. Either write the japanese review in text area manually or select the review from the grid.")
131
+
132
+ if jp_review_text:
133
+ st.markdown(
134
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Sentence-Piece based Japanese Tokenizer using RoBERTA<b></h3>",
135
+ unsafe_allow_html=True)
136
+ tokens_column, tokenID_column = st.columns(2)
137
+ tokenizer = T5Tokenizer.from_pretrained('rinna/japanese-roberta-base')
138
+ tokens = tokenizer.tokenize(jp_review_text)
139
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
140
+ with tokens_column:
141
+ token_expander = st.expander("Expand to see the tokens", expanded=False)
142
+ with token_expander:
143
+ st.write(tokens)
144
+ with tokenID_column:
145
+ tokenID_expander = st.expander("Expand to see the token IDs", expanded=False)
146
+ with tokenID_expander:
147
+ st.write(token_ids)
148
+
149
+ st.markdown(
150
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Encoded Japanese Review Text to get Input IDs and attention masks as PyTorch Tensor<b></h3>",
151
+ unsafe_allow_html=True)
152
+ encoded_data = tokenizer.batch_encode_plus(np.array([jp_review_text]).astype('object'),
153
+ add_special_tokens=True,
154
+ return_attention_mask=True,
155
+ padding=True,
156
+ max_length=200,
157
+ return_tensors='pt',
158
+ truncation=True)
159
+ input_ids = encoded_data['input_ids']
160
+ attention_masks = encoded_data['attention_mask']
161
+ input_ids_column, attention_masks_column = st.columns(2)
162
+ with input_ids_column:
163
+ input_ids_expander = st.expander("Expand to see the input IDs tensor")
164
+ with input_ids_expander:
165
+ st.write(input_ids)
166
+ with attention_masks_column:
167
+ attention_masks_expander = st.expander("Expand to see the attention mask tensor")
168
+ with attention_masks_expander:
169
+ st.write(attention_masks)
170
+
171
+ st.markdown(
172
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Predict Sentiment of review using Fine-Tuned Japanese BERT<b></h3>",
173
+ unsafe_allow_html=True)
174
+
175
+ label_dict = {'positive': 1, 'negative': 0}
176
+ if st.button("Predict Sentiment"):
177
+ with st.spinner("Wait.."):
178
+ predictions = []
179
+ model = BertForSequenceClassification.from_pretrained("shubh2014shiv/jp_review_sentiments_amzn",
180
+ num_labels=len(label_dict),
181
+ output_attentions=False,
182
+ output_hidden_states=False)
183
+ #model.load_state_dict(
184
+ # torch.load(JAPANESE_SENTIMENT_PROJECT_PATH + 'FineTuneJapaneseBert_AmazonReviewSentiments.pt',
185
+ # map_location=torch.device('cpu')))
186
+
187
+ model.load_state_dict(
188
+ torch.load('reviewSentiments_jp.pt',
189
+ map_location=torch.device('cpu')))
190
+
191
+ inputs = {
192
+ 'input_ids': input_ids,
193
+ 'attention_mask': attention_masks
194
+ }
195
+
196
+ with torch.no_grad():
197
+ outputs = model(**inputs)
198
+
199
+ logits = outputs.logits
200
+ logits = logits.detach().cpu().numpy()
201
+ scores = 1 / (1 + np.exp(-1 * logits))
202
+
203
+ result = {"TEXT (文章)": jp_review_text,'NEGATIVE (ネガティブ)': scores[0][0], 'POSITIVE (ポジティブ)': scores[0][1]}
204
+
205
+ result_col,graph_col = st.columns(2)
206
+ with result_col:
207
+ st.write(result)
208
+ with graph_col:
209
+ fig = px.bar(x=['NEGATIVE (ネガティブ)','POSITIVE (ポジティブ)'],y=[result['NEGATIVE (ネガティブ)'],result['POSITIVE (ポジティブ)']])
210
+ fig.update_layout(title="Probability distribution of Sentiment for the given text",\
211
+ yaxis_title="Probability (確率)")
212
+ fig.update_traces(marker_color=['#FF7F7F','#32CD32'])
213
+ st.plotly_chart(fig)
214
+
215
+ elif topic == "Text Summarization":
216
+ st.markdown(
217
+ "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Summarizing Japanese News Article using multi-Lingual T5 (mT5)<b></h2>",
218
+ unsafe_allow_html=True)
219
+ st.markdown(
220
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Japanese News Article Data<b></h3>",
221
+ unsafe_allow_html=True)
222
+
223
+ news_articles = pd.read_csv("jp_news_articles_val.csv").sample(frac=0.75,
224
+ random_state=42)
225
+ gb = GridOptionsBuilder.from_dataframe(news_articles)
226
+ gb.configure_pagination()
227
+ gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
228
+ gridOptions = gb.build()
229
+ jp_article = AgGrid(news_articles, gridOptions=gridOptions, theme='material',
230
+ enable_enterprise_modules=True,
231
+ allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
232
+
233
+ # WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
234
+ if len(jp_article['selected_rows']) == 0:
235
+ st.info("Pick any one Japanese News Article by selecting the checkbox. News articles can be navigated by clicking on page navigator at right-bottom")
236
+ else:
237
+ article_text = jp_article['selected_rows'][0]['News Articles']
238
+
239
+ text = st.text_area(label="Text from selected Japanese News Article(ニュース記事)", value=article_text, height=500)
240
+ summary_length = st.slider(label="Select the maximum length of summary (要約の最大長を選択します )", min_value=120,max_value=160,step=5)
241
+
242
+ if text and st.button("Summarize it! (要約しよう)"):
243
+ waitPlaceholder = st.image("wait.gif")
244
+ summarization_model_name = "csebuetnlp/mT5_multilingual_XLSum"
245
+ tokenizer = AutoTokenizer.from_pretrained(summarization_model_name )
246
+ model = AutoModelForSeq2SeqLM.from_pretrained(summarization_model_name )
247
+
248
+ input_ids = tokenizer(
249
+ article_text,
250
+ return_tensors="pt",
251
+ padding="max_length",
252
+ truncation=True,
253
+ max_length=512
254
+ )["input_ids"]
255
+
256
+ output_ids = model.generate(
257
+ input_ids=input_ids,
258
+ max_length=summary_length,
259
+ no_repeat_ngram_size=2,
260
+ num_beams=4
261
+ )[0]
262
+
263
+ summary = tokenizer.decode(
264
+ output_ids,
265
+ skip_special_tokens=True,
266
+ clean_up_tokenization_spaces=False
267
+ )
268
+
269
+ waitPlaceholder.empty()
270
+
271
+ st.markdown(
272
+ "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Summary (要約文)<b></h2>",
273
+ unsafe_allow_html=True)
274
+
275
+ st.write(summary)
276
+ elif topic == "Japanese to English Translation":
277
+ st.markdown(
278
+ "<h2 style='text-align: left; color:#EE82EE; font-size:25px;'><b>Japanese to English translation (for short sentences)<b></h2>",
279
+ unsafe_allow_html=True)
280
+ st.markdown(
281
+ "<h3 style='text-align: center; color:#F63366; font-size:18px;'><b>Business Scene Dialog Japanese-English Corpus<b></h3>",
282
+ unsafe_allow_html=True)
283
+
284
+ st.write("Below given Japanese-English pair is from 'Business Scene Dialog Corpus' by the University of Tokyo")
285
+ link = '[Corpus GitHub Link](https://github.com/tsuruoka-lab/BSD)'
286
+ st.markdown(link, unsafe_allow_html=True)
287
+
288
+ bsd_more_info = st.expander(label="Expand to get more information on data and training report")
289
+ with bsd_more_info:
290
+ st.markdown(
291
+ "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Dataset<b></h3>",
292
+ unsafe_allow_html=True)
293
+ st.write("The corpus has total 20,000 Japanese-English Business Dialog pairs. The fined-tuned Transformer model is validated on 670 Japanese-English Business Dialog pairs")
294
+
295
+ st.markdown(
296
+ "<h3 style='text-align: left; color:#F63366; font-size:12px;'><b>Training Report<b></h3>",
297
+ unsafe_allow_html=True)
298
+ st.write(
299
+ "The Dashboard for training result on Tensorboard is [here](https://tensorboard.dev/experiment/eWhxt1i2RuaU64krYtORhw/)")
300
+
301
+ with open("./BSD_ja-en_val.json", encoding='utf-8') as f:
302
+ bsd_sample_data = json.load(f)
303
+
304
+ en, ja = [], []
305
+ for i in range(len(bsd_sample_data)):
306
+ for j in range(len(bsd_sample_data[i]['conversation'])):
307
+ en.append(bsd_sample_data[i]['conversation'][j]['en_sentence'])
308
+ ja.append(bsd_sample_data[i]['conversation'][j]['ja_sentence'])
309
+
310
+ df = pd.DataFrame.from_dict({'Japanese': ja, 'English': en})
311
+ gb = GridOptionsBuilder.from_dataframe(df)
312
+ gb.configure_pagination()
313
+ gb.configure_selection(selection_mode="single", use_checkbox=True, suppressRowDeselection=False)
314
+ gridOptions = gb.build()
315
+ translation_text = AgGrid(df, gridOptions=gridOptions, theme='material',
316
+ enable_enterprise_modules=True,
317
+ allow_unsafe_jscode=True, update_mode=GridUpdateMode.SELECTION_CHANGED)
318
+ if len(translation_text['selected_rows']) != 0:
319
+ bsd_jp = translation_text['selected_rows'][0]['Japanese']
320
+ st.markdown(
321
+ "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Business Scene Dialog in Japanese (日本語でのビジネスシーンダイアログ)<b></h2>",
322
+ unsafe_allow_html=True)
323
+ st.write(bsd_jp)
324
+
325
+ if st.button("Translate"):
326
+ ja_tokenizer, en_tokenizer = getJpEn_Tokenizers()
327
+ trained_model = loadFineTunedJaEn_NMT_Model()
328
+ trained_model.freeze()
329
+
330
+
331
+ def translate(text):
332
+ text_encoding = ja_tokenizer(
333
+ text,
334
+ max_length=100,
335
+ padding="max_length",
336
+ truncation=True,
337
+ return_attention_mask=True,
338
+ add_special_tokens=True,
339
+ return_tensors='pt'
340
+ )
341
+
342
+ generated_ids = trained_model.model.generate(
343
+ input_ids=text_encoding['input_ids'],
344
+ attention_mask=text_encoding['attention_mask'],
345
+ max_length=100,
346
+ num_beams=2,
347
+ repetition_penalty=2.5,
348
+ length_penalty=1.0,
349
+ early_stopping=True
350
+ )
351
+
352
+ preds = [en_tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for
353
+ gen_id in generated_ids]
354
+
355
+ return "".join(preds)[5:]
356
+
357
+
358
+ st.markdown(
359
+ "<h2 style='text-align: left; color:#32CD32; font-size:25px;'><b>Translated Dialog in English (英語の翻訳されたダイアログ)<b></h2>",
360
+ unsafe_allow_html=True)
361
+ st.write(translate(bsd_jp))