liamcripwell commited on
Commit
3ab7934
1 Parent(s): 8521c8e

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +72 -0
README.md CHANGED
@@ -91,4 +91,76 @@ template = """{
91
  prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
92
  print(prediction)
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  ```
 
91
  prediction = predict_NuExtract(model, tokenizer, [text], template)[0]
92
  print(prediction)
93
 
94
+ ```
95
+
96
+ Sliding window prompting:
97
+
98
+ ```python
99
+ import json
100
+
101
+ MAX_INPUT_SIZE = 20_000
102
+ MAX_NEW_TOKENS = 6000
103
+
104
+ def clean_json_text(text):
105
+ text = text.strip()
106
+ text = text.replace("\#", "#").replace("\&", "&")
107
+ return text
108
+
109
+ def predict_chunk(text, template, current, model, tokenizer):
110
+ current = clean_json_text(current)
111
+
112
+ input_llm = f"<|input|>\n### Template:\n{template}\n### Current:\n{current}\n### Text:\n{text}\n\n<|output|>" + "{"
113
+ input_ids = tokenizer(input_llm, return_tensors="pt", truncation=True, max_length=MAX_INPUT_SIZE).to("cuda")
114
+ output = tokenizer.decode(model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS)[0], skip_special_tokens=True)
115
+
116
+ return clean_json_text(output.split("<|output|>")[1])
117
+
118
+ def split_document(document, window_size, overlap):
119
+ tokens = tokenizer.tokenize(document)
120
+ print(f"\tLength of document: {len(tokens)} tokens")
121
+
122
+ chunks = []
123
+ if len(tokens) > window_size:
124
+ for i in range(0, len(tokens), window_size-overlap):
125
+ print(f"\t{i} to {i + len(tokens[i:i + window_size])}")
126
+ chunk = tokenizer.convert_tokens_to_string(tokens[i:i + window_size])
127
+ chunks.append(chunk)
128
+
129
+ if i + len(tokens[i:i + window_size]) >= len(tokens):
130
+ break
131
+ else:
132
+ chunks.append(document)
133
+ print(f"\tSplit into {len(chunks)} chunks")
134
+
135
+ return chunks
136
+
137
+ def handle_broken_output(pred, prev):
138
+ try:
139
+ if all([(v in ["", []]) for v in json.loads(pred).values()]):
140
+ # if empty json, return previous
141
+ pred = prev
142
+ except:
143
+ # if broken json, return previous
144
+ pred = prev
145
+
146
+ return pred
147
+
148
+ def sliding_window_prediction(text, template, model, tokenizer, window_size=4000, overlap=128):
149
+ # split text into chunks of n tokens
150
+ tokens = tokenizer.tokenize(text)
151
+ chunks = split_document(text, window_size, overlap)
152
+
153
+ # iterate over text chunks
154
+ prev = template
155
+ for i, chunk in enumerate(chunks):
156
+ print(f"Processing chunk {i}...")
157
+ pred = predict_chunk(chunk, template, prev, model, tokenizer)
158
+
159
+ # handle broken output
160
+ pred = handle_broken_output(pred, prev)
161
+
162
+ # iterate
163
+ prev = pred
164
+
165
+ return pred
166
  ```