|
import torch
|
|
import logging
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
|
|
|
|
class SBInputExample(object):
|
|
"""A single training/test example for simple sequence classification."""
|
|
|
|
def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None):
|
|
"""Constructs a InputExample.
|
|
|
|
Args:
|
|
guid: Unique id for the example.
|
|
text_a: string. The untokenized text of the first sequence. For single
|
|
sequence tasks, only this sequence must be specified.
|
|
text_b: (Optional) string. The untokenized text of the second sequence.
|
|
Only must be specified for sequence pair tasks.
|
|
label: (Optional) string. The label of the example. This should be
|
|
specified for train and dev examples, but not for test examples.
|
|
"""
|
|
self.guid = guid
|
|
self.text_a = text_a
|
|
self.text_b = text_b
|
|
self.img_id = img_id
|
|
self.label = label
|
|
|
|
|
|
self.auxlabel = auxlabel
|
|
|
|
|
|
class SBInputFeatures(object):
|
|
"""A single set of features of data"""
|
|
|
|
def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat, label_id, auxlabel_id):
|
|
self.input_ids = input_ids
|
|
self.input_mask = input_mask
|
|
self.added_input_mask = added_input_mask
|
|
self.segment_ids = segment_ids
|
|
self.img_feat = img_feat
|
|
self.label_id = label_id
|
|
self.auxlabel_id = auxlabel_id
|
|
|
|
|
|
def sbreadfile(filename):
|
|
'''
|
|
Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các cặp từ và nhãn, cùng với danh sách hình ảnh và nhãn phụ.
|
|
'''
|
|
print("Chuẩn bị dữ liệu cho ", filename)
|
|
f = open(filename, encoding='utf8')
|
|
data = []
|
|
imgs = []
|
|
auxlabels = []
|
|
sentence = []
|
|
label = []
|
|
auxlabel = []
|
|
imgid = ''
|
|
|
|
for line in f:
|
|
line = line.strip()
|
|
if line.startswith('IMGID:'):
|
|
imgid = line.split('IMGID:')[1] + '.jpg'
|
|
continue
|
|
if line == '':
|
|
if len(sentence) > 0:
|
|
data.append((sentence, label))
|
|
imgs.append(imgid)
|
|
auxlabels.append(auxlabel)
|
|
sentence = []
|
|
label = []
|
|
auxlabel = []
|
|
imgid = ''
|
|
continue
|
|
splits = line.split('\t')
|
|
if len(splits) == 2:
|
|
word, cur_label = splits
|
|
sentence.append(word)
|
|
label.append(cur_label)
|
|
auxlabel.append(cur_label[0])
|
|
|
|
if len(sentence) > 0:
|
|
data.append((sentence, label))
|
|
imgs.append(imgid)
|
|
auxlabels.append(auxlabel)
|
|
|
|
print("Số lượng mẫu: " + str(len(data)))
|
|
print("Số lượng hình ảnh: " + str(len(imgs)))
|
|
return data, imgs, auxlabels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataProcessor(object):
|
|
"""Base class for data converters for sequence classification data sets."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""Gets a collection of `InputExample`s for the train set."""
|
|
raise NotImplementedError()
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""Gets a collection of `InputExample`s for the dev set."""
|
|
raise NotImplementedError()
|
|
|
|
def get_labels(self):
|
|
"""Gets the list of labels for this data set."""
|
|
raise NotImplementedError()
|
|
|
|
@classmethod
|
|
def _read_sbtsv(cls, input_file, quotechar=None):
|
|
"""Reads a tab separated value file."""
|
|
return sbreadfile(input_file)
|
|
|
|
|
|
class MNERProcessor_2016(DataProcessor):
|
|
"""Processor for the CoNLL-2003 data set."""
|
|
|
|
def get_train_examples(self, data_dir):
|
|
"""See base class."""
|
|
data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "train.txt"))
|
|
return self._create_examples(data, imgs, auxlabels, "train")
|
|
|
|
def get_dev_examples(self, data_dir):
|
|
"""See base class."""
|
|
data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "dev.txt"))
|
|
return self._create_examples(data, imgs, auxlabels, "dev")
|
|
|
|
def get_test_examples(self, data_dir):
|
|
"""See base class."""
|
|
data, imgs, auxlabels = self._read_sbtsv(os.path.join(data_dir, "test.txt"))
|
|
return self._create_examples(data, imgs, auxlabels, "test")
|
|
|
|
def get_labels(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [
|
|
"B-ORG", "B-MISC",
|
|
"I-PER",
|
|
"I-ORG",
|
|
"B-LOC",
|
|
"I-MISC",
|
|
"I-LOC",
|
|
"O",
|
|
"B-PER",
|
|
"X",
|
|
"<s>",
|
|
"</s>"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_auxlabels(self):
|
|
return ["O", "B", "I", "X", "<s>", "</s>"]
|
|
|
|
def get_start_label_id(self):
|
|
label_list = self.get_labels()
|
|
label_map = {label: i for i, label in enumerate(label_list, 1)}
|
|
return label_map['<s>']
|
|
|
|
def get_stop_label_id(self):
|
|
label_list = self.get_labels()
|
|
label_map = {label: i for i, label in enumerate(label_list, 1)}
|
|
return label_map['</s>']
|
|
|
|
def _create_examples(self, lines, imgs, auxlabels, set_type):
|
|
examples = []
|
|
for i, (sentence, label) in enumerate(lines):
|
|
guid = "%s-%s" % (set_type, i)
|
|
text_a = ' '.join(sentence)
|
|
text_b = None
|
|
img_id = imgs[i]
|
|
label = label
|
|
auxlabel = auxlabels[i]
|
|
examples.append(
|
|
SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id, label=label, auxlabel=auxlabel))
|
|
return examples
|
|
|
|
|
|
def image_process(image_path, transform):
|
|
image = Image.open(image_path).convert('RGB')
|
|
image = transform(image)
|
|
return image
|
|
|
|
|
|
def convert_mm_examples_to_features(examples, label_list, auxlabel_list,
|
|
max_seq_length, tokenizer, crop_size, path_img):
|
|
label_map = {label: i for i, label in enumerate(label_list, 1)}
|
|
auxlabel_map = {label: i for i, label in enumerate(auxlabel_list, 1)}
|
|
|
|
features = []
|
|
count = 0
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize([256, 256]),
|
|
transforms.RandomCrop(crop_size),
|
|
transforms.RandomHorizontalFlip(),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize((0.485, 0.456, 0.406),
|
|
(0.229, 0.224, 0.225))])
|
|
|
|
for (ex_index, example) in enumerate(examples):
|
|
textlist = example.text_a.split(' ')
|
|
labellist = example.label
|
|
auxlabellist = example.auxlabel
|
|
tokens = []
|
|
labels = []
|
|
auxlabels = []
|
|
for i, word in enumerate(textlist):
|
|
token = tokenizer.tokenize(word)
|
|
tokens.extend(token)
|
|
label_1 = labellist[i]
|
|
auxlabel_1 = auxlabellist[i]
|
|
for m in range(len(token)):
|
|
if m == 0:
|
|
labels.append(label_1)
|
|
auxlabels.append(auxlabel_1)
|
|
else:
|
|
labels.append("X")
|
|
auxlabels.append("X")
|
|
if len(tokens) >= max_seq_length - 1:
|
|
tokens = tokens[0:(max_seq_length - 2)]
|
|
labels = labels[0:(max_seq_length - 2)]
|
|
auxlabels = auxlabels[0:(max_seq_length - 2)]
|
|
ntokens = []
|
|
segment_ids = []
|
|
label_ids = []
|
|
auxlabel_ids = []
|
|
ntokens.append("<s>")
|
|
segment_ids.append(0)
|
|
label_ids.append(label_map["<s>"])
|
|
auxlabel_ids.append(auxlabel_map["<s>"])
|
|
for i, token in enumerate(tokens):
|
|
ntokens.append(token)
|
|
segment_ids.append(0)
|
|
label_ids.append(label_map[labels[i]])
|
|
auxlabel_ids.append(auxlabel_map[auxlabels[i]])
|
|
ntokens.append("</s>")
|
|
segment_ids.append(0)
|
|
label_ids.append(label_map["</s>"])
|
|
auxlabel_ids.append(auxlabel_map["</s>"])
|
|
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
|
|
input_mask = [1] * len(input_ids)
|
|
added_input_mask = [1] * (len(input_ids) + 49)
|
|
|
|
while len(input_ids) < max_seq_length:
|
|
input_ids.append(0)
|
|
input_mask.append(0)
|
|
added_input_mask.append(0)
|
|
segment_ids.append(0)
|
|
label_ids.append(0)
|
|
auxlabel_ids.append(0)
|
|
|
|
assert len(input_ids) == max_seq_length
|
|
assert len(input_mask) == max_seq_length
|
|
assert len(segment_ids) == max_seq_length
|
|
assert len(label_ids) == max_seq_length
|
|
assert len(auxlabel_ids) == max_seq_length
|
|
|
|
image_name = example.img_id
|
|
image_path = os.path.join(path_img, image_name)
|
|
|
|
if not os.path.exists(image_path):
|
|
if 'NaN' not in image_path:
|
|
print(image_path)
|
|
try:
|
|
image = image_process(image_path, transform)
|
|
except:
|
|
count += 1
|
|
image_path_fail = os.path.join(path_img, 'background.jpg')
|
|
image = image_process(image_path_fail, transform)
|
|
|
|
else:
|
|
if ex_index < 2:
|
|
logger.info("*** Example ***")
|
|
logger.info("guid: %s" % (example.guid))
|
|
logger.info("tokens: %s" % " ".join(
|
|
[str(x) for x in tokens]))
|
|
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
|
|
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
|
logger.info(
|
|
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
|
logger.info("label: %s" % " ".join([str(x) for x in label_ids]))
|
|
logger.info("auxlabel: %s" % " ".join([str(x) for x in auxlabel_ids]))
|
|
|
|
features.append(
|
|
SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
|
|
segment_ids=segment_ids, img_feat=image, label_id=label_ids, auxlabel_id=auxlabel_ids))
|
|
|
|
print('the number of problematic samples: ' + str(count))
|
|
return features
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|