Files changed (1) hide show
  1. app.py +1 -83
app.py CHANGED
@@ -1,83 +1 @@
1
- import os
2
- import torch
3
- import gradio as gr
4
- import time
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
- from flores200_codes import flores_codes
7
-
8
-
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B'}
12
-
13
- model_dict = {}
14
-
15
- for call_name, real_name in model_name_dict.items():
16
- print('\tLoading model: %s' % call_name)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
18
- tokenizer = AutoTokenizer.from_pretrained(real_name)
19
- model_dict[call_name+'_model'] = model
20
- model_dict[call_name+'_tokenizer'] = tokenizer
21
-
22
- return model_dict
23
-
24
-
25
- def translation(source, target, text):
26
- if len(model_dict) == 2:
27
- model_name = 'nllb-distilled-1.3B'
28
-
29
- start_time = time.time()
30
- source = flores_codes[source]
31
- target = flores_codes[target]
32
-
33
- model = model_dict[model_name + '_model']
34
- tokenizer = model_dict[model_name + '_tokenizer']
35
-
36
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
37
- output = translator(text, max_length=400)
38
-
39
- end_time = time.time()
40
-
41
- output = output[0]['translation_text']
42
- result = {'inference_time': end_time - start_time,
43
- 'source': source,
44
- 'target': target,
45
- 'result': output}
46
- return result
47
-
48
-
49
- if __name__ == '__main__':
50
- print('\tinit models')
51
-
52
- global model_dict
53
-
54
- model_dict = load_models()
55
-
56
- # define gradio demo
57
- lang_codes = list(flores_codes.keys())
58
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
59
- inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
60
- gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
61
- gr.inputs.Textbox(lines=5, label="Input text"),
62
- ]
63
-
64
- outputs = gr.outputs.JSON()
65
-
66
- title = "NLLB distilled 1.3B demo"
67
-
68
- demo_status = "Demo is running on CPU"
69
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
70
- examples = [
71
- ['English', 'Korean', 'Hi. nice to meet you']
72
- ]
73
-
74
- gr.Interface(translation,
75
- inputs,
76
- outputs,
77
- title=title,
78
- description=description,
79
- examples=examples,
80
- examples_per_page=50,
81
- ).launch()
82
-
83
-
 
1
+ Who is lebron