Shaltiel commited on
Commit
8e97fcd
·
1 Parent(s): c073310

Linked python files to main dictabert-joint

Browse files
BertForJointParsing.py DELETED
@@ -1,523 +0,0 @@
1
- from dataclasses import dataclass
2
- import re
3
- from operator import itemgetter
4
- import torch
5
- from torch import nn
6
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
7
- from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
- from transformers.models.bert.modeling_bert import BertOnlyMLMHead
9
- from transformers.utils import ModelOutput
10
- from .BertForSyntaxParsing import BertSyntaxParsingHead, SyntaxLabels, SyntaxLogitsOutput, parse_logits as syntax_parse_logits
11
- from .BertForPrefixMarking import BertPrefixMarkingHead, parse_logits as prefix_parse_logits, encode_sentences_for_bert_for_prefix_marking, get_prefixes_from_str
12
- from .BertForMorphTagging import BertMorphTaggingHead, MorphLogitsOutput, MorphLabels, parse_logits as morph_parse_logits
13
-
14
- import warnings
15
-
16
- @dataclass
17
- class JointParsingOutput(ModelOutput):
18
- loss: Optional[torch.FloatTensor] = None
19
- # logits will contain the optional predictions for the given labels
20
- logits: Optional[Union[SyntaxLogitsOutput, None]] = None
21
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
22
- attentions: Optional[Tuple[torch.FloatTensor]] = None
23
- # if no labels are given, we will always include the syntax logits separately
24
- syntax_logits: Optional[SyntaxLogitsOutput] = None
25
- ner_logits: Optional[torch.FloatTensor] = None
26
- prefix_logits: Optional[torch.FloatTensor] = None
27
- lex_logits: Optional[torch.FloatTensor] = None
28
- morph_logits: Optional[MorphLogitsOutput] = None
29
-
30
- # wrapper class to wrap a torch.nn.Module so that you can store a module in multiple linked
31
- # properties without registering the parameter multiple times
32
- class ModuleRef:
33
- def __init__(self, module: torch.nn.Module):
34
- self.module = module
35
-
36
- def forward(self, *args, **kwargs):
37
- return self.module.forward(*args, **kwargs)
38
-
39
- def __call__(self, *args, **kwargs):
40
- return self.module(*args, **kwargs)
41
-
42
- class BertForJointParsing(BertPreTrainedModel):
43
- _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
44
-
45
- def __init__(self, config, do_syntax=None, do_ner=None, do_prefix=None, do_lex=None, do_morph=None, syntax_head_size=64):
46
- super().__init__(config)
47
-
48
- self.bert = BertModel(config, add_pooling_layer=False)
49
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
50
- # create all the heads as None, and then populate them as defined
51
- self.syntax, self.ner, self.prefix, self.lex, self.morph = (None,)*5
52
-
53
- if do_syntax is not None:
54
- config.do_syntax = do_syntax
55
- config.syntax_head_size = syntax_head_size
56
- if do_ner is not None: config.do_ner = do_ner
57
- if do_prefix is not None: config.do_prefix = do_prefix
58
- if do_lex is not None: config.do_lex = do_lex
59
- if do_morph is not None: config.do_morph = do_morph
60
-
61
- # add all the individual heads
62
- if config.do_syntax:
63
- self.syntax = BertSyntaxParsingHead(config)
64
- if config.do_ner:
65
- self.num_labels = config.num_labels
66
- self.classifier = nn.Linear(config.hidden_size, config.num_labels) # name it same as in BertForTokenClassification
67
- self.ner = ModuleRef(self.classifier)
68
- if config.do_prefix:
69
- self.prefix = BertPrefixMarkingHead(config)
70
- if config.do_lex:
71
- self.cls = BertOnlyMLMHead(config) # name it the same as in BertForMaskedLM
72
- self.lex = ModuleRef(self.cls)
73
- if config.do_morph:
74
- self.morph = BertMorphTaggingHead(config)
75
-
76
- # Initialize weights and apply final processing
77
- self.post_init()
78
-
79
- def get_output_embeddings(self):
80
- return self.cls.predictions.decoder if self.lex is not None else None
81
-
82
- def set_output_embeddings(self, new_embeddings):
83
- if self.lex is not None:
84
-
85
- self.cls.predictions.decoder = new_embeddings
86
-
87
- def forward(
88
- self,
89
- input_ids: Optional[torch.Tensor] = None,
90
- attention_mask: Optional[torch.Tensor] = None,
91
- token_type_ids: Optional[torch.Tensor] = None,
92
- position_ids: Optional[torch.Tensor] = None,
93
- prefix_class_id_options: Optional[torch.Tensor] = None,
94
- labels: Optional[Union[SyntaxLabels, MorphLabels, torch.Tensor]] = None,
95
- labels_type: Optional[Literal['syntax', 'ner', 'prefix', 'lex', 'morph']] = None,
96
- head_mask: Optional[torch.Tensor] = None,
97
- inputs_embeds: Optional[torch.Tensor] = None,
98
- output_attentions: Optional[bool] = None,
99
- output_hidden_states: Optional[bool] = None,
100
- return_dict: Optional[bool] = None,
101
- compute_syntax_mst: Optional[bool] = None
102
- ):
103
- if return_dict is False:
104
- warnings.warn("Specified `return_dict=False` but the flag is ignored and treated as always True in this model.")
105
-
106
- if labels is not None and labels_type is None:
107
- raise ValueError("Cannot specify labels without labels_type")
108
-
109
- if labels_type == 'seg' and prefix_class_id_options is None:
110
- raise ValueError('Cannot calculate prefix logits without prefix_class_id_options')
111
-
112
- if compute_syntax_mst is not None and self.syntax is None:
113
- raise ValueError("Cannot compute syntax MST when the syntax head isn't loaded")
114
-
115
-
116
- bert_outputs = self.bert(
117
- input_ids,
118
- attention_mask=attention_mask,
119
- token_type_ids=token_type_ids,
120
- position_ids=position_ids,
121
- head_mask=head_mask,
122
- inputs_embeds=inputs_embeds,
123
- output_attentions=output_attentions,
124
- output_hidden_states=output_hidden_states,
125
- return_dict=True,
126
- )
127
-
128
- # calculate the extended attention mask for any child that might need it
129
- extended_attention_mask = None
130
- if attention_mask is not None:
131
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
132
-
133
- # extract the hidden states, and apply the dropout
134
- hidden_states = self.dropout(bert_outputs[0])
135
-
136
- logits = None
137
- syntax_logits = None
138
- ner_logits = None
139
- prefix_logits = None
140
- lex_logits = None
141
- morph_logits = None
142
-
143
- # Calculate the syntax
144
- if self.syntax is not None and (labels is None or labels_type == 'syntax'):
145
- # apply the syntax head
146
- loss, syntax_logits = self.syntax(hidden_states, extended_attention_mask, labels, compute_syntax_mst)
147
- logits = syntax_logits
148
-
149
- # Calculate the NER
150
- if self.ner is not None and (labels is None or labels_type == 'ner'):
151
- ner_logits = self.ner(hidden_states)
152
- logits = ner_logits
153
- if labels is not None:
154
- loss_fct = nn.CrossEntropyLoss()
155
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
156
-
157
- # Calculate the segmentation
158
- if self.prefix is not None and (labels is None or labels_type == 'prefix'):
159
- loss, prefix_logits = self.prefix(hidden_states, prefix_class_id_options, labels)
160
- logits = prefix_logits
161
-
162
- # Calculate the lexeme
163
- if self.lex is not None and (labels is None or labels_type == 'lex'):
164
- lex_logits = self.lex(hidden_states)
165
- logits = lex_logits
166
- if labels is not None:
167
- loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
168
- loss = loss_fct(lex_logits.view(-1, self.config.vocab_size), labels.view(-1))
169
-
170
- if self.morph is not None and (labels is None or labels_type == 'morph'):
171
- loss, morph_logits = self.morph(hidden_states, labels)
172
- logits = morph_logits
173
-
174
- # no labels => logits = None
175
- if labels is None: logits = None
176
-
177
- return JointParsingOutput(
178
- loss,
179
- logits,
180
- hidden_states=bert_outputs.hidden_states,
181
- attentions=bert_outputs.attentions,
182
- # all the predicted logits section
183
- syntax_logits=syntax_logits,
184
- ner_logits=ner_logits,
185
- prefix_logits=prefix_logits,
186
- lex_logits=lex_logits,
187
- morph_logits=morph_logits
188
- )
189
-
190
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, padding='longest', truncation=True, compute_syntax_mst=True, per_token_ner=False, output_style: Literal['json', 'ud', 'iahlt_ud'] = 'json'):
191
- is_single_sentence = isinstance(sentences, str)
192
- if is_single_sentence:
193
- sentences = [sentences]
194
-
195
- if output_style not in ['json', 'ud', 'iahlt_ud']:
196
- raise ValueError('output_style must be in json/ud/iahlt_ud')
197
- if output_style in ['ud', 'iahlt_ud'] and (self.prefix is None or self.morph is None or self.syntax is None or self.lex is None):
198
- raise ValueError("Cannot output UD format when any of the prefix,morph,syntax, and lex heads aren't loaded.")
199
-
200
- # predict the logits for the sentence
201
- if self.prefix is not None:
202
- inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
203
- else:
204
- inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
205
-
206
- offset_mapping = inputs.pop('offset_mapping')
207
- # Copy the tensors to the right device, and parse!
208
- inputs = {k:v.to(self.device) for k,v in inputs.items()}
209
- output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
210
-
211
- input_ids = inputs['input_ids'].tolist() # convert once
212
- final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, input_ids, offset_mapping)]
213
- # Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
214
- if output.syntax_logits is not None:
215
- for sent_idx,parsed in enumerate(syntax_parse_logits(input_ids, sentences, tokenizer, output.syntax_logits)):
216
- merge_token_list(final_output[sent_idx]['tokens'], parsed['tree'], 'syntax')
217
- final_output[sent_idx]['root_idx'] = parsed['root_idx']
218
-
219
- # Prefix logits: each sentence gets a list([prefix_segment, word_without_prefix]) - **WITH CLS & SEP**
220
- if output.prefix_logits is not None:
221
- for sent_idx,parsed in enumerate(prefix_parse_logits(input_ids, sentences, tokenizer, output.prefix_logits)):
222
- merge_token_list(final_output[sent_idx]['tokens'], map(tuple, parsed[1:-1]), 'seg')
223
-
224
- # Lex logits each sentence gets a list(tuple(word, lexeme))
225
- if output.lex_logits is not None:
226
- for sent_idx, parsed in enumerate(lex_parse_logits(input_ids, sentences, tokenizer, output.lex_logits)):
227
- merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'lex')
228
-
229
- # morph logits each sentences get a dict(text=str, tokens=list(dict(token, pos, feats, prefixes, suffix, suffix_feats?)))
230
- if output.morph_logits is not None:
231
- for sent_idx,parsed in enumerate(morph_parse_logits(input_ids, sentences, tokenizer, output.morph_logits)):
232
- merge_token_list(final_output[sent_idx]['tokens'], parsed['tokens'], 'morph')
233
-
234
- # NER logits each sentence gets a list(tuple(word, ner))
235
- if output.ner_logits is not None:
236
- for sent_idx,parsed in enumerate(ner_parse_logits(input_ids, sentences, tokenizer, output.ner_logits, self.config.id2label)):
237
- if per_token_ner:
238
- merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
239
- final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
240
-
241
- if output_style in ['ud', 'iahlt_ud']:
242
- final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
243
-
244
- if is_single_sentence:
245
- final_output = final_output[0]
246
- return final_output
247
-
248
-
249
-
250
- def aggregate_ner_tokens(final_output, parsed):
251
- entities = []
252
- prev = None
253
- for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
254
- # O does nothing
255
- if pred == 'O': prev = None
256
- # B- || I-entity != prev (different entity or none)
257
- elif pred.startswith('B-') or pred[2:] != prev:
258
- prev = pred[2:]
259
- entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
260
- else:
261
- entities[-1][0].append(word)
262
- entities[-1][1]['end'] = d['offsets']['end']
263
- entities[-1][1]['token_end'] = token_idx
264
-
265
- return [dict(phrase=' '.join(words), **d) for words, d in entities]
266
-
267
- def merge_token_list(src, update, key):
268
- for token_src, token_update in zip(src, update):
269
- token_src[key] = token_update
270
-
271
- def combine_token_wordpieces(input_ids: List[int], offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
272
- offset_mapping = offset_mapping.tolist()
273
- ret = []
274
- special_toks = tokenizer.all_special_tokens
275
- for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
276
- if token in special_toks: continue
277
- if token.startswith('##'):
278
- ret[-1]['token'] += token[2:]
279
- ret[-1]['offsets']['end'] = offsets[1]
280
- else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
281
- return ret
282
-
283
- def ner_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
284
- predictions = torch.argmax(logits, dim=-1).tolist()
285
- batch_ret = []
286
-
287
- special_toks = tokenizer.all_special_tokens
288
- for batch_idx in range(len(sentences)):
289
-
290
- ret = []
291
- batch_ret.append(ret)
292
-
293
- tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
294
- for tok_idx in range(len(tokens)):
295
- token = tokens[tok_idx]
296
- if token in special_toks: continue
297
-
298
- # wordpieces should just be appended to the previous word
299
- # we modify the last token in ret
300
- # by discarding the original end position and replacing it with the new token's end position
301
- if token.startswith('##'):
302
- continue
303
- # for each token, we append a tuple containing: token, label, start position, end position
304
- ret.append((token, id2label[predictions[batch_idx][tok_idx]]))
305
-
306
- return batch_ret
307
-
308
- def lex_parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor):
309
-
310
- predictions = torch.argsort(logits, dim=-1, descending=True)[..., :3].tolist()
311
- batch_ret = []
312
-
313
- special_toks = tokenizer.all_special_tokens
314
- for batch_idx in range(len(sentences)):
315
- intermediate_ret = []
316
- tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx])
317
- for tok_idx in range(len(tokens)):
318
- token = tokens[tok_idx]
319
- if token in special_toks: continue
320
-
321
- # wordpieces should just be appended to the previous word
322
- if token.startswith('##'):
323
- intermediate_ret[-1] = (intermediate_ret[-1][0] + token[2:], intermediate_ret[-1][1])
324
- continue
325
- intermediate_ret.append((token, tokenizer.convert_ids_to_tokens(predictions[batch_idx][tok_idx])))
326
-
327
- # build the final output taking into account valid letters
328
- ret = []
329
- batch_ret.append(ret)
330
- for (token, lexemes) in intermediate_ret:
331
- # must overlap on at least 2 non אהוי letters
332
- possible_lets = set(c for c in token if c not in 'אהוי')
333
- final_lex = '[BLANK]'
334
- for lex in lexemes:
335
- if sum(c in possible_lets for c in lex) >= min([2, len(possible_lets), len([c for c in lex if c not in 'אהוי'])]):
336
- final_lex = lex
337
- break
338
- ret.append((token, final_lex))
339
-
340
- return batch_ret
341
-
342
- ud_prefixes_to_pos = {
343
- 'ש': ['SCONJ'],
344
- 'מש': ['SCONJ'],
345
- 'כש': ['SCONJ'],
346
- 'לכש': ['SCONJ'],
347
- 'בש': ['SCONJ'],
348
- 'לש': ['SCONJ'],
349
- 'ו': ['CCONJ'],
350
- 'ל': ['ADP'],
351
- 'ה': ['DET', 'SCONJ'],
352
- 'מ': ['ADP', 'SCONJ'],
353
- 'ב': ['ADP'],
354
- 'כ': ['ADP', 'ADV'],
355
- }
356
- ud_suffix_to_htb_str = {
357
- 'Gender=Masc|Number=Sing|Person=3': '_הוא',
358
- 'Gender=Masc|Number=Plur|Person=3': '_הם',
359
- 'Gender=Fem|Number=Sing|Person=3': '_היא',
360
- 'Gender=Fem|Number=Plur|Person=3': '_הן',
361
- 'Gender=Fem,Masc|Number=Plur|Person=1': '_אנחנו',
362
- 'Gender=Fem,Masc|Number=Sing|Person=1': '_אני',
363
- 'Gender=Masc|Number=Plur|Person=2': '_אתם',
364
- 'Gender=Masc|Number=Sing|Person=3': '_הוא',
365
- 'Gender=Masc|Number=Sing|Person=2': '_אתה',
366
- 'Gender=Fem|Number=Sing|Person=2': '_את',
367
- 'Gender=Masc|Number=Plur|Person=3': '_הם'
368
- }
369
- def convert_output_to_ud(output_sentences, style: Literal['htb', 'iahlt']):
370
- if style not in ['htb', 'iahlt']:
371
- raise ValueError('style must be htb/iahlt')
372
-
373
- final_output = []
374
- for sent_idx, sentence in enumerate(output_sentences):
375
- # next, go through each word and insert it in the UD format. Store in a temp format for the post process
376
- intermediate_output = []
377
- ranges = []
378
- # store a mapping between each word index and the actual line it appears in
379
- idx_to_key = {-1: 0}
380
- for word_idx,word in enumerate(sentence['tokens']):
381
- try:
382
- # handle blank lexemes
383
- if word['lex'] == '[BLANK]':
384
- word['lex'] = word['seg'][-1]
385
- except KeyError:
386
- import json
387
- print(json.dumps(sentence, ensure_ascii=False, indent=2))
388
- exit(0)
389
-
390
- start = len(intermediate_output)
391
- # Add in all the prefixes
392
- if len(word['seg']) > 1:
393
- for pre in get_prefixes_from_str(word['seg'][0], greedy=True):
394
- # pos - just take the first valid pos that appears in the predicted prefixes list.
395
- pos = next((pos for pos in ud_prefixes_to_pos[pre] if pos in word['morph']['prefixes']), ud_prefixes_to_pos[pre][0])
396
- dep, func = ud_get_prefix_dep(pre, word, word_idx)
397
- intermediate_output.append(dict(word=pre, lex=pre, pos=pos, dep=dep, func=func, feats='_'))
398
-
399
- # if there was an implicit heh, add it in dependent on the method
400
- if not 'ה' in pre and intermediate_output[-1]['pos'] == 'ADP' and 'DET' in word['morph']['prefixes']:
401
- if style == 'htb':
402
- intermediate_output.append(dict(word='ה_', lex='ה', pos='DET', dep=word_idx, func='det', feats='_'))
403
- elif style == 'iahlt':
404
- intermediate_output[-1]['feats'] = 'Definite=Def|PronType=Art'
405
-
406
-
407
- idx_to_key[word_idx] = len(intermediate_output) + 1
408
- # add the main word in!
409
- intermediate_output.append(dict(
410
- word=word['seg'][-1], lex=word['lex'], pos=word['morph']['pos'],
411
- dep=word['syntax']['dep_head_idx'], func=word['syntax']['dep_func'],
412
- feats='|'.join(f'{k}={v}' for k,v in word['morph']['feats'].items())))
413
-
414
- # if we have suffixes, this changes things
415
- if word['morph']['suffix']:
416
- # first determine the dependency info:
417
- # For adp, num, det - they main word points to here, and the suffix points to the dependency
418
- entry_to_assign_suf_dep = None
419
- if word['morph']['pos'] in ['ADP', 'NUM', 'DET']:
420
- entry_to_assign_suf_dep = intermediate_output[-1]
421
- intermediate_output[-1]['func'] = 'case'
422
- dep = word['syntax']['dep_head_idx']
423
- func = word['syntax']['dep_func']
424
- else:
425
- # if pos is verb -> obj, num -> dep, default to -> nmod:poss
426
- dep = word_idx
427
- func = {'VERB': 'obj', 'NUM': 'dep'}.get(word['morph']['pos'], 'nmod:poss')
428
-
429
- s_word, s_lex = word['seg'][-1], word['lex']
430
- # update the word of the string and extract the string of the suffix!
431
- # for IAHLT:
432
- if style == 'iahlt':
433
- # we need to shorten the main word and extract the suffix
434
- # if it is longer than the lexeme - just take off the lexeme.
435
- if len(s_word) > len(s_lex):
436
- idx = len(s_lex)
437
- # Otherwise, try to find the last letter of the lexeme, and fail that just take the last letter
438
- else:
439
- # take either len-1, or the last occurence (which can be -1 === len-1)
440
- idx = min([len(s_word) - 1, s_word.rfind(s_lex[-1])])
441
- # extract the suffix and update the main word
442
- suf = s_word[idx:]
443
- intermediate_output[-1]['word'] = s_word[:idx]
444
- # for htb:
445
- elif style == 'htb':
446
- # main word becomes the lexeme, the suffix is based on the features
447
- intermediate_output[-1]['word'] = (s_lex if s_lex != s_word else s_word[:-1]) + '_'
448
- suf_feats = word['morph']['suffix_feats']
449
- suf = ud_suffix_to_htb_str.get(f"Gender={suf_feats.get('Gender', 'Fem,Masc')}|Number={suf_feats.get('Number', 'Sing')}|Person={suf_feats.get('Person', '3')}", "_הוא")
450
- # for HTB, if the function is poss, then add a shel pointing to the next word
451
- if func == 'nmod:poss' and s_lex != 'של':
452
- intermediate_output.append(dict(word='_של_', lex='של', pos='ADP', dep=len(intermediate_output) + 2, func='case', feats='_', absolute_dep=True))
453
- # add the main suffix in
454
- intermediate_output.append(dict(word=suf, lex='הוא', pos='PRON', dep=dep, func=func, feats='|'.join(f'{k}={v}' for k,v in word['morph']['suffix_feats'].items())))
455
- if entry_to_assign_suf_dep:
456
- entry_to_assign_suf_dep['dep'] = len(intermediate_output)
457
- entry_to_assign_suf_dep['absolute_dep'] = True
458
-
459
- end = len(intermediate_output)
460
- ranges.append((start, end, word['token']))
461
-
462
- # now that we have the intermediate output, combine it to the final output
463
- cur_output = []
464
- final_output.append(cur_output)
465
- # first, add the headers
466
- cur_output.append(f'# sent_id = {sent_idx + 1}')
467
- cur_output.append(f'# text = {sentence["text"]}')
468
-
469
- # add in all the actual entries
470
- for start,end,token in ranges:
471
- if end - start > 1:
472
- cur_output.append(f'{start + 1}-{end}\t{token}\t_\t_\t_\t_\t_\t_\t_\t_')
473
- for idx,output in enumerate(intermediate_output[start:end], start + 1):
474
- # compute the actual dependency location
475
- dep = output['dep'] if output.get('absolute_dep', False) else idx_to_key[output['dep']]
476
- func = normalize_dep_rel(output['func'], style)
477
- # and add the full ud string in
478
- cur_output.append('\t'.join([
479
- str(idx),
480
- output['word'],
481
- output['lex'],
482
- output['pos'],
483
- output['pos'],
484
- output['feats'],
485
- str(dep),
486
- func,
487
- '_', '_'
488
- ]))
489
- return final_output
490
-
491
- def normalize_dep_rel(dep, style: Literal['htb', 'iahlt']):
492
- if style == 'iahlt':
493
- if dep == 'compound:smixut': return 'compound'
494
- if dep == 'nsubj:cop': return 'nsubj'
495
- if dep == 'mark:q': return 'mark'
496
- if dep == 'case:gen' or dep == 'case:acc': return 'case'
497
- return dep
498
-
499
-
500
- def ud_get_prefix_dep(pre, word, word_idx):
501
- does_follow_main = False
502
-
503
- # shin goes to the main word for verbs, otherwise follows the word
504
- if pre.endswith('ש'):
505
- does_follow_main = word['morph']['pos'] != 'VERB'
506
- func = 'mark'
507
- # vuv goes to the main word if the function is in the list, otherwise follows
508
- elif pre == 'ו':
509
- does_follow_main = word['syntax']['dep_func'] not in ["conj", "acl:recl", "parataxis", "root", "acl", "amod", "list", "appos", "dep", "flatccomp"]
510
- func = 'cc'
511
- else:
512
- # for adj, noun, propn, pron, verb - prefixes go to the main word
513
- if word['morph']['pos'] in ["ADJ", "NOUN", "PROPN", "PRON", "VERB"]:
514
- does_follow_main = False
515
- # otherwise - prefix follows the word if the function is in the list
516
- else: does_follow_main = word['syntax']['dep_func'] in ["compound:affix", "det", "aux", "nummod", "advmod", "dep", "cop", "mark", "fixed"]
517
-
518
- func = 'case'
519
- if pre == 'ה':
520
- func = 'det' if 'DET' in word['morph']['prefixes'] else 'mark'
521
-
522
- return (word['syntax']['dep_head_idx'] if does_follow_main else word_idx), func
523
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BertForMorphTagging.py DELETED
@@ -1,212 +0,0 @@
1
- from collections import OrderedDict
2
- from operator import itemgetter
3
- from transformers.utils import ModelOutput
4
- import torch
5
- from torch import nn
6
- from typing import Dict, List, Tuple, Optional
7
- from dataclasses import dataclass
8
- from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
9
-
10
- ALL_POS = ['DET', 'NOUN', 'VERB', 'CCONJ', 'ADP', 'PRON', 'PUNCT', 'ADJ', 'ADV', 'SCONJ', 'NUM', 'PROPN', 'AUX', 'X', 'INTJ', 'SYM']
11
- ALL_PREFIX_POS = ['SCONJ', 'DET', 'ADV', 'CCONJ', 'ADP', 'NUM']
12
- ALL_SUFFIX_POS = ['none', 'ADP_PRON', 'PRON']
13
- ALL_FEATURES = [
14
- ('Gender', ['none', 'Masc', 'Fem', 'Fem,Masc']),
15
- ('Number', ['none', 'Sing', 'Plur', 'Plur,Sing', 'Dual', 'Dual,Plur']),
16
- ('Person', ['none', '1', '2', '3', '1,2,3']),
17
- ('Tense', ['none', 'Past', 'Fut', 'Pres', 'Imp'])
18
- ]
19
-
20
- @dataclass
21
- class MorphLogitsOutput(ModelOutput):
22
- prefix_logits: torch.FloatTensor = None
23
- pos_logits: torch.FloatTensor = None
24
- features_logits: List[torch.FloatTensor] = None
25
- suffix_logits: torch.FloatTensor = None
26
- suffix_features_logits: List[torch.FloatTensor] = None
27
-
28
- def detach(self):
29
- return MorphLogitsOutput(self.prefix_logits.detach(), self.pos_logits.detach(), [logits.deatch() for logits in self.features_logits], self.suffix_logits.detach(), [logits.deatch() for logits in self.suffix_features_logits])
30
-
31
-
32
- @dataclass
33
- class MorphTaggingOutput(ModelOutput):
34
- loss: Optional[torch.FloatTensor] = None
35
- logits: Optional[MorphLogitsOutput] = None
36
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
37
- attentions: Optional[Tuple[torch.FloatTensor]] = None
38
-
39
- @dataclass
40
- class MorphLabels(ModelOutput):
41
- prefix_labels: Optional[torch.FloatTensor] = None
42
- pos_labels: Optional[torch.FloatTensor] = None
43
- features_labels: Optional[List[torch.FloatTensor]] = None
44
- suffix_labels: Optional[torch.FloatTensor] = None
45
- suffix_features_labels: Optional[List[torch.FloatTensor]] = None
46
-
47
- def detach(self):
48
- return MorphLabels(self.prefix_labels.detach(), self.pos_labels.detach(), [labels.detach() for labels in self.features_labels], self.suffix_labels.detach(), [labels.detach() for labels in self.suffix_features_labels])
49
-
50
- def to(self, device):
51
- return MorphLabels(self.prefix_labels.to(device), self.pos_labels.to(device), [feat.to(device) for feat in self.features_labels], self.suffix_labels.to(device), [feat.to(device) for feat in self.suffix_features_labels])
52
-
53
- class BertMorphTaggingHead(nn.Module):
54
- def __init__(self, config):
55
- super().__init__()
56
- self.config = config
57
-
58
- self.num_prefix_classes = len(ALL_PREFIX_POS)
59
- self.num_pos_classes = len(ALL_POS)
60
- self.num_suffix_classes = len(ALL_SUFFIX_POS)
61
- self.num_features_classes = list(map(len, map(itemgetter(1), ALL_FEATURES)))
62
- # we need a classifier for prefix cls and POS cls
63
- # the prefix will use BCEWithLogits for multiple labels cls
64
- self.prefix_cls = nn.Linear(config.hidden_size, self.num_prefix_classes)
65
- # and pos + feats will use good old cross entropy for single label
66
- self.pos_cls = nn.Linear(config.hidden_size, self.num_pos_classes)
67
- self.features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
68
- # and suffix + feats will also be cross entropy
69
- self.suffix_cls = nn.Linear(config.hidden_size, self.num_suffix_classes)
70
- self.suffix_features_cls = nn.ModuleList([nn.Linear(config.hidden_size, len(features)) for _, features in ALL_FEATURES])
71
-
72
- def forward(
73
- self,
74
- hidden_states: torch.Tensor,
75
- labels: Optional[MorphLabels] = None):
76
- # run each of the classifiers on the transformed output
77
- prefix_logits = self.prefix_cls(hidden_states)
78
- pos_logits = self.pos_cls(hidden_states)
79
- suffix_logits = self.suffix_cls(hidden_states)
80
- features_logits = [cls(hidden_states) for cls in self.features_cls]
81
- suffix_features_logits = [cls(hidden_states) for cls in self.suffix_features_cls]
82
-
83
- loss = None
84
- if labels is not None:
85
- # step 1: prefix labels loss
86
- loss_fct = nn.BCEWithLogitsLoss(weight=(labels.prefix_labels != -100).float())
87
- loss = loss_fct(prefix_logits, labels.prefix_labels)
88
- # step 2: pos labels loss
89
- loss_fct = nn.CrossEntropyLoss()
90
- loss += loss_fct(pos_logits.view(-1, self.num_pos_classes), labels.pos_labels.view(-1))
91
- # step 2b: features
92
- for feat_logits,feat_labels,num_features in zip(features_logits, labels.features_labels, self.num_features_classes):
93
- loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
94
- # step 3: suffix logits loss
95
- loss += loss_fct(suffix_logits.view(-1, self.num_suffix_classes), labels.suffix_labels.view(-1))
96
- # step 3b: suffix features
97
- for feat_logits,feat_labels,num_features in zip(suffix_features_logits, labels.suffix_features_labels, self.num_features_classes):
98
- loss += loss_fct(feat_logits.view(-1, num_features), feat_labels.view(-1))
99
-
100
- return loss, MorphLogitsOutput(prefix_logits, pos_logits, features_logits, suffix_logits, suffix_features_logits)
101
-
102
- class BertForMorphTagging(BertPreTrainedModel):
103
-
104
- def __init__(self, config):
105
- super().__init__(config)
106
-
107
- self.bert = BertModel(config, add_pooling_layer=False)
108
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
109
- self.morph = BertMorphTaggingHead(config)
110
-
111
- # Initialize weights and apply final processing
112
- self.post_init()
113
-
114
- def forward(
115
- self,
116
- input_ids: Optional[torch.Tensor] = None,
117
- attention_mask: Optional[torch.Tensor] = None,
118
- token_type_ids: Optional[torch.Tensor] = None,
119
- position_ids: Optional[torch.Tensor] = None,
120
- labels: Optional[MorphLabels] = None,
121
- head_mask: Optional[torch.Tensor] = None,
122
- inputs_embeds: Optional[torch.Tensor] = None,
123
- output_attentions: Optional[bool] = None,
124
- output_hidden_states: Optional[bool] = None,
125
- return_dict: Optional[bool] = None,
126
- ):
127
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
128
-
129
- bert_outputs = self.bert(
130
- input_ids,
131
- attention_mask=attention_mask,
132
- token_type_ids=token_type_ids,
133
- position_ids=position_ids,
134
- head_mask=head_mask,
135
- inputs_embeds=inputs_embeds,
136
- output_attentions=output_attentions,
137
- output_hidden_states=output_hidden_states,
138
- return_dict=return_dict,
139
- )
140
-
141
- hidden_states = bert_outputs[0]
142
- hidden_states = self.dropout(hidden_states)
143
-
144
- loss, logits = self.morph(hidden_states, labels)
145
-
146
- if not return_dict:
147
- return (loss,logits) + bert_outputs[2:]
148
-
149
- return MorphTaggingOutput(
150
- loss=loss,
151
- logits=logits,
152
- hidden_states=bert_outputs.hidden_states,
153
- attentions=bert_outputs.attentions,
154
- )
155
-
156
- def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
157
- # tokenize the inputs and convert them to relevant device
158
- inputs = tokenizer(sentences, padding=padding, truncation=True, return_tensors='pt')
159
- inputs = {k:v.to(self.device) for k,v in inputs.items()}
160
- # calculate the logits
161
- logits = self.forward(**inputs, return_dict=True).logits
162
- return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
163
-
164
- def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: MorphLogitsOutput):
165
- prefix_logits, pos_logits, feats_logits, suffix_logits, suffix_feats_logits = \
166
- logits.prefix_logits, logits.pos_logits, logits.features_logits, logits.suffix_logits, logits.suffix_features_logits
167
-
168
- prefix_predictions = (prefix_logits > 0.5).int().tolist() # Threshold at 0.5 for multi-label classification
169
- pos_predictions = pos_logits.argmax(axis=-1).tolist()
170
- suffix_predictions = suffix_logits.argmax(axis=-1).tolist()
171
- feats_predictions = [logits.argmax(axis=-1).tolist() for logits in feats_logits]
172
- suffix_feats_predictions = [logits.argmax(axis=-1).tolist() for logits in suffix_feats_logits]
173
-
174
- # create the return dictionary
175
- # for each sentence, return a dict object with the following files { text, tokens }
176
- # Where tokens is a list of dicts, where each dict is:
177
- # { pos: str, feats: dict, prefixes: List[str], suffix: str | bool, suffix_feats: dict | None}
178
- special_toks = tokenizer.all_special_tokens
179
- ret = []
180
- for sent_idx,sentence in enumerate(sentences):
181
- input_id_strs = tokenizer.convert_ids_to_tokens(input_ids[sent_idx])
182
- # iterate through each token in the sentence, ignoring special tokens
183
- tokens = []
184
- for token_idx,token_str in enumerate(input_id_strs):
185
- if token_str in special_toks: continue
186
- if token_str.startswith('##'):
187
- tokens[-1]['token'] += token_str[2:]
188
- continue
189
- tokens.append(dict(
190
- token=token_str,
191
- pos=ALL_POS[pos_predictions[sent_idx][token_idx]],
192
- feats=get_features_dict_from_predictions(feats_predictions, (sent_idx, token_idx)),
193
- prefixes=[ALL_PREFIX_POS[idx] for idx,i in enumerate(prefix_predictions[sent_idx][token_idx]) if i > 0],
194
- suffix=get_suffix_or_false(ALL_SUFFIX_POS[suffix_predictions[sent_idx][token_idx]]),
195
- ))
196
- if tokens[-1]['suffix']:
197
- tokens[-1]['suffix_feats'] = get_features_dict_from_predictions(suffix_feats_predictions, (sent_idx, token_idx))
198
- ret.append(dict(text=sentence, tokens=tokens))
199
- return ret
200
-
201
- def get_suffix_or_false(suffix):
202
- return False if suffix == 'none' else suffix
203
-
204
- def get_features_dict_from_predictions(predictions, idx):
205
- ret = {}
206
- for (feat_idx, (feat_name, feat_values)) in enumerate(ALL_FEATURES):
207
- val = feat_values[predictions[feat_idx][idx[0]][idx[1]]]
208
- if val != 'none':
209
- ret[feat_name] = val
210
- return ret
211
-
212
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BertForPrefixMarking.py DELETED
@@ -1,248 +0,0 @@
1
- from transformers.utils import ModelOutput
2
- import torch
3
- from torch import nn
4
- from typing import Dict, List, Tuple, Optional
5
- from dataclasses import dataclass
6
- from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
7
-
8
- # define the classes, and the possible prefixes for each class
9
- POSSIBLE_PREFIX_CLASSES = [ ['לכש', 'כש', 'מש', 'בש', 'לש'], ['מ'], ['ש'], ['ה'], ['ו'], ['כ'], ['ל'], ['ב'] ]
10
- # map each individual prefix to it's class number
11
- PREFIXES_TO_CLASS = {w:i for i,l in enumerate(POSSIBLE_PREFIX_CLASSES) for w in l}
12
- # keep a list of all the prefixes, sorted by length, so that we can decompose
13
- # a given prefixes and figure out the classes
14
- ALL_PREFIX_ITEMS = list(sorted(PREFIXES_TO_CLASS.keys(), key=len, reverse=True))
15
- TOTAL_POSSIBLE_PREFIX_CLASSES = len(POSSIBLE_PREFIX_CLASSES)
16
-
17
- def get_prefixes_from_str(s, greedy=False):
18
- # keep trimming prefixes from the string
19
- while len(s) > 0 and s[0] in PREFIXES_TO_CLASS:
20
- # find the longest string to trim
21
- next_pre = next((pre for pre in ALL_PREFIX_ITEMS if s.startswith(pre)), None)
22
- if next_pre is None:
23
- return
24
- yield next_pre
25
- # if the chosen prefix is more than one letter, there is always an option that the
26
- # prefix is actually just the first letter of the prefix - so offer that up as a valid prefix
27
- # as well. We will still jump to the length of the longer one, since if the next two/three
28
- # letters are a prefix, they have to be the longest one
29
- if not greedy and len(next_pre) > 1:
30
- yield next_pre[0]
31
- s = s[len(next_pre):]
32
-
33
- def get_prefix_classes_from_str(s, greedy=False):
34
- for pre in get_prefixes_from_str(s, greedy):
35
- yield PREFIXES_TO_CLASS[pre]
36
-
37
- @dataclass
38
- class PrefixesClassifiersOutput(ModelOutput):
39
- loss: Optional[torch.FloatTensor] = None
40
- logits: Optional[torch.FloatTensor] = None
41
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
42
- attentions: Optional[Tuple[torch.FloatTensor]] = None
43
-
44
- class BertPrefixMarkingHead(nn.Module):
45
- def __init__(self, config) -> None:
46
- super().__init__()
47
- self.config = config
48
-
49
- # an embedding table containing an embedding for each prefix class + 1 for NONE
50
- # we will concatenate either the embedding/NONE for each class - and we want the concatenate
51
- # size to be the hidden_size
52
- prefix_class_embed = config.hidden_size // TOTAL_POSSIBLE_PREFIX_CLASSES
53
- self.prefix_class_embeddings = nn.Embedding(TOTAL_POSSIBLE_PREFIX_CLASSES + 1, prefix_class_embed)
54
-
55
- # one layer for transformation, apply an activation, then another N classifiers for each prefix class
56
- self.transform = nn.Linear(config.hidden_size + prefix_class_embed * TOTAL_POSSIBLE_PREFIX_CLASSES, config.hidden_size)
57
- self.activation = nn.Tanh()
58
- self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, 2) for _ in range(TOTAL_POSSIBLE_PREFIX_CLASSES)])
59
-
60
- def forward(
61
- self,
62
- hidden_states: torch.Tensor,
63
- prefix_class_id_options: torch.Tensor,
64
- labels: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
65
-
66
- # encode the prefix_class_id_options
67
- # If input_ids is batch x seq_len
68
- # Then sequence_output is batch x seq_len x hidden_dim
69
- # So prefix_class_id_options is batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES
70
- # Looking up the embeddings should give us batch x seq_len x TOTAL_POSSIBLE_PREFIX_CLASSES x hidden_dim / N
71
- possible_class_embed = self.prefix_class_embeddings(prefix_class_id_options)
72
- # then flatten the final dimension - now we have batch x seq_len x hidden_dim_2
73
- possible_class_embed = possible_class_embed.reshape(possible_class_embed.shape[:-2] + (-1,))
74
-
75
- # concatenate the new class embed into the sequence output before the transform
76
- pre_transform_output = torch.cat((hidden_states, possible_class_embed), dim=-1) # batch x seq_len x (hidden_dim + hidden_dim_2)
77
- pre_logits_output = self.activation(self.transform(pre_transform_output))# batch x seq_len x hidden_dim
78
-
79
- # run each of the classifiers on the transformed output
80
- logits = torch.cat([cls(pre_logits_output).unsqueeze(-2) for cls in self.classifiers], dim=-2)
81
-
82
- loss = None
83
- if labels is not None:
84
- loss_fct = nn.CrossEntropyLoss()
85
- loss = loss_fct(logits.view(-1, 2), labels.view(-1))
86
-
87
- return (loss, logits)
88
-
89
-
90
-
91
- class BertForPrefixMarking(BertPreTrainedModel):
92
-
93
- def __init__(self, config):
94
- super().__init__(config)
95
-
96
- self.bert = BertModel(config, add_pooling_layer=False)
97
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
98
- self.prefix = BertPrefixMarkingHead(config)
99
-
100
- # Initialize weights and apply final processing
101
- self.post_init()
102
-
103
- def forward(
104
- self,
105
- input_ids: Optional[torch.Tensor] = None,
106
- attention_mask: Optional[torch.Tensor] = None,
107
- token_type_ids: Optional[torch.Tensor] = None,
108
- prefix_class_id_options: Optional[torch.Tensor] = None,
109
- position_ids: Optional[torch.Tensor] = None,
110
- labels: Optional[torch.Tensor] = None,
111
- head_mask: Optional[torch.Tensor] = None,
112
- inputs_embeds: Optional[torch.Tensor] = None,
113
- output_attentions: Optional[bool] = None,
114
- output_hidden_states: Optional[bool] = None,
115
- return_dict: Optional[bool] = None,
116
- ):
117
- r"""
118
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
119
- Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
120
- """
121
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
122
-
123
- bert_outputs = self.bert(
124
- input_ids,
125
- attention_mask=attention_mask,
126
- token_type_ids=token_type_ids,
127
- position_ids=position_ids,
128
- head_mask=head_mask,
129
- inputs_embeds=inputs_embeds,
130
- output_attentions=output_attentions,
131
- output_hidden_states=output_hidden_states,
132
- return_dict=return_dict,
133
- )
134
-
135
- hidden_states = bert_outputs[0]
136
- hidden_states = self.dropout(hidden_states)
137
-
138
- loss, logits = self.prefix.forward(hidden_states, prefix_class_id_options, labels)
139
- if not return_dict:
140
- return (loss,logits,) + bert_outputs[2:]
141
-
142
- return PrefixesClassifiersOutput(
143
- loss=loss,
144
- logits=logits,
145
- hidden_states=bert_outputs.hidden_states,
146
- attentions=bert_outputs.attentions,
147
- )
148
-
149
- def predict(self, sentences: List[str], tokenizer: BertTokenizerFast, padding='longest'):
150
- # step 1: encode the sentences through using the tokenizer, and get the input tensors + prefix id tensors
151
- inputs = encode_sentences_for_bert_for_prefix_marking(tokenizer, sentences, padding)
152
- inputs.pop('offset_mapping')
153
- inputs = {k:v.to(self.device) for k,v in inputs.items()}
154
-
155
- # run through bert
156
- logits = self.forward(**inputs, return_dict=True).logits
157
- return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
158
-
159
- def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.FloatTensor):
160
- # extract the predictions by argmaxing the final dimension (batch x sequence x prefixes x prediction)
161
- logit_preds = torch.argmax(logits, axis=3).tolist()
162
-
163
- ret = []
164
-
165
- for sent_idx,sent_ids in enumerate(input_ids):
166
- tokens = tokenizer.convert_ids_to_tokens(sent_ids)
167
- ret.append([])
168
- for tok_idx,token in enumerate(tokens):
169
- # If we've reached the pad token, then we are at the end
170
- if token == tokenizer.pad_token: continue
171
- if token.startswith('##'): continue
172
-
173
- # combine the next tokens in? only if it's a breakup
174
- next_tok_idx = tok_idx + 1
175
- while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
176
- token += tokens[next_tok_idx][2:]
177
- next_tok_idx += 1
178
-
179
- prefix_len = get_predicted_prefix_len_from_logits(token, logit_preds[sent_idx][tok_idx])
180
-
181
- if not prefix_len:
182
- ret[-1].append([token])
183
- else:
184
- ret[-1].append([token[:prefix_len], token[prefix_len:]])
185
- return ret
186
-
187
- def encode_sentences_for_bert_for_prefix_marking(tokenizer: BertTokenizerFast, sentences: List[str], padding='longest', truncation=True):
188
- inputs = tokenizer(sentences, padding=padding, truncation=truncation, return_offsets_mapping=True, return_tensors='pt')
189
- # create our prefix_id_options array which will be like the input ids shape but with an addtional
190
- # dimension containing for each prefix whether it can be for that word
191
- prefix_id_options = torch.full(inputs['input_ids'].shape + (TOTAL_POSSIBLE_PREFIX_CLASSES,), TOTAL_POSSIBLE_PREFIX_CLASSES, dtype=torch.long)
192
-
193
- # go through each token, and fill in the vector accordingly
194
- for sent_idx, sent_ids in enumerate(inputs['input_ids']):
195
- tokens = tokenizer.convert_ids_to_tokens(sent_ids)
196
- for tok_idx, token in enumerate(tokens):
197
- # if the first letter isn't a valid prefix letter, nothing to talk about
198
- if len(token) < 2 or not token[0] in PREFIXES_TO_CLASS: continue
199
-
200
- # combine the next tokens in? only if it's a breakup
201
- next_tok_idx = tok_idx + 1
202
- while next_tok_idx < len(tokens) and tokens[next_tok_idx].startswith('##'):
203
- token += tokens[next_tok_idx][2:]
204
- next_tok_idx += 1
205
-
206
- # find all the possible prefixes - and mark them as 0 (and in the possible mark it as it's value for embed lookup)
207
- for pre_class in get_prefix_classes_from_str(token):
208
- prefix_id_options[sent_idx, tok_idx, pre_class] = pre_class
209
-
210
- inputs['prefix_class_id_options'] = prefix_id_options
211
- return inputs
212
-
213
- def get_predicted_prefix_len_from_logits(token, token_logits):
214
- # Go through each possible prefix, and check if the prefix is yes - and if
215
- # so increase the counter of the matched length, otherwise break out. That will solve cases
216
- # of predicting prefix combinations that don't exist on the word.
217
- # For example, if we have the word ושכשהלכתי and the model predict ו & כש, then we will only
218
- # take the vuv because in order to get the כש we need the ש as well.
219
- # Two extra items:
220
- # 1] Don't allow the same prefix multiple times
221
- # 2] Always check that the word starts with that prefix - otherwise it's bad
222
- # (except for the case of multi-letter prefix, where we force the next to be last)
223
- cur_len, skip_next, last_check, seen_prefixes = 0, False, False, set()
224
- for prefix in get_prefixes_from_str(token):
225
- # Are we skipping this prefix? This will be the case where we matched כש, don't allow ש
226
- if skip_next:
227
- skip_next = False
228
- continue
229
- # check for duplicate prefixes, we don't allow two of the same prefix
230
- # if it predicted two of the same, then we will break out
231
- if prefix in seen_prefixes: break
232
- seen_prefixes.add(prefix)
233
-
234
- # check if we predicted this prefix
235
- if token_logits[PREFIXES_TO_CLASS[prefix]]:
236
- cur_len += len(prefix)
237
- if last_check: break
238
- skip_next = len(prefix) > 1
239
- # Otherwise, we predicted no. If we didn't, then this is the end of the prefix
240
- # and time to break out. *Except* if it's a multi letter prefix, then we allow
241
- # just the next letter - e.g., if כש doesn't match, then we allow כ, but then we know
242
- # the word continues with a ש, and if it's not כש, then it's not כ-ש- (invalid)
243
- elif len(prefix) > 1:
244
- last_check = True
245
- else:
246
- break
247
-
248
- return cur_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
BertForSyntaxParsing.py DELETED
@@ -1,312 +0,0 @@
1
- import math
2
- from transformers.utils import ModelOutput
3
- import torch
4
- from torch import nn
5
- from typing import Dict, List, Tuple, Optional, Union
6
- from dataclasses import dataclass
7
- from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast
8
-
9
- ALL_FUNCTION_LABELS = ["nsubj", "nsubj:cop", "punct", "mark", "mark:q", "case", "case:gen", "case:acc", "fixed", "obl", "det", "amod", "acl:relcl", "nmod", "cc", "conj", "root", "compound:smixut", "cop", "compound:affix", "advmod", "nummod", "appos", "nsubj:pass", "nmod:poss", "xcomp", "obj", "aux", "parataxis", "advcl", "ccomp", "csubj", "acl", "obl:tmod", "csubj:pass", "dep", "dislocated", "nmod:tmod", "nmod:npmod", "flat", "obl:npmod", "goeswith", "reparandum", "orphan", "list", "discourse", "iobj", "vocative", "expl", "flat:name"]
10
-
11
- @dataclass
12
- class SyntaxLogitsOutput(ModelOutput):
13
- dependency_logits: torch.FloatTensor = None
14
- function_logits: torch.FloatTensor = None
15
- dependency_head_indices: torch.LongTensor = None
16
-
17
- def detach(self):
18
- return SyntaxTaggingOutput(self.dependency_logits.detach(), self.function_logits.detach(), self.dependency_head_indices.detach())
19
-
20
- @dataclass
21
- class SyntaxTaggingOutput(ModelOutput):
22
- loss: Optional[torch.FloatTensor] = None
23
- logits: Optional[SyntaxLogitsOutput] = None
24
- hidden_states: Optional[Tuple[torch.FloatTensor]] = None
25
- attentions: Optional[Tuple[torch.FloatTensor]] = None
26
-
27
- @dataclass
28
- class SyntaxLabels(ModelOutput):
29
- dependency_labels: Optional[torch.LongTensor] = None
30
- function_labels: Optional[torch.LongTensor] = None
31
-
32
- def detach(self):
33
- return SyntaxLabels(self.dependency_labels.detach(), self.function_labels.detach())
34
-
35
- def to(self, device):
36
- return SyntaxLabels(self.dependency_labels.to(device), self.function_labels.to(device))
37
-
38
- class BertSyntaxParsingHead(nn.Module):
39
- def __init__(self, config):
40
- super().__init__()
41
- self.config = config
42
-
43
- # the attention query & key values
44
- self.head_size = config.syntax_head_size# int(config.hidden_size / config.num_attention_heads * 2)
45
- self.query = nn.Linear(config.hidden_size, self.head_size)
46
- self.key = nn.Linear(config.hidden_size, self.head_size)
47
- # the function classifier gets two encoding values and predicts the labels
48
- self.num_function_classes = len(ALL_FUNCTION_LABELS)
49
- self.cls = nn.Linear(config.hidden_size * 2, self.num_function_classes)
50
-
51
- def forward(
52
- self,
53
- hidden_states: torch.Tensor,
54
- extended_attention_mask: Optional[torch.Tensor],
55
- labels: Optional[SyntaxLabels] = None,
56
- compute_mst: bool = False) -> Tuple[torch.Tensor, SyntaxLogitsOutput]:
57
-
58
- # Take the dot product between "query" and "key" to get the raw attention scores.
59
- query_layer = self.query(hidden_states)
60
- key_layer = self.key(hidden_states)
61
- attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.head_size)
62
-
63
- # add in the attention mask
64
- if extended_attention_mask is not None:
65
- if extended_attention_mask.ndim == 4:
66
- extended_attention_mask = extended_attention_mask.squeeze(1)
67
- attention_scores += extended_attention_mask# batch x seq x seq
68
-
69
- # At this point take the hidden_state of the word and of the dependency word, and predict the function
70
- # If labels are provided, use the labels.
71
- if self.training and labels is not None:
72
- # Note that the labels can have -100, so just set those to zero with a max
73
- dep_indices = labels.dependency_labels.clamp_min(0)
74
- # Otherwise - check if he wants the MST or just the argmax
75
- elif compute_mst:
76
- dep_indices = compute_mst_tree(attention_scores, extended_attention_mask)
77
- else:
78
- dep_indices = torch.argmax(attention_scores, dim=-1)
79
-
80
- # After we retrieved the dependency indicies, create a tensor of teh batch indices, and and retrieve the vectors of the heads to calculate the function
81
- batch_indices = torch.arange(dep_indices.size(0)).view(-1, 1).expand(-1, dep_indices.size(1)).to(dep_indices.device)
82
- dep_vectors = hidden_states[batch_indices, dep_indices, :] # batch x seq x dim
83
-
84
- # concatenate that with the last hidden states, and send to the classifier output
85
- cls_inputs = torch.cat((hidden_states, dep_vectors), dim=-1)
86
- function_logits = self.cls(cls_inputs)
87
-
88
- loss = None
89
- if labels is not None:
90
- loss_fct = nn.CrossEntropyLoss()
91
- # step 1: dependency scores loss - this is applied to the attention scores
92
- loss = loss_fct(attention_scores.view(-1, hidden_states.size(-2)), labels.dependency_labels.view(-1))
93
- # step 2: function loss
94
- loss += loss_fct(function_logits.view(-1, self.num_function_classes), labels.function_labels.view(-1))
95
-
96
- return (loss, SyntaxLogitsOutput(attention_scores, function_logits, dep_indices))
97
-
98
-
99
- class BertForSyntaxParsing(BertPreTrainedModel):
100
-
101
- def __init__(self, config):
102
- super().__init__(config)
103
-
104
- self.bert = BertModel(config, add_pooling_layer=False)
105
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
106
- self.syntax = BertSyntaxParsingHead(config)
107
-
108
- # Initialize weights and apply final processing
109
- self.post_init()
110
-
111
- def forward(
112
- self,
113
- input_ids: Optional[torch.Tensor] = None,
114
- attention_mask: Optional[torch.Tensor] = None,
115
- token_type_ids: Optional[torch.Tensor] = None,
116
- position_ids: Optional[torch.Tensor] = None,
117
- labels: Optional[SyntaxLabels] = None,
118
- head_mask: Optional[torch.Tensor] = None,
119
- inputs_embeds: Optional[torch.Tensor] = None,
120
- output_attentions: Optional[bool] = None,
121
- output_hidden_states: Optional[bool] = None,
122
- return_dict: Optional[bool] = None,
123
- compute_syntax_mst: Optional[bool] = None,
124
- ):
125
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
126
-
127
- bert_outputs = self.bert(
128
- input_ids,
129
- attention_mask=attention_mask,
130
- token_type_ids=token_type_ids,
131
- position_ids=position_ids,
132
- head_mask=head_mask,
133
- inputs_embeds=inputs_embeds,
134
- output_attentions=output_attentions,
135
- output_hidden_states=output_hidden_states,
136
- return_dict=return_dict,
137
- )
138
-
139
- extended_attention_mask = None
140
- if attention_mask is not None:
141
- extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_ids.size())
142
- # apply the syntax head
143
- loss, logits = self.syntax(self.dropout(bert_outputs[0]), extended_attention_mask, labels, compute_syntax_mst)
144
-
145
- if not return_dict:
146
- return (loss,(logits.dependency_logits, logits.function_logits)) + bert_outputs[2:]
147
-
148
- return SyntaxTaggingOutput(
149
- loss=loss,
150
- logits=logits,
151
- hidden_states=bert_outputs.hidden_states,
152
- attentions=bert_outputs.attentions,
153
- )
154
-
155
- def predict(self, sentences: Union[str, List[str]], tokenizer: BertTokenizerFast, compute_mst=True):
156
- if isinstance(sentences, str):
157
- sentences = [sentences]
158
-
159
- # predict the logits for the sentence
160
- inputs = tokenizer(sentences, padding='longest', truncation=True, return_tensors='pt')
161
- inputs = {k:v.to(self.device) for k,v in inputs.items()}
162
- logits = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_mst).logits
163
- return parse_logits(inputs['input_ids'].tolist(), sentences, tokenizer, logits)
164
-
165
- def parse_logits(input_ids: List[List[int]], sentences: List[str], tokenizer: BertTokenizerFast, logits: SyntaxLogitsOutput):
166
- outputs = []
167
-
168
- special_toks = tokenizer.all_special_tokens
169
- for i in range(len(sentences)):
170
- deps = logits.dependency_head_indices[i].tolist()
171
- funcs = logits.function_logits.argmax(-1)[i].tolist()
172
- toks = [tok for tok in tokenizer.convert_ids_to_tokens(input_ids[i]) if tok not in special_toks]
173
-
174
- # first, go through the tokens and create a mapping between each dependency index and the index without wordpieces
175
- # wordpieces. At the same time, append the wordpieces in
176
- idx_mapping = {-1:-1} # default root
177
- real_idx = -1
178
- for i in range(len(toks)):
179
- if not toks[i].startswith('##'):
180
- real_idx += 1
181
- idx_mapping[i] = real_idx
182
-
183
- # build our tree, keeping tracking of the root idx
184
- tree = []
185
- root_idx = 0
186
- for i in range(len(toks)):
187
- if toks[i].startswith('##'):
188
- tree[-1]['word'] += toks[i][2:]
189
- continue
190
-
191
- dep_idx = deps[i + 1] - 1 # increase 1 for cls, decrease 1 for cls
192
- if dep_idx == len(toks): dep_idx = i - 1 # if he predicts sep, then just point to the previous word
193
-
194
- dep_head = 'root' if dep_idx == -1 else toks[dep_idx]
195
- dep_func = ALL_FUNCTION_LABELS[funcs[i + 1]]
196
-
197
- if dep_head == 'root': root_idx = len(tree)
198
- tree.append(dict(word=toks[i], dep_head_idx=idx_mapping[dep_idx], dep_func=dep_func))
199
- # append the head word
200
- for d in tree:
201
- d['dep_head'] = tree[d['dep_head_idx']]['word']
202
-
203
- outputs.append(dict(tree=tree, root_idx=root_idx))
204
- return outputs
205
-
206
-
207
- def compute_mst_tree(attention_scores: torch.Tensor, extended_attention_mask: torch.LongTensor):
208
- # attention scores should be 3 dimensions - batch x seq x seq (if it is 2 - just unsqueeze)
209
- if attention_scores.ndim == 2: attention_scores = attention_scores.unsqueeze(0)
210
- if attention_scores.ndim != 3 or attention_scores.shape[1] != attention_scores.shape[2]:
211
- raise ValueError(f'Expected attention scores to be of shape batch x seq x seq, instead got {attention_scores.shape}')
212
-
213
- batch_size, seq_len, _ = attention_scores.shape
214
- # start by softmaxing so the scores are comparable
215
- attention_scores = attention_scores.softmax(dim=-1)
216
-
217
- batch_indices = torch.arange(batch_size, device=attention_scores.device)
218
- seq_indices = torch.arange(seq_len, device=attention_scores.device)
219
-
220
- seq_lens = torch.full((batch_size,), seq_len)
221
-
222
- if extended_attention_mask is not None:
223
- seq_lens = torch.argmax((extended_attention_mask != 0).int(), dim=2).squeeze(1)
224
- # zero out any padding
225
- attention_scores[extended_attention_mask.squeeze(1) != 0] = 0
226
-
227
- # set the values for the CLS and sep to all by very low, so they never get chosen as a replacement arc
228
- attention_scores[:, 0, :] = 0
229
- attention_scores[batch_indices, seq_lens - 1, :] = 0
230
- attention_scores[batch_indices, :, seq_lens - 1] = 0 # can never predict sep
231
- # set the values for each token pointing to itself be 0
232
- attention_scores[:, seq_indices, seq_indices] = 0
233
-
234
- # find the root, and make him super high so we never have a conflict
235
- root_cands = torch.argsort(attention_scores[:, :, 0], dim=-1)
236
- attention_scores[batch_indices.unsqueeze(1), root_cands, 0] = 0
237
- attention_scores[batch_indices, root_cands[:, -1], 0] = 1.0
238
-
239
- # we start by getting the argmax for each score, and then computing the cycles and contracting them
240
- sorted_indices = torch.argsort(attention_scores, dim=-1, descending=True)
241
- indices = sorted_indices[:, :, 0].clone() # take the argmax
242
-
243
- attention_scores = attention_scores.tolist()
244
- seq_lens = seq_lens.tolist()
245
- sorted_indices = [[sub_l[:slen] for sub_l in l[:slen]] for l,slen in zip(sorted_indices.tolist(), seq_lens)]
246
-
247
-
248
- # go through each batch item and make sure our tree works
249
- for batch_idx in range(batch_size):
250
- # We have one root - detect the cycles and contract them. A cycle can never contain the root so really
251
- # for every cycle, we look at all the nodes, and find the highest arc out of the cycle for any values. Replace that and tada
252
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
253
- contracted_arcs = set()
254
- while has_cycle:
255
- base_idx, head_idx = choose_contracting_arc(indices[batch_idx], sorted_indices[batch_idx], cycle_nodes, contracted_arcs, seq_lens[batch_idx], attention_scores[batch_idx])
256
- indices[batch_idx, base_idx] = head_idx
257
- contracted_arcs.add(base_idx)
258
- # find the next cycle
259
- has_cycle, cycle_nodes = detect_cycle(indices[batch_idx], seq_lens[batch_idx])
260
-
261
- return indices
262
-
263
- def detect_cycle(indices: torch.LongTensor, seq_len: int):
264
- # Simple cycle detection algorithm
265
- # Returns a boolean indicating if a cycle is detected and the nodes involved in the cycle
266
- visited = set()
267
- for node in range(1, seq_len - 1): # ignore the CLS/SEP tokens
268
- if node in visited:
269
- continue
270
- current_path = set()
271
- while node not in visited:
272
- visited.add(node)
273
- current_path.add(node)
274
- node = indices[node].item()
275
- if node == 0: break # roots never point to anything
276
- if node in current_path:
277
- return True, current_path # Cycle detected
278
- return False, None
279
-
280
- def choose_contracting_arc(indices: torch.LongTensor, sorted_indices: List[List[int]], cycle_nodes: set, contracted_arcs: set, seq_len: int, scores: List[List[float]]):
281
- # Chooses the highest-scoring, non-cycling arc from a graph. Iterates through 'cycle_nodes' to find
282
- # the best arc based on 'scores', avoiding cycles and zero node connections.
283
- # For each node, we only look at the next highest scoring non-cycling arc
284
- best_base_idx, best_head_idx = -1, -1
285
- score = 0
286
-
287
- # convert the indices to a list once, to avoid multiple conversions (saves a few seconds)
288
- currents = indices.tolist()
289
- for base_node in cycle_nodes:
290
- if base_node in contracted_arcs: continue
291
- # we don't want to take anything that has a higher score than the current value - we can end up in an endless loop
292
- # Since the indices are sorted, as soon as we find our current item, we can move on to the next.
293
- current = currents[base_node]
294
- found_current = False
295
-
296
- for head_node in sorted_indices[base_node]:
297
- if head_node == current:
298
- found_current = True
299
- continue
300
- if head_node in contracted_arcs: continue
301
- if not found_current or head_node in cycle_nodes or head_node == 0:
302
- continue
303
-
304
- current_score = scores[base_node][head_node]
305
- if current_score > score:
306
- best_base_idx, best_head_idx, score = base_node, head_node, current_score
307
- break
308
-
309
- if best_base_idx == -1:
310
- raise ValueError('Stuck in endless loop trying to compute syntax mst. Please try again setting compute_syntax_mst=False')
311
-
312
- return best_base_idx, best_head_idx