File size: 2,379 Bytes
9cdcbb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import sys
import os


sys.path.append("./")
def render_token(t: bytes) -> str:
    # pretty print a token, escaping control characters
    s = t.decode('utf-8', errors='replace')
    # s = replace_control_characters(s)
    return s


def get_freq_pairs(inp_toks):
  """Returns a count of the pairs"""
  count = {}
  for pair in zip(inp_toks, inp_toks[1:]):
    count[pair] = count.get(pair,0) + 1
  return count


def merge(id_list, pair, replace_with_idx):
  """
  Replace the occurence of 'pair' in 'id_list' with 'replace_with_idx'

  id_list : List of tokens
  pair : List of 2 numbers
  replace_with_idx : Int value

  Returns new list with the pair replaced
  """
  i=0
  new_ids_list = []
  while(i<len(id_list)):
    if(i<len(id_list)-1 and id_list[i]==pair[0] and id_list[i+1]==pair[1]):
      new_ids_list.append(replace_with_idx)

      i+=2
    else:
      new_ids_list.append(id_list[i])
      i+=1

  return new_ids_list

class Tokenizer():
  def __init__(self):
    self.merges = {}
    ##vocab -> (int) : bytes . For all ints (0-256, 256+ from new merges)

    self.vocab = {}
    self.load()
  
  
  
  def save(self):
    with open('merges.txt', 'w') as f:
      ##Write only the pairs. Not the index of the merged pairs.
      ##When the tokenizer is loaded, allow the user to specify the index
      for p1,p2 in self.merges.keys():
        f.write(f"{p1} {p2}\n")


    with open('vocab.txt', 'w') as f:
      for idx, byte in self.vocab.items():
        s = render_token(byte)
        f.write(f"{idx} {s}\n")

  def _build_vocab(self):
    self.vocab = {idx: bytes([idx]) for idx in range(256)}
    try:
      
      for (tok0, tok1),idx in self.merges.items():
        self.vocab[idx] = self.vocab[tok0] + self.vocab[tok1]
    except Exception as e:
      print(e)



  def load(self):
    try:
      # print("Loading", os.getcwd(), "hey" , __file__)
      with open(os.path.join(os.path.dirname(os.path.abspath(__file__)),'merges.txt'), 'r') as file:
      
        idx = 256
        for line in file:
          tok0, tok1 = map(int,line.split())
          self.merges[(tok0, tok1)] = idx
          idx += 1


      # print(self.merges)

      self._build_vocab()

        

    


    except Exception as e:
      print(e)
      
    


if __name__ == '__main__':
  # print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))
  tokenizer = Tokenizer()