emircanerol commited on
Commit
67e22a8
·
verified ·
1 Parent(s): 1fa24d0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +89 -3
README.md CHANGED
@@ -1,3 +1,89 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ datasets:
4
+ - cc100
5
+ language:
6
+ - tr
7
+ library_name: peft
8
+ pipeline_tag: token-classification
9
+ ---
10
+
11
+ ```python
12
+ from peft import PeftModel, prepare_model_for_kbit_training
13
+ from transformers import T5ForTokenClassification, BitsAndBytesConfig
14
+ import torch
15
+
16
+ model_id = "google/byt5-small"
17
+
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_use_double_quant=True,
22
+ bnb_4bit_compute_dtype=torch.bfloat16,
23
+ )
24
+
25
+ model = T5ForTokenClassification.from_pretrained(model_id,
26
+ num_labels=2,
27
+ torch_dtype=torch.bfloat16,
28
+ quantization_config=bnb_config,
29
+ device_map="auto",)
30
+ model = prepare_model_for_kbit_training(model)
31
+ model = PeftModel.from_pretrained(model, 'bite-the-byte/byt5-small-deASCIIfy-TR')
32
+
33
+ def test_mask(data):
34
+ """
35
+ Masks the padded tokens in the input.
36
+ Args:
37
+ data (list): List of strings.
38
+ Returns:
39
+ dataset (list): List of dictionaries.
40
+ """
41
+
42
+ dataset = list()
43
+ for sample in data:
44
+ new_sample = dict()
45
+
46
+ input_tokens = [i + 3 for i in sample.encode('utf-8')]
47
+ input_tokens.append(0) # eos token
48
+ new_sample['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64)
49
+
50
+ # Create attention mask
51
+ attention_mask = [1] * len(input_tokens) # Attend to all tokens
52
+ new_sample['attention_mask'] = torch.tensor([attention_mask], dtype=torch.int64)
53
+
54
+ dataset.append(new_sample)
55
+
56
+ return dataset
57
+
58
+ def rewrite(model, data):
59
+ """
60
+ Rewrites the input text with the model.
61
+ Args:
62
+ model (torch.nn.Module): Model.
63
+ data (dict): Dictionary containing 'input_ids' and 'attention_mask'.
64
+ Returns:
65
+ output (str): Rewritten text.
66
+ """
67
+
68
+ with torch.no_grad():
69
+ data = {k: v.to(model.device) for k, v in data.items()}
70
+ pred = torch.argmax(model(**data).logits, dim=2)
71
+
72
+ output = list() # save the indices of the characters as list of integers
73
+
74
+ # Conversion table for Turkish characters {100: [300, 350], ...}
75
+ en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))}
76
+
77
+ for inp, lab in zip((data['input_ids'] - 3)[0].tolist(), pred[0].tolist()):
78
+ if lab and inp in en2tr:
79
+ # if the model predicts a diacritic, replace it with the corresponding Turkish character
80
+ output.extend(en2tr[inp])
81
+ elif inp >= 0: output.append(inp)
82
+ return bytes(output).decode()
83
+
84
+ def try_it(text, model):#=model):
85
+ sample = test_mask([text])
86
+ return rewrite(model, sample[0])
87
+
88
+ try_it('Cekoslovakyalilastiramadiklarimizdan misiniz?', model)
89
+ ```