sarahai commited on
Commit
a1af8f5
·
verified ·
1 Parent(s): bf19ee0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -30
app.py CHANGED
@@ -1,44 +1,52 @@
1
  import streamlit as st
2
- from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration
3
- from transformers import NllbTokenizer, T5Tokenizer
4
 
5
- # Load translation model and tokenizer
6
  translation_model_name = 'sarahai/nllb-uzbek-cyrillic-to-russian'
7
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
8
  translation_tokenizer = NllbTokenizer.from_pretrained(translation_model_name)
9
 
10
-
11
- # Загрузка модели суммаризации и токенизатора
12
  summarization_model_name = 'sarahai/ruT5-base-summarizer'
13
  summarization_model = T5ForConditionalGeneration.from_pretrained(summarization_model_name)
14
  summarization_tokenizer = T5Tokenizer.from_pretrained(summarization_model_name)
15
 
16
- def translate(text, translation_model, translation_tokenizer, src_lang='uzb_Cyrl', tgt_lang='rus_Cyrl', a=16, b=1.5, max_input_length=2048):
17
- translation_tokenizer.src_lang = src_lang
18
- translation_tokenizer.tgt_lang = tgt_lang
19
- inputs = translation_tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=max_input_length)
20
- outputs = translation_model.generate(
21
- inputs['input_ids'],
22
- forced_bos_token_id=translation_tokenizer.lang_code_to_id[tgt_lang],
23
- max_new_tokens=int(a + b * inputs.input_ids.shape[1])
24
- )
25
- translated_text = translation_tokenizer.decode(outputs[0], skip_special_tokens=True)
26
- return translated_text
27
-
28
- def summarize(translated_text, summarization_model, summarization_tokenizer, max_length=250):
29
- input_ids = summarization_tokenizer.encode("summarize: " + translated_text, return_tensors="pt", max_length=2048, truncation=True)
30
- summary_ids = summarization_model.generate(
31
- input_ids,
32
- max_length=max_length,
33
- min_length=min_length,
34
- length_penalty=2.0,
35
- num_beams=4,
36
- early_stopping=True
37
- )
38
- summary = summarization_tokenizer.decode(summary_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
39
  return summary
40
 
41
- # Настройка интерфейса Streamlit
42
  st.title("Перевод с узбекского на русский и суммаризация")
43
  text = st.text_area("Введите текст на узбекском:", height=200)
44
 
@@ -53,4 +61,3 @@ if st.button("Перевести и суммаризировать"):
53
  st.text_area("Суммаризация (на русском):", value=summary_text, height=100)
54
  else:
55
  st.warning("Пожалуйста, введите текст на узбекском языке для перевода.")
56
-
 
1
  import streamlit as st
2
+ from transformers import AutoModelForSeq2SeqLM, T5ForConditionalGeneration, NllbTokenizer, T5Tokenizer
 
3
 
4
+ # Initialize models and tokenizers
5
  translation_model_name = 'sarahai/nllb-uzbek-cyrillic-to-russian'
6
  translation_model = AutoModelForSeq2SeqLM.from_pretrained(translation_model_name)
7
  translation_tokenizer = NllbTokenizer.from_pretrained(translation_model_name)
8
 
 
 
9
  summarization_model_name = 'sarahai/ruT5-base-summarizer'
10
  summarization_model = T5ForConditionalGeneration.from_pretrained(summarization_model_name)
11
  summarization_tokenizer = T5Tokenizer.from_pretrained(summarization_model_name)
12
 
13
+ def split_into_chunks(text, tokenizer, max_length=150):
14
+ # Tokenize the text and get ids
15
+ tokens = tokenizer.tokenize(text)
16
+ # Initialize chunks
17
+ chunks = []
18
+ current_chunk = []
19
+ current_length = 0
20
+ for token in tokens:
21
+ current_chunk.append(token)
22
+ current_length += 1
23
+ if current_length >= max_length:
24
+ chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
25
+ current_chunk = []
26
+ current_length = 0
27
+ # Add the last chunk if it's not empty
28
+ if current_chunk:
29
+ chunks.append(tokenizer.convert_tokens_to_string(current_chunk))
30
+ return chunks
31
+
32
+ def translate(text, model, tokenizer, src_lang='uzb_Cyrl', tgt_lang='rus_Cyrl'):
33
+ tokenizer.src_lang = src_lang
34
+ tokenizer.tgt_lang = tgt_lang
35
+ chunks = split_into_chunks(text, tokenizer)
36
+ translated_chunks = []
37
+ for chunk in chunks:
38
+ inputs = tokenizer(chunk, return_tensors='pt', padding=True, truncation=True, max_length=128)
39
+ outputs = model.generate(inputs['input_ids'], forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])
40
+ translated_chunks.append(tokenizer.decode(outputs[0], skip_special_tokens=True))
41
+ return ' '.join(translated_chunks)
42
+
43
+ def summarize(text, model, tokenizer, max_length=250):
44
+ input_ids = tokenizer.encode("summarize: " + text, return_tensors="pt", max_length=2048, truncation=True)
45
+ summary_ids = model.generate(input_ids, max_length=max_length, length_penalty=2.0, num_beams=4, early_stopping=True)
46
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
47
  return summary
48
 
49
+ # Streamlit UI
50
  st.title("Перевод с узбекского на русский и суммаризация")
51
  text = st.text_area("Введите текст на узбекском:", height=200)
52
 
 
61
  st.text_area("Суммаризация (на русском):", value=summary_text, height=100)
62
  else:
63
  st.warning("Пожалуйста, введите текст на узбекском языке для перевода.")