emircanerol commited on
Commit
0d155e5
·
verified ·
1 Parent(s): 67e22a8

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +21 -41
README.md CHANGED
@@ -9,28 +9,14 @@ pipeline_tag: token-classification
9
  ---
10
 
11
  ```python
12
- from peft import PeftModel, prepare_model_for_kbit_training
13
- from transformers import T5ForTokenClassification, BitsAndBytesConfig
14
- import torch
15
 
16
- model_id = "google/byt5-small"
 
 
17
 
18
- bnb_config = BitsAndBytesConfig(
19
- load_in_4bit=True,
20
- bnb_4bit_quant_type="nf4",
21
- bnb_4bit_use_double_quant=True,
22
- bnb_4bit_compute_dtype=torch.bfloat16,
23
- )
24
-
25
- model = T5ForTokenClassification.from_pretrained(model_id,
26
- num_labels=2,
27
- torch_dtype=torch.bfloat16,
28
- quantization_config=bnb_config,
29
- device_map="auto",)
30
- model = prepare_model_for_kbit_training(model)
31
- model = PeftModel.from_pretrained(model, 'bite-the-byte/byt5-small-deASCIIfy-TR')
32
-
33
- def test_mask(data):
34
  """
35
  Masks the padded tokens in the input.
36
  Args:
@@ -39,21 +25,16 @@ def test_mask(data):
39
  dataset (list): List of dictionaries.
40
  """
41
 
42
- dataset = list()
43
- for sample in data:
44
- new_sample = dict()
45
 
46
- input_tokens = [i + 3 for i in sample.encode('utf-8')]
47
- input_tokens.append(0) # eos token
48
- new_sample['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64)
49
-
50
- # Create attention mask
51
- attention_mask = [1] * len(input_tokens) # Attend to all tokens
52
- new_sample['attention_mask'] = torch.tensor([attention_mask], dtype=torch.int64)
53
-
54
- dataset.append(new_sample)
55
-
56
- return dataset
57
 
58
  def rewrite(model, data):
59
  """
@@ -66,24 +47,23 @@ def rewrite(model, data):
66
  """
67
 
68
  with torch.no_grad():
69
- data = {k: v.to(model.device) for k, v in data.items()}
70
- pred = torch.argmax(model(**data).logits, dim=2)
71
-
72
  output = list() # save the indices of the characters as list of integers
73
 
74
  # Conversion table for Turkish characters {100: [300, 350], ...}
75
  en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))}
76
 
77
- for inp, lab in zip((data['input_ids'] - 3)[0].tolist(), pred[0].tolist()):
78
  if lab and inp in en2tr:
79
  # if the model predicts a diacritic, replace it with the corresponding Turkish character
80
  output.extend(en2tr[inp])
81
  elif inp >= 0: output.append(inp)
82
  return bytes(output).decode()
83
 
84
- def try_it(text, model):#=model):
85
- sample = test_mask([text])
86
- return rewrite(model, sample[0])
87
 
88
  try_it('Cekoslovakyalilastiramadiklarimizdan misiniz?', model)
89
  ```
 
9
  ---
10
 
11
  ```python
12
+ from peft import PeftModel, PeftConfig
13
+ from transformers import AutoModelForTokenClassification
 
14
 
15
+ config = PeftConfig.from_pretrained("bite-the-byte/byt5-small-deASCIIfy-TR")
16
+ model = AutoModelForTokenClassification.from_pretrained("google/byt5-small")
17
+ model = PeftModel.from_pretrained(model, "bite-the-byte/byt5-small-deASCIIfy-TR")
18
 
19
+ def test_mask(device, sample):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  """
21
  Masks the padded tokens in the input.
22
  Args:
 
25
  dataset (list): List of dictionaries.
26
  """
27
 
28
+ tokens = dict()
 
 
29
 
30
+ input_tokens = [i + 3 for i in sample.encode('utf-8')]
31
+ input_tokens.append(0) # eos token
32
+ tokens['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64, device=device)
33
+
34
+ # Create attention mask
35
+ tokens['attention_mask'] = torch.ones_like(tokens['input_ids'], dtype=torch.int64, device=device)
36
+
37
+ return tokens
 
 
 
38
 
39
  def rewrite(model, data):
40
  """
 
47
  """
48
 
49
  with torch.no_grad():
50
+ pred = torch.argmax(model(**data).logits, dim=2).squeeze(0)
51
+
 
52
  output = list() # save the indices of the characters as list of integers
53
 
54
  # Conversion table for Turkish characters {100: [300, 350], ...}
55
  en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))}
56
 
57
+ for inp, lab in zip((data['input_ids'].squeeze(0) - 3).tolist(), pred.tolist()):
58
  if lab and inp in en2tr:
59
  # if the model predicts a diacritic, replace it with the corresponding Turkish character
60
  output.extend(en2tr[inp])
61
  elif inp >= 0: output.append(inp)
62
  return bytes(output).decode()
63
 
64
+ def try_it(text, model):
65
+ sample = test_mask(model.device, text)
66
+ return rewrite(model, sample)
67
 
68
  try_it('Cekoslovakyalilastiramadiklarimizdan misiniz?', model)
69
  ```