import torch import logging import os logger = logging.getLogger(__name__) from torchvision import transforms from PIL import Image class SBInputExample(object): """A single training/test example for simple sequence classification.""" def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None): """Constructs a InputExample. Args: guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples. """ self.guid = guid self.text_a = text_a self.text_b = text_b self.img_id = img_id # Please note that the auxlabel is not used in SB # it is just kept in order not to modify the original code class SBInputFeatures(object): """A single set of features of data""" def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat): self.input_ids = input_ids self.input_mask = input_mask self.added_input_mask = added_input_mask self.segment_ids = segment_ids self.img_feat = img_feat def sbreadfile(filename): ''' Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các từ và danh sách hình ảnh. ''' print("Chuẩn bị dữ liệu từ", filename) with open(filename, encoding='utf8') as f: data = [] imgs = [] sentence = [] imgid = '' for line in f: line = line.strip() # Loại bỏ các dấu cách thừa ở đầu và cuối dòng if line.startswith('IMGID:'): imgid = line.split('IMGID:')[1] + '.jpg' continue if line == '': if len(sentence) > 0: data.append(sentence) imgs.append(imgid) sentence = [] imgid = '' continue word = line.split('\t')[0] # Chỉ lấy từ (không lấy nhãn) sentence.append(word) if len(sentence) > 0: # Xử lý dữ liệu cuối cùng trong tệp data.append(sentence) imgs.append(imgid) print("Số lượng mẫu: " + str(len(data))) print("Số lượng hình ảnh: " + str(len(imgs))) return data, imgs class DataProcessor(object): """Base class for data converters for sequence classification data sets.""" def get_train_examples(self, data_dir): """Gets a collection of `InputExample`s for the train set.""" raise NotImplementedError() def get_dev_examples(self, data_dir): """Gets a collection of `InputExample`s for the dev set.""" raise NotImplementedError() def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() @classmethod def _read_sbtsv(cls, input_file, quotechar=None): """Reads a tab separated value file.""" return sbreadfile(input_file) class MNERProcessor(DataProcessor): """Processor for the CoNLL-2003 data set.""" def get_train_examples(self, data_dir): """See base class.""" data, imgs = self._read_sbtsv(os.path.join(data_dir, "train.txt")) return self._create_examples(data, imgs, "train") def get_dev_examples(self, data_dir): """See base class.""" data, imgs = self._read_sbtsv(os.path.join(data_dir, "dev.txt")) return self._create_examples(data, imgs, "dev") def get_test_examples(self, data_dir): """See base class.""" data, imgs = self._read_sbtsv(os.path.join(data_dir, "test.txt")) return self._create_examples(data, imgs, "test") def get_labels(self): # return [ # "O","I-PRODUCT-AWARD", # "B-MISCELLANEOUS", # "B-QUANTITY-NUM", # "B-ORGANIZATION-SPORTS", # "B-DATETIME", # "I-ADDRESS", # "I-PERSON", # "I-EVENT-SPORT", # "B-ADDRESS", # "B-EVENT-NATURAL", # "I-LOCATION-GPE", # "B-EVENT-GAMESHOW", # "B-DATETIME-TIMERANGE", # "I-QUANTITY-NUM", # "I-QUANTITY-AGE", # "B-EVENT-CUL", # "I-QUANTITY-TEM", # "I-PRODUCT-LEGAL", # "I-LOCATION-STRUC", # "I-ORGANIZATION", # "B-PHONENUMBER", # "B-IP", # "B-QUANTITY-AGE", # "I-DATETIME-TIME", # "I-DATETIME", # "B-ORGANIZATION-MED", # "B-DATETIME-SET", # "I-EVENT-CUL", # "B-QUANTITY-DIM", # "I-QUANTITY-DIM", # "B-EVENT", # "B-DATETIME-DATERANGE", # "I-EVENT-GAMESHOW", # "B-PRODUCT-AWARD", # "B-LOCATION-STRUC", # "B-LOCATION", # "B-PRODUCT", # "I-MISCELLANEOUS", # "B-SKILL", # "I-QUANTITY-ORD", # "I-ORGANIZATION-STOCK", # "I-LOCATION-GEO", # "B-PERSON", # "B-PRODUCT-COM", # "B-PRODUCT-LEGAL", # "I-LOCATION", # "B-QUANTITY-TEM", # "I-PRODUCT", # "B-QUANTITY-CUR", # "I-QUANTITY-CUR", # "B-LOCATION-GPE", # "I-PHONENUMBER", # "I-ORGANIZATION-MED", # "I-EVENT-NATURAL", # "I-EMAIL", # "B-ORGANIZATION", # "B-URL", # "I-DATETIME-TIMERANGE", # "I-QUANTITY", # "I-IP", # "B-EVENT-SPORT", # "B-PERSONTYPE", # "B-QUANTITY-PER", # "I-QUANTITY-PER", # "I-PRODUCT-COM", # "I-DATETIME-DURATION", # "B-LOCATION-GPE-GEO", # "B-QUANTITY-ORD", # "I-EVENT", # "B-DATETIME-TIME", # "B-QUANTITY", # "I-DATETIME-SET", # "I-LOCATION-GPE-GEO", # "B-ORGANIZATION-STOCK", # "I-ORGANIZATION-SPORTS", # "I-SKILL", # "I-URL", # "B-DATETIME-DURATION", # "I-DATETIME-DATE", # "I-PERSONTYPE", # "B-DATETIME-DATE", # "I-DATETIME-DATERANGE", # "B-LOCATION-GEO", # "B-EMAIL","X","", ""] # vlsp2016 return [ "I-LOC", "B-MISC", "I-PER", "I-ORG", "B-LOC", "I-MISC", "B-ORG", "O", "B-PER", "X", "", ""] # vlsp2018 # return [ # "O","I-ORGANIZATION", # "B-ORGANIZATION", # "I-LOCATION", # "B-MISCELLANEOUS", # "I-PERSON", # "B-PERSON", # "I-MISCELLANEOUS", # "B-LOCATION", # "X", # "", # ""] def get_auxlabels(self): return ["O", "B", "I", "X", "", ""] def get_start_label_id(self): label_list = self.get_labels() label_map = {label: i for i, label in enumerate(label_list, 1)} return label_map[''] def get_stop_label_id(self): label_list = self.get_labels() label_map = {label: i for i, label in enumerate(label_list, 1)} return label_map[''] def _create_examples(self, lines, imgs, set_type): examples = [] for i, (sentence) in enumerate(lines): guid = "%s-%s" % (set_type, i) text_a = ' '.join(sentence) text_b = None img_id = imgs[i] examples.append( SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id)) return examples def create_examples(lines, imgs, set_type): examples = [] for i, (sentence) in enumerate(lines): guid = "%s-%s" % (set_type, i) text_a = ' '.join(sentence) text_b = None img_id = imgs[i] examples.append( SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id)) return examples def get_test_examples_predict(data_dir): """See base class.""" data, imgs = sbreadfile(os.path.join(data_dir, "test.txt")) return create_examples(data, imgs, "test") def image_process(image_path, transform): image = Image.open(image_path).convert('RGB') image = transform(image) return image def convert_mm_examples_to_features_predict(examples, max_seq_length, tokenizer, crop_size, path_img): features = [] count = 0 transform = transforms.Compose([ transforms.Resize([256, 256]), transforms.RandomCrop(crop_size), # args.crop_size, by default it is set to be 224 transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))]) for (ex_index, example) in enumerate(examples): textlist = example.text_a.split(' ') tokens = [] for i, word in enumerate(textlist): token = tokenizer.tokenize(word) tokens.extend(token) if len(tokens) >= max_seq_length - 1: tokens = tokens[0:(max_seq_length - 2)] ntokens = [] segment_ids = [] ntokens.append("") segment_ids.append(0) for i, token in enumerate(tokens): ntokens.append(token) segment_ids.append(0) ntokens.append("") segment_ids.append(0) input_ids = tokenizer.convert_tokens_to_ids(ntokens) input_mask = [1] * len(input_ids) added_input_mask = [1] * (len(input_ids) + 49) # 1 or 49 is for encoding regional image representations while len(input_ids) < max_seq_length: input_ids.append(0) input_mask.append(0) added_input_mask.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length image_name = example.img_id image_path = os.path.join(path_img, image_name) if not os.path.exists(image_path): if 'NaN' not in image_path: print(image_path) try: image = image_process(image_path, transform) except: count += 1 image_path_fail = os.path.join(path_img, 'background.jpg') image = image_process(image_path_fail, transform) else: if ex_index < 1: logger.info("*** Example ***") logger.info("guid: %s" % (example.guid)) logger.info("tokens: %s" % " ".join( [str(x) for x in tokens])) logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids])) logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) features.append( SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask, segment_ids=segment_ids, img_feat=image)) print('the number of problematic samples: ' + str(count)) return features