Timing0311 commited on
Commit
b849606
·
1 Parent(s): 401530d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -3
app.py CHANGED
@@ -1,10 +1,17 @@
1
- from transformers import MT5ForConditionalGeneration, AutoTokenizer, Text2TextGenerationPipeline
2
  import gradio as gr
 
3
 
 
4
  trans_mdl = MT5ForConditionalGeneration.from_pretrained("K024/mt5-zh-ja-en-trimmed")
5
  trans_tokenizer = AutoTokenizer.from_pretrained("K024/mt5-zh-ja-en-trimmed")
6
  trans_pipe = Text2TextGenerationPipeline(model=trans_mdl, tokenizer=trans_tokenizer)
7
 
 
 
 
 
 
8
  def translation_job(job, text):
9
  # 设置翻译任务和提示语的映射
10
  job_key = ["中译日", "中译英", "日译中", "英译中", "日译英", "英译日"]
@@ -15,7 +22,34 @@ def translation_job(job, text):
15
  print(input)
16
  response = trans_pipe(input, max_length=100, num_beams=4)
17
  return response[0]['generated_text']
18
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
  with gr.Blocks() as app:
@@ -29,8 +63,16 @@ with gr.Blocks() as app:
29
  source_text = gr.Textbox(lines=1, label="翻译文本", placeholder="请输入要翻译的文本")
30
  trans_result = gr.Textbox(lines=1, label="翻译结果")
31
  trans_btn = gr.Button("翻译")
32
-
 
 
 
 
 
 
33
  trans_btn.click(translation_job, inputs=[job_name, source_text], outputs=trans_result)
 
 
34
  app.launch()
35
 
36
 
 
1
+ from transformers import MT5ForConditionalGeneration, AutoTokenizer, Text2TextGenerationPipeline, AutoModelForSeq2SeqLM
2
  import gradio as gr
3
+ import re
4
 
5
+ # 翻译任务设置
6
  trans_mdl = MT5ForConditionalGeneration.from_pretrained("K024/mt5-zh-ja-en-trimmed")
7
  trans_tokenizer = AutoTokenizer.from_pretrained("K024/mt5-zh-ja-en-trimmed")
8
  trans_pipe = Text2TextGenerationPipeline(model=trans_mdl, tokenizer=trans_tokenizer)
9
 
10
+ # 摘要任务设置
11
+ sum_mdl = AutoModelForSeq2SeqLM.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
12
+ sum_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/mT5_multilingual_XLSum")
13
+
14
+
15
  def translation_job(job, text):
16
  # 设置翻译任务和提示语的映射
17
  job_key = ["中译日", "中译英", "日译中", "英译中", "日译英", "英译日"]
 
22
  print(input)
23
  response = trans_pipe(input, max_length=100, num_beams=4)
24
  return response[0]['generated_text']
25
+
26
+
27
+ def sum_job(text):
28
+ # 去除源文本中的空格
29
+ WHITESPACE_HANDLER = lambda k: re.sub('\s+', ' ', re.sub('\n+', ' ', k.strip()))
30
+
31
+ input_ids = sum_tokenizer(
32
+ [WHITESPACE_HANDLER(text)],
33
+ return_tensors="pt",
34
+ padding="max_length",
35
+ truncation=True,
36
+ max_length=512
37
+ )["input_ids"]
38
+
39
+ output_ids = sum_mdl.generate(
40
+ input_ids=input_ids,
41
+ max_length=84,
42
+ no_repeat_ngram_size=2,
43
+ num_beams=4
44
+ )[0]
45
+
46
+ response = sum_tokenizer.decode(
47
+ output_ids,
48
+ skip_special_tokens=True,
49
+ clean_up_tokenization_spaces=False
50
+ )
51
+
52
+ return response
53
 
54
 
55
  with gr.Blocks() as app:
 
63
  source_text = gr.Textbox(lines=1, label="翻译文本", placeholder="请输入要翻译的文本")
64
  trans_result = gr.Textbox(lines=1, label="翻译结果")
65
  trans_btn = gr.Button("翻译")
66
+
67
+ # 多语言自动摘要任务
68
+ with gr.Tab("多语言自动摘要"):
69
+ article_text = gr.Textbox(lines=8, label="待总结文本", placeholder="请输入要进行摘要的文本")
70
+ sum_result = gr.Textbox(lines=2, label="摘要结果")
71
+ sum_btn = gr.Button("摘要")
72
+
73
  trans_btn.click(translation_job, inputs=[job_name, source_text], outputs=trans_result)
74
+ sum_btn.click(sum_job, inputs=article_text, outputs=sum_result)
75
+
76
  app.launch()
77
 
78