Linhz's picture
Upload 80 files
fd07025 verified
import torch
import logging
import os
logger = logging.getLogger(__name__)
from torchvision import transforms
from PIL import Image
class SBInputExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self, guid, text_a, text_b, img_id, label=None, auxlabel=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.img_id = img_id
# Please note that the auxlabel is not used in SB
# it is just kept in order not to modify the original code
class SBInputFeatures(object):
"""A single set of features of data"""
def __init__(self, input_ids, input_mask, added_input_mask, segment_ids, img_feat):
self.input_ids = input_ids
self.input_mask = input_mask
self.added_input_mask = added_input_mask
self.segment_ids = segment_ids
self.img_feat = img_feat
def sbreadfile(filename):
'''
Đọc dữ liệu từ tệp và trả về dưới dạng danh sách các từ và danh sách hình ảnh.
'''
print("Chuẩn bị dữ liệu từ", filename)
with open(filename, encoding='utf8') as f:
data = []
imgs = []
sentence = []
imgid = ''
for line in f:
line = line.strip() # Loại bỏ các dấu cách thừa ở đầu và cuối dòng
if line.startswith('IMGID:'):
imgid = line.split('IMGID:')[1] + '.jpg'
continue
if line == '':
if len(sentence) > 0:
data.append(sentence)
imgs.append(imgid)
sentence = []
imgid = ''
continue
word = line.split('\t')[0] # Chỉ lấy từ (không lấy nhãn)
sentence.append(word)
if len(sentence) > 0: # Xử lý dữ liệu cuối cùng trong tệp
data.append(sentence)
imgs.append(imgid)
print("Số lượng mẫu: " + str(len(data)))
print("Số lượng hình ảnh: " + str(len(imgs)))
return data, imgs
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
@classmethod
def _read_sbtsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
return sbreadfile(input_file)
class MNERProcessor(DataProcessor):
"""Processor for the CoNLL-2003 data set."""
def get_train_examples(self, data_dir):
"""See base class."""
data, imgs = self._read_sbtsv(os.path.join(data_dir, "train.txt"))
return self._create_examples(data, imgs, "train")
def get_dev_examples(self, data_dir):
"""See base class."""
data, imgs = self._read_sbtsv(os.path.join(data_dir, "dev.txt"))
return self._create_examples(data, imgs, "dev")
def get_test_examples(self, data_dir):
"""See base class."""
data, imgs = self._read_sbtsv(os.path.join(data_dir, "test.txt"))
return self._create_examples(data, imgs, "test")
def get_labels(self):
# return [
# "O","I-PRODUCT-AWARD",
# "B-MISCELLANEOUS",
# "B-QUANTITY-NUM",
# "B-ORGANIZATION-SPORTS",
# "B-DATETIME",
# "I-ADDRESS",
# "I-PERSON",
# "I-EVENT-SPORT",
# "B-ADDRESS",
# "B-EVENT-NATURAL",
# "I-LOCATION-GPE",
# "B-EVENT-GAMESHOW",
# "B-DATETIME-TIMERANGE",
# "I-QUANTITY-NUM",
# "I-QUANTITY-AGE",
# "B-EVENT-CUL",
# "I-QUANTITY-TEM",
# "I-PRODUCT-LEGAL",
# "I-LOCATION-STRUC",
# "I-ORGANIZATION",
# "B-PHONENUMBER",
# "B-IP",
# "B-QUANTITY-AGE",
# "I-DATETIME-TIME",
# "I-DATETIME",
# "B-ORGANIZATION-MED",
# "B-DATETIME-SET",
# "I-EVENT-CUL",
# "B-QUANTITY-DIM",
# "I-QUANTITY-DIM",
# "B-EVENT",
# "B-DATETIME-DATERANGE",
# "I-EVENT-GAMESHOW",
# "B-PRODUCT-AWARD",
# "B-LOCATION-STRUC",
# "B-LOCATION",
# "B-PRODUCT",
# "I-MISCELLANEOUS",
# "B-SKILL",
# "I-QUANTITY-ORD",
# "I-ORGANIZATION-STOCK",
# "I-LOCATION-GEO",
# "B-PERSON",
# "B-PRODUCT-COM",
# "B-PRODUCT-LEGAL",
# "I-LOCATION",
# "B-QUANTITY-TEM",
# "I-PRODUCT",
# "B-QUANTITY-CUR",
# "I-QUANTITY-CUR",
# "B-LOCATION-GPE",
# "I-PHONENUMBER",
# "I-ORGANIZATION-MED",
# "I-EVENT-NATURAL",
# "I-EMAIL",
# "B-ORGANIZATION",
# "B-URL",
# "I-DATETIME-TIMERANGE",
# "I-QUANTITY",
# "I-IP",
# "B-EVENT-SPORT",
# "B-PERSONTYPE",
# "B-QUANTITY-PER",
# "I-QUANTITY-PER",
# "I-PRODUCT-COM",
# "I-DATETIME-DURATION",
# "B-LOCATION-GPE-GEO",
# "B-QUANTITY-ORD",
# "I-EVENT",
# "B-DATETIME-TIME",
# "B-QUANTITY",
# "I-DATETIME-SET",
# "I-LOCATION-GPE-GEO",
# "B-ORGANIZATION-STOCK",
# "I-ORGANIZATION-SPORTS",
# "I-SKILL",
# "I-URL",
# "B-DATETIME-DURATION",
# "I-DATETIME-DATE",
# "I-PERSONTYPE",
# "B-DATETIME-DATE",
# "I-DATETIME-DATERANGE",
# "B-LOCATION-GEO",
# "B-EMAIL","X","<s>", "</s>"]
# vlsp2016
return [
"I-LOC", "B-MISC",
"I-PER",
"I-ORG",
"B-LOC",
"I-MISC",
"B-ORG",
"O",
"B-PER",
"X",
"<s>",
"</s>"]
# vlsp2018
# return [
# "O","I-ORGANIZATION",
# "B-ORGANIZATION",
# "I-LOCATION",
# "B-MISCELLANEOUS",
# "I-PERSON",
# "B-PERSON",
# "I-MISCELLANEOUS",
# "B-LOCATION",
# "X",
# "<s>",
# "</s>"]
def get_auxlabels(self):
return ["O", "B", "I", "X", "<s>", "</s>"]
def get_start_label_id(self):
label_list = self.get_labels()
label_map = {label: i for i, label in enumerate(label_list, 1)}
return label_map['<s>']
def get_stop_label_id(self):
label_list = self.get_labels()
label_map = {label: i for i, label in enumerate(label_list, 1)}
return label_map['</s>']
def _create_examples(self, lines, imgs, set_type):
examples = []
for i, (sentence) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = ' '.join(sentence)
text_b = None
img_id = imgs[i]
examples.append(
SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
return examples
def create_examples(lines, imgs, set_type):
examples = []
for i, (sentence) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = ' '.join(sentence)
text_b = None
img_id = imgs[i]
examples.append(
SBInputExample(guid=guid, text_a=text_a, text_b=text_b, img_id=img_id))
return examples
def get_test_examples_predict(data_dir):
"""See base class."""
data, imgs = sbreadfile(os.path.join(data_dir, "test.txt"))
return create_examples(data, imgs, "test")
def image_process(image_path, transform):
image = Image.open(image_path).convert('RGB')
image = transform(image)
return image
def convert_mm_examples_to_features_predict(examples,
max_seq_length, tokenizer, crop_size, path_img):
features = []
count = 0
transform = transforms.Compose([
transforms.Resize([256, 256]),
transforms.RandomCrop(crop_size), # args.crop_size, by default it is set to be 224
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))])
for (ex_index, example) in enumerate(examples):
textlist = example.text_a.split(' ')
tokens = []
for i, word in enumerate(textlist):
token = tokenizer.tokenize(word)
tokens.extend(token)
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
ntokens = []
segment_ids = []
ntokens.append("<s>")
segment_ids.append(0)
for i, token in enumerate(tokens):
ntokens.append(token)
segment_ids.append(0)
ntokens.append("</s>")
segment_ids.append(0)
input_ids = tokenizer.convert_tokens_to_ids(ntokens)
input_mask = [1] * len(input_ids)
added_input_mask = [1] * (len(input_ids) + 49) # 1 or 49 is for encoding regional image representations
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
added_input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
image_name = example.img_id
image_path = os.path.join(path_img, image_name)
if not os.path.exists(image_path):
if 'NaN' not in image_path:
print(image_path)
try:
image = image_process(image_path, transform)
except:
count += 1
image_path_fail = os.path.join(path_img, 'background.jpg')
image = image_process(image_path_fail, transform)
else:
if ex_index < 1:
logger.info("*** Example ***")
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info(
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
features.append(
SBInputFeatures(input_ids=input_ids, input_mask=input_mask, added_input_mask=added_input_mask,
segment_ids=segment_ids, img_feat=image))
print('the number of problematic samples: ' + str(count))
return features