werent4 commited on
Commit
483c402
·
verified ·
1 Parent(s): b892dfb

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +13 -8
README.md CHANGED
@@ -21,12 +21,17 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
 
22
  from transformers import T5Tokenizer, MT5ForConditionalGeneration
23
 
24
- tokenizer = T5Tokenizer.from_pretrained('google/mt5-small')
25
  model = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
26
  model.to(device)
27
-
28
- def translate(text, model, tokenizer, device):
29
- input_text = f"translate English to Lithuanian: {text}"
 
 
 
 
 
30
  encoded_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
31
  with torch.no_grad():
32
  output_tokens = model.generate(
@@ -40,10 +45,6 @@ def translate(text, model, tokenizer, device):
40
  translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
41
  return translated_text
42
 
43
- text = "women"
44
- translate(text, model, tokenizer, device)
45
- `moteris`
46
-
47
  text = "How are you?"
48
  translate(text, model, tokenizer, device)
49
  `Kaip esate?`
@@ -51,6 +52,10 @@ translate(text, model, tokenizer, device)
51
  text = "I live in Kaunas"
52
  translate(text, model, tokenizer, device)
53
  `Aš gyvenu Kaunas`
 
 
 
 
54
  ```
55
 
56
 
 
21
 
22
  from transformers import T5Tokenizer, MT5ForConditionalGeneration
23
 
24
+ tokenizer = T5Tokenizer.from_pretrained('werent4/mt5TranslatorLT')
25
  model = MT5ForConditionalGeneration.from_pretrained("werent4/mt5TranslatorLT")
26
  model.to(device)
27
+ def translate(text, model, tokenizer, device, translation_way = "en-lt"):
28
+ translations_ways = {
29
+ "en-lt": "<EN2LT>",
30
+ "lt-en": "<LT2EN>"
31
+ }
32
+ if translation_way not in translations_ways:
33
+ raise ValueError(f"Invalid translation way. Supported ways: {list(translations_ways.keys())}")
34
+ input_text = f"{translations_ways[translation_way]} {text}"
35
  encoded_input = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
36
  with torch.no_grad():
37
  output_tokens = model.generate(
 
45
  translated_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
46
  return translated_text
47
 
 
 
 
 
48
  text = "How are you?"
49
  translate(text, model, tokenizer, device)
50
  `Kaip esate?`
 
52
  text = "I live in Kaunas"
53
  translate(text, model, tokenizer, device)
54
  `Aš gyvenu Kaunas`
55
+
56
+ text = "Mano vardas yra Karolis"
57
+ translate(text, model, tokenizer, device, translation_way= "lt-en")
58
+ `My name is Karolis`
59
  ```
60
 
61