demo_ner / pipeline.py
lalital's picture
Add related files
6c4ffba
import torch
from typing import Callable, List, Tuple, Union
from functools import partial
import itertools
from seqeval.scheme import Tokens, IOB2, IOBES
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from pythainlp.tokenize import word_tokenize as pythainlp_word_tokenize
newmm_word_tokenizer = partial(pythainlp_word_tokenize, keep_whitespace=True, engine='newmm')
from thai2transformers.preprocess import rm_useless_spaces
SPIECE = '▁'
class TokenClassificationPipeline:
def __init__(self,
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerBase,
pretokenizer: Callable[[str], List[str]] = newmm_word_tokenizer,
lowercase=False,
space_token='<_>',
device: int = -1,
group_entities: bool = False,
strict: bool = False,
tag_delimiter: str = '-',
scheme: str = 'IOB',
use_crf=False,
remove_spiece=True):
super().__init__()
assert isinstance(tokenizer, PreTrainedTokenizerBase)
# assert isinstance(model, PreTrainedModel)
self.model = model
self.tokenizer = tokenizer
self.pretokenizer = pretokenizer
self.lowercase = lowercase
self.space_token = space_token
self.device = 'cpu' if device == -1 or not torch.cuda.is_available() else f'cuda:{device}'
self.group_entities = group_entities
self.strict = strict
self.tag_delimiter = tag_delimiter
self.scheme = scheme
self.id2label = self.model.config.id2label
self.label2id = self.model.config.label2id
self.use_crf = use_crf
self.remove_spiece = remove_spiece
self.model.to(self.device)
def preprocess(self, inputs: Union[str, List[str]]) -> Union[List[str], List[List[str]]]:
if self.lowercase:
inputs = inputs.lower() if type(inputs) == str else list(map(str.lower, inputs))
inputs = rm_useless_spaces(inputs) if type(inputs) == str else list(map(rm_useless_spaces, inputs))
tokens = self.pretokenizer(inputs) if type(inputs) == str else list(map(self.pretokenizer, inputs))
tokens = list(map(lambda x: x.replace(' ', self.space_token), tokens)) if type(inputs) == str else \
list(map(lambda _tokens: list(map(lambda x: x.replace(' ', self.space_token), _tokens)), tokens))
return tokens
def _inference(self, input: str):
tokens = [[self.tokenizer.bos_token]] + \
[self.tokenizer.tokenize(tok) if tok != SPIECE else [SPIECE] for tok in self.preprocess(input)] + \
[[self.tokenizer.eos_token]]
ids = [self.tokenizer.convert_tokens_to_ids(token) for token in tokens]
flatten_tokens = list(itertools.chain(*tokens))
flatten_ids = list(itertools.chain(*ids))
input_ids = torch.LongTensor([flatten_ids]).to(self.device)
if self.use_crf:
out = self.model(input_ids=input_ids)
else:
out = self.model(input_ids=input_ids, return_dict=True)
probs = torch.softmax(out['logits'], dim=-1)
vals, indices = probs.topk(1)
indices_np = indices.detach().cpu().numpy().reshape(-1)
list_of_token_label_tuple = list(zip(flatten_tokens, [ self.id2label[idx] for idx in indices_np] ))
merged_preds = self._merged_pred(preds=list_of_token_label_tuple, ids=ids)
if self.remove_spiece:
merged_preds = list(map(lambda x: (x[0].replace(SPIECE, ''), x[1]), merged_preds))
# remove start and end tokens
merged_preds_removed_bos_eos = merged_preds[1:-1]
# convert to list of Dict objects
merged_preds_return_dict = [ {'word': word if word != self.space_token else ' ', 'entity': tag, '√': idx } \
for idx, (word, tag) in enumerate(merged_preds_removed_bos_eos) ]
if (not self.group_entities or self.scheme == None) and self.strict == True:
return merged_preds_return_dict
elif not self.group_entities and self.strict == False:
tags = list(map(lambda x: x['entity'], merged_preds_return_dict))
processed_tags = self._fix_incorrect_tags(tags)
for i, item in enumerate(merged_preds_return_dict):
merged_preds_return_dict[i]['entity'] = processed_tags[i]
return merged_preds_return_dict
elif self.group_entities:
return self._group_entities(merged_preds_removed_bos_eos)
def __call__(self, inputs: Union[str, List[str]]):
"""
"""
if type(inputs) == str:
return self._inference(inputs)
if type(inputs) == list:
results = [ self._inference(text) for text in inputs]
return results
def _merged_pred(self, preds: List[Tuple[str, str]], ids: List[List[int]]):
token_mapping = [ ]
for i in range(0, len(ids)):
for j in range(0, len(ids[i])):
token_mapping.append(i)
grouped_subtokens = []
_subtoken = []
prev_idx = 0
for i, (subtoken, label) in enumerate(preds):
current_idx = token_mapping[i]
if prev_idx != current_idx:
grouped_subtokens.append(_subtoken)
_subtoken = [(subtoken, label)]
if i == len(preds) -1:
_subtoken = [(subtoken, label)]
grouped_subtokens.append(_subtoken)
elif i == len(preds) -1:
_subtoken += [(subtoken, label)]
grouped_subtokens.append(_subtoken)
else:
_subtoken += [(subtoken, label)]
prev_idx = current_idx
merged_subtokens = []
_merged_subtoken = ''
for subtoken_group in grouped_subtokens:
first_token_pred = subtoken_group[0][1]
_merged_subtoken = ''.join(list(map(lambda x: x[0], subtoken_group)))
merged_subtokens.append((_merged_subtoken, first_token_pred))
return merged_subtokens
def _fix_incorrect_tags(self, tags: List[str]) -> List[str]:
I_PREFIX = f'I{self.tag_delimiter}'
E_PREFIX = f'E{self.tag_delimiter}'
B_PREFIX = f'B{self.tag_delimiter}'
O_PREFIX = 'O'
previous_tag_ne = None
for i, current_tag in enumerate(tags):
current_tag_ne = current_tag.split(self.tag_delimiter)[-1] if current_tag != O_PREFIX else O_PREFIX
if i == 0 and (current_tag.startswith(I_PREFIX) or \
current_tag.startswith(E_PREFIX)):
# if a NE tag (with I-, or E- prefix) occuring at the begining of sentence
# e.g. (I-LOC, I-LOC) , (E-LOC, B-PER) (I-LOC, O, O)
# then, change the prefix of the current tag to B{tag_delimiter}
tags[i] = B_PREFIX + tags[i][2:]
elif i >= 1 and tags[i-1] == O_PREFIX and (
current_tag.startswith(I_PREFIX) or \
current_tag.startswith(E_PREFIX)):
# if a NE tag (with I-, or E- prefix) occuring after O tag
# e.g. (O, I-LOC, I-LOC) , (O, E-LOC, B-PER) (O, I-LOC, O, O)
# then, change the prefix of the current tag to B{tag_delimiter}
tags[i] = B_PREFIX + tags[i][2:]
elif i >= 1 and ( tags[i-1].startswith(I_PREFIX) or \
tags[i-1].startswith(E_PREFIX) or \
tags[i-1].startswith(B_PREFIX)) and \
( current_tag.startswith(I_PREFIX) or current_tag.startswith(E_PREFIX) ) and \
previous_tag_ne != current_tag_ne:
# if a NE tag (with I-, or E- prefix) occuring after NE tag with different NE
# e.g. (B-LOC, I-PER) , (B-LOC, E-LOC, E-PER) (B-LOC, I-LOC, I-PER)
# then, change the prefix of the current tag to B{tag_delimiter}
tags[i] = B_PREFIX + tags[i][2:]
elif i == len(tags) - 1 and tags[i-1] == O_PREFIX and (
current_tag.startswith(I_PREFIX) or \
current_tag.startswith(E_PREFIX)):
# if a NE tag (with I-, or E- prefix) occuring at the end of sentence
# e.g. (O, O, I-LOC) , (O, O, E-LOC)
# then, change the prefix of the current tag to B{tag_delimiter}
tags[i] = B_PREFIX + tags[i][2:]
previous_tag_ne = current_tag_ne
return tags
def _group_entities(self, ner_tags: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
if self.scheme not in ['IOB', 'IOBES', 'IOBE']:
raise AttributeError()
tokens, tags = zip(*ner_tags)
tokens, tags = list(tokens), list(tags)
if self.scheme == 'IOBE':
# Replace E prefix with I prefix
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
if self.scheme == 'IOBES':
# Replace E prefix with I prefix and replace S prefix with B
tags = list(map(lambda x: x.replace(f'E{self.tag_delimiter}', f'I{self.tag_delimiter}'), tags))
tags = list(map(lambda x: x.replace(f'S{self.tag_delimiter}', f'B{self.tag_delimiter}'), tags))
if not self.strict:
tags = self._fix_incorrect_tags(tags)
ent = Tokens(tokens=tags, scheme=IOB2,
suffix=False, delimiter=self.tag_delimiter)
ne_position_mappings = ent.entities
token_positions = []
curr_len = 0
tokens = list(map(lambda x: x.replace('<_>', ' ').replace('ํา', 'ำ'), tokens))
for i, token in enumerate(tokens):
token_len = len(token)
if i == 0:
token_positions.append((0, curr_len + token_len))
else:
token_positions.append((curr_len, curr_len + token_len ))
curr_len += token_len
print(f'token_positions: {list(zip(tokens, token_positions))}')
begin_end_pos = []
begin_end_char_pos = []
accum_char_len = 0
for i, ne_position_mapping in enumerate(ne_position_mappings):
print(f'ne_position_mapping.start: {ne_position_mapping.start}')
print(f'ne_position_mapping.end: {ne_position_mapping.end}\n')
begin_end_pos.append((ne_position_mapping.start, ne_position_mapping.end))
begin_end_char_pos.append((token_positions[ne_position_mapping.start][0], token_positions[ne_position_mapping.end-1][1]))
print(f'begin_end_pos: {begin_end_pos}')
print(f'begin_end_char_pos: {begin_end_char_pos}')
j = 0
# print(f'tokens: {tokens}')
for i, pos_tuple in enumerate(begin_end_pos):
# print(f'j = {j}')
if pos_tuple[0] > 0 and i == 0:
ne_position_mappings.insert(0, (None, 'O', 0, pos_tuple[0]))
j += 1
if begin_end_pos[i-1][1] != begin_end_pos[i][0] and len(begin_end_pos) > 1 and i > 0 :
ne_position_mappings.insert(j, (None, 'O', begin_end_pos[i-1][1], begin_end_pos[i][0]))
j += 1
j += 1
print('ne_position_mappings', ne_position_mappings)
groups = []
k = 0
for i, ne_position_mapping in enumerate(ne_position_mappings):
if type(ne_position_mapping) != tuple:
ne_position_mapping = ne_position_mapping.to_tuple()
ne = ne_position_mapping[1]
text = ''
for ne_position in range(ne_position_mapping[2], ne_position_mapping[3]):
_token = tokens[ne_position]
text += _token if _token != self.space_token else ' '
if ne.lower() != 'o':
groups.append({
'entity_group': ne,
'word': text,
'begin_char_index': begin_end_char_pos[k][0]
})
k+=1
else:
groups.append({
'entity_group': ne,
'word': text,
})
return groups