|
import torch |
|
from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel |
|
from torch import nn |
|
from itertools import chain |
|
from torch.nn import MSELoss, CrossEntropyLoss |
|
from cleantext import clean |
|
from num2words import num2words |
|
import re |
|
import string |
|
|
|
punct_chars = list((set(string.punctuation) | {'β', 'β', 'β', 'β', '~', '|', 'β', 'β', 'β¦', "'", "`", '_'})) |
|
punct_chars.sort() |
|
punctuation = ''.join(punct_chars) |
|
replace = re.compile('[%s]' % re.escape(punctuation)) |
|
|
|
MATH_PREFIXES = [ |
|
"sum", |
|
"arc", |
|
"mass", |
|
"digit", |
|
"graph", |
|
"liter", |
|
"gram", |
|
"add", |
|
"angle", |
|
"scale", |
|
"data", |
|
"array", |
|
"ruler", |
|
"meter", |
|
"total", |
|
"unit", |
|
"prism", |
|
"median", |
|
"ratio", |
|
"area", |
|
] |
|
|
|
MATH_WORDS = [ |
|
"absolute value", |
|
"algebra", |
|
"area", |
|
"average", |
|
"base of", |
|
"box plot", |
|
"categorical", |
|
"coefficient", |
|
"common factor", |
|
"common multiple", |
|
"compose", |
|
"coordinate", |
|
"cubed", |
|
"decompose", |
|
"dependent variable", |
|
"distribution", |
|
"dot plot", |
|
"double number line diagram", |
|
"equivalent", |
|
"equivalent expression", |
|
"ratio", |
|
"exponent", |
|
"frequency", |
|
"greatest common factor", |
|
"gcd", |
|
"height of", |
|
"histogram", |
|
"independent variable", |
|
"integer", |
|
"interquartile range", |
|
"iqr", |
|
"least common multiple", |
|
"long division", |
|
"mean absolute deviation", |
|
"median", |
|
"negative number", |
|
"opposite vertex", |
|
"parallelogram", |
|
"percent", |
|
"polygon", |
|
"polyhedron", |
|
"positive number", |
|
"prism", |
|
"pyramid", |
|
"quadrant", |
|
"quadrilateral", |
|
"quartile", |
|
"rational number", |
|
"reciprocal", |
|
"equality", |
|
"inequality", |
|
"squared", |
|
"statistic", |
|
"surface area", |
|
"identity property", |
|
"addend", |
|
"unit", |
|
"number sentence", |
|
"make ten", |
|
"take from ten", |
|
"number bond", |
|
"total", |
|
"estimate", |
|
"hashmark", |
|
"meter", |
|
"number line", |
|
"ruler", |
|
"centimeter", |
|
"base ten", |
|
"expanded form", |
|
"hundred", |
|
"thousand", |
|
"place value", |
|
"number disk", |
|
"standard form", |
|
"unit form", |
|
"word form", |
|
"tens place", |
|
"algorithm", |
|
"equation", |
|
"simplif", |
|
"addition", |
|
"subtract", |
|
"array", |
|
"even number", |
|
"odd number", |
|
"repeated addition", |
|
"tessellat", |
|
"whole number", |
|
"number path", |
|
"rectangle", |
|
"square", |
|
"bar graph", |
|
"data", |
|
"degree", |
|
"line plot", |
|
"picture graph", |
|
"scale", |
|
"survey", |
|
"thermometer", |
|
"estimat", |
|
"tape diagram", |
|
"value", |
|
"analog", |
|
"angle", |
|
"parallel", |
|
"partition", |
|
"pentagon", |
|
"right angle", |
|
"cube", |
|
"digital", |
|
"quarter of", |
|
"tangram", |
|
"circle", |
|
"hexagon", |
|
"half circle", |
|
"half-circle", |
|
"quarter circle", |
|
"quarter-circle", |
|
"semicircle", |
|
"semi-circle", |
|
"rectang", |
|
"rhombus", |
|
"trapezoid", |
|
"triangle", |
|
"commutative", |
|
"equal group", |
|
"distributive", |
|
"divide", |
|
"division", |
|
"multipl", |
|
"parentheses", |
|
"quotient", |
|
"rotate", |
|
"unknown", |
|
"add", |
|
"capacity", |
|
"continuous", |
|
"endpoint", |
|
"gram", |
|
"interval", |
|
"kilogram", |
|
"volume", |
|
"liter", |
|
"milliliter", |
|
"approximate", |
|
"area model", |
|
"square unit", |
|
"unit square", |
|
"geometr", |
|
"equivalent fraction", |
|
"fraction form", |
|
"fractional unit", |
|
"unit fraction", |
|
"unit interval", |
|
"measur", |
|
"graph", |
|
"scaled graph", |
|
"diagonal", |
|
"perimeter", |
|
"regular polygon", |
|
"tessellate", |
|
"tetromino", |
|
"heptagon", |
|
"octagon", |
|
"digit", |
|
"expression", |
|
"sum", |
|
"kilometer", |
|
"mass", |
|
"mixed unit", |
|
"length", |
|
"measure", |
|
"simplify", |
|
"associative", |
|
"composite", |
|
"divisible", |
|
"divisor", |
|
"partial product", |
|
"prime number", |
|
"remainder", |
|
"acute", |
|
"arc", |
|
"collinear", |
|
"equilateral", |
|
"intersect", |
|
"isosceles", |
|
"symmetry", |
|
"line segment", |
|
"line", |
|
"obtuse", |
|
"perpendicular", |
|
"protractor", |
|
"scalene", |
|
"straight angle", |
|
"supplementary angle", |
|
"vertex", |
|
"common denominator", |
|
"denominator", |
|
"fraction", |
|
"mixed number", |
|
"numerator", |
|
"whole", |
|
"decimal expanded form", |
|
"decimal", |
|
"hundredth", |
|
"tenth", |
|
"customary system of measurement", |
|
"customary unit", |
|
"gallon", |
|
"metric", |
|
"metric unit", |
|
"ounce", |
|
"pint", |
|
"quart", |
|
"convert", |
|
"distance", |
|
"millimeter", |
|
"thousandth", |
|
"hundredths", |
|
"conversion factor", |
|
"decimal fraction", |
|
"multiplier", |
|
"equivalence", |
|
"multiple", |
|
"product", |
|
"benchmark fraction", |
|
"cup", |
|
"pound", |
|
"yard", |
|
"whole unit", |
|
"decimal divisor", |
|
"factors", |
|
"bisect", |
|
"cubic units", |
|
"hierarchy", |
|
"unit cube", |
|
"attribute", |
|
"kite", |
|
"bisector", |
|
"solid figure", |
|
"square units", |
|
"dimension", |
|
"axis", |
|
"ordered pair", |
|
"angle measure", |
|
"horizontal", |
|
"vertical", |
|
"categorical data", |
|
"lcm", |
|
"measure of center", |
|
"meters per second", |
|
"numerical", |
|
"solution", |
|
"unit price", |
|
"unit rate", |
|
"variability", |
|
"variable", |
|
"abundant number", |
|
"accurate", |
|
"acre", |
|
"addition fact", |
|
"algebraic", |
|
"altitude", |
|
"apex", |
|
"arithmetic facts", |
|
"associative property", |
|
"astronomical unit", |
|
"base", |
|
"baseline", |
|
"billion", |
|
"celsius", |
|
"census", |
|
"cent", |
|
"center of a circle", |
|
"center of a sphere", |
|
"chance", |
|
"circle graph", |
|
"column", |
|
"combine", |
|
"common fraction", |
|
"comparison diagram", |
|
"comparison story", |
|
"compass", |
|
"complement", |
|
"concave polygon", |
|
"concentric circles", |
|
"consecutive", |
|
"constant", |
|
"continuous model of area", |
|
"continuous model of volume", |
|
"contour", |
|
"conversion fact", |
|
"convex polygon", |
|
"counting numbers", |
|
"counting up subtraction", |
|
"cover-up method", |
|
"cross multiplication", |
|
"cubic", |
|
"cubit", |
|
"curved surface", |
|
"cylinder", |
|
"decagon", |
|
"decimeter", |
|
"deficient number", |
|
"density", |
|
"discrete model", |
|
"displacement method", |
|
"divisibility test", |
|
"divisible by", |
|
"dodecahedron", |
|
"double stem plot", |
|
"doubles fact", |
|
"egyptian multiplication", |
|
"elevation", |
|
"embed figure", |
|
"end point", |
|
"enlarge", |
|
"equal", |
|
"equal groups", |
|
"equal parts", |
|
"equidistant marks", |
|
"equilateral polygon", |
|
"equivalent fractions", |
|
"european subtraction", |
|
"expanded notation", |
|
"expected outcome", |
|
"exponential", |
|
"extended facts", |
|
"fact power", |
|
"fact triangle", |
|
"factor", |
|
"factors of numbers", |
|
"fahrenheit", |
|
"false number sentence", |
|
"figurate numbers", |
|
"flowchart", |
|
"fluid ounce", |
|
"fractional part", |
|
"fulcrum", |
|
"function machine", |
|
"furlong", |
|
"genus", |
|
"geoboard", |
|
"geometric solid", |
|
"geometry template", |
|
"girth", |
|
"golden ratio", |
|
"golden rectangle", |
|
"graph key", |
|
"grouping symbol", |
|
"hemisphere", |
|
"icosahedron", |
|
"improper fraction", |
|
"inch", |
|
"index of locations", |
|
"indirect measurement", |
|
"input", |
|
"inscribed polygon", |
|
"instance of a pattern", |
|
"interior of a figure", |
|
"interpolate", |
|
"irrational", |
|
"isometry transformation", |
|
"isosceles trapezoid", |
|
"juxtapose", |
|
"key sequence", |
|
"label", |
|
"landmark", |
|
"latitude", |
|
"lattice multiplication", |
|
"left to right subtraction", |
|
"leg of a right triangle", |
|
"like terms", |
|
"line graph", |
|
"line of reflection", |
|
"line of symmetry", |
|
"line symmetry", |
|
"lines of latitude", |
|
"lines of longitude", |
|
"longitude", |
|
"magnitude estimate", |
|
"map legend", |
|
"map scale", |
|
"maximum", |
|
"measurement division", |
|
"measurement unit", |
|
"meridian bar", |
|
"metric system", |
|
"midpoint", |
|
"mile", |
|
"millisecond", |
|
"minimum", |
|
"minuend", |
|
"mirror image", |
|
"mobius", |
|
"modal", |
|
"multiplication counting principle", |
|
"multiplication diagram", |
|
"multiplication fact", |
|
"multiplication symbols", |
|
"multiplication use class", |
|
"negative rational numbers", |
|
"nested parentheses", |
|
"net score", |
|
"net weight", |
|
"nonagon", |
|
"nonconvex polygon", |
|
"normal span", |
|
"number grid", |
|
"number sequence", |
|
"numeral", |
|
"numeration", |
|
"octahedron", |
|
"open proportion", |
|
"operation", |
|
"operation symbol", |
|
"opposite angle", |
|
"opposite change rule", |
|
"opposite of a number", |
|
"opposite side", |
|
"order of magnitude", |
|
"order of operations", |
|
"order of rotation symmetry", |
|
"ordinal number", |
|
"pan balance", |
|
"parabola", |
|
"parallel lines", |
|
"parallel planes", |
|
"part to part ratio", |
|
"part to whole ratio", |
|
"part whole fraction", |
|
"partial differences subtraction", |
|
"partial products multiplication", |
|
"partial quotients division", |
|
"partial sums addition", |
|
"partitive division", |
|
"parts and total diagram", |
|
"per capita", |
|
"per unit rate", |
|
"percent circle", |
|
"perfect number", |
|
"perpetual calendar", |
|
"pie graph", |
|
"plane", |
|
"plane figure", |
|
"point symmetry", |
|
"population density", |
|
"precise", |
|
"predict", |
|
"prediction line", |
|
"preimage", |
|
"prime factorization", |
|
"prime meridian", |
|
"probability", |
|
"probability meter", |
|
"probability tree diagram", |
|
"proper factor", |
|
"proper fraction", |
|
"property", |
|
"quadrangle", |
|
"quick common denominator", |
|
"quotitive division", |
|
"random draw", |
|
"random experiment", |
|
"random number", |
|
"random sample", |
|
"rank", |
|
"rate diagram", |
|
"rate multiplication ", |
|
"rate unit", |
|
"recall survey", |
|
"rectangular array", |
|
"rectangular coordinate grid", |
|
"rectangular prism", |
|
"rectangular pyramid", |
|
"rectilinear figure", |
|
"reflection", |
|
"reflex angle", |
|
"regular polyhedron", |
|
"regular tessellation", |
|
"relation symbol", |
|
"revolution", |
|
"right cone", |
|
"right cylinder", |
|
"right prism", |
|
"right pyramid", |
|
"right triangle", |
|
"roman numerals", |
|
"rotation symmetry", |
|
"same change rule for subtraction", |
|
"scale model", |
|
"scale of a map", |
|
"scale of a number line", |
|
"sector", |
|
"segment", |
|
"sequence", |
|
"significant digits", |
|
"similar figures", |
|
"simpler form", |
|
"situtation diagram", |
|
"skew lines", |
|
"slanted", |
|
"slide rule", |
|
"span", |
|
"stacked bar graph", |
|
"standard unit", |
|
"stem and leaf plot", |
|
"step graph", |
|
"straightedge", |
|
"substitute", |
|
"subtrahend", |
|
"surface", |
|
"symmetric", |
|
"tally", |
|
"tangent", |
|
"tangent circles", |
|
"temperature", |
|
"template", |
|
"tetrahedron", |
|
"theorem", |
|
"tile", |
|
"tiling", |
|
"time graph", |
|
"timeline", |
|
"top heavy fraction", |
|
"topological", |
|
"topology", |
|
"trade first subtraction", |
|
"tree diagram", |
|
"triangular", |
|
"true number sentence", |
|
"truncate", |
|
"twin primes", |
|
"unlike denominators", |
|
"unlike fractions", |
|
"vanishing ", |
|
"venn diagram", |
|
"vernal equinox", |
|
"weight", |
|
"width", |
|
"base of a prism", |
|
"base of a pyramid", |
|
"face", |
|
"numerical data", |
|
"opposite", |
|
"pace", |
|
"per", |
|
"region", |
|
"sign", |
|
"alternate interior angles", |
|
"base of an exponent", |
|
"cone", |
|
"congruent", |
|
"counterclockwise", |
|
"cube root", |
|
"hypotenuse", |
|
"irrational number", |
|
"linear relationship", |
|
"positive association", |
|
"rate of change", |
|
"translation", |
|
"transversal", |
|
"circumference", |
|
"corresponding", |
|
"expand", |
|
"population", |
|
"proportion", |
|
"radius", |
|
"random", |
|
"repeating decimal", |
|
"representative", |
|
"scaled", |
|
"withdrawal", |
|
"center", |
|
"edge", |
|
"height of a parallelogram or triangle", |
|
"net", |
|
"speed", |
|
"table", |
|
"term", |
|
"adjacent", |
|
"complementary", |
|
"cross-section", |
|
"cross section", |
|
"deposit", |
|
"event", |
|
"measurement error", |
|
"proportional", |
|
"simulation", |
|
"center of a dilation", |
|
"clockwise", |
|
"dilation", |
|
"function", |
|
"negative association", |
|
"pythagorean theorem", |
|
"relative frequency", |
|
"rigid transformation", |
|
"scale factor", |
|
"scatter plot", |
|
"similar", |
|
"sphere", |
|
"two-way table", |
|
"additive identity", |
|
"additive inverse", |
|
"box and whisker plot", |
|
"cartesian coordinates", |
|
"central angle", |
|
"chord", |
|
"combination", |
|
"commutative property", |
|
"coplanar", |
|
"cross product", |
|
"dependent events", |
|
"difference", |
|
"dividend", |
|
"equilateral triangle", |
|
"error of measurement", |
|
"factorial", |
|
"formula", |
|
"identity property of", |
|
"independent events", |
|
"infinity", |
|
"inscribed angle", |
|
"intercept", |
|
"intercepted arc", |
|
"inverse", |
|
"inverse operations", |
|
"isosceles triangle", |
|
"least common denominator", |
|
"like fractions", |
|
"locus", |
|
"logic", |
|
"lowest terms", |
|
"mode", |
|
"multiplicative identity", |
|
"multiplicative inverse", |
|
"mutually exclusive events", |
|
"natural numbers", |
|
"normal", |
|
"permutation", |
|
"pi", |
|
"point", |
|
"power", |
|
"range", |
|
"rate", |
|
"ray", |
|
"real numbers", |
|
"rectangular", |
|
"root", |
|
"rotation", |
|
"scalene triangle", |
|
"scattergram", |
|
"set", |
|
"statistics", |
|
"terminating decimal", |
|
"transformation", |
|
"x intercept", |
|
"x-axis", |
|
"x-intercept", |
|
"y intercept", |
|
"y-axis", |
|
"y-intercept", |
|
"zero", |
|
"zero property of multiplication", |
|
"base of a parallelogram", |
|
"base of a triangle", |
|
"height", |
|
"chance experiment", |
|
"diameter", |
|
"mean", |
|
"percentage", |
|
"sample", |
|
"legs", |
|
"outlier", |
|
"slope", |
|
"square root", |
|
"system of equations", |
|
"tessellation", |
|
] |
|
|
|
def get_num_words(text): |
|
if not isinstance(text, str): |
|
print("%s is not a string" % text) |
|
text = replace.sub(' ', text) |
|
text = re.sub(r'\s+', ' ', text) |
|
text = text.strip() |
|
text = re.sub(r'\[.+\]', " ", text) |
|
return len(text.split()) |
|
|
|
def number_to_words(num): |
|
try: |
|
return num2words(re.sub(",", "", num)) |
|
except: |
|
return num |
|
|
|
|
|
clean_str = lambda s: clean(s, |
|
fix_unicode=True, |
|
to_ascii=True, |
|
lower=True, |
|
no_line_breaks=True, |
|
no_urls=True, |
|
no_emails=True, |
|
no_phone_numbers=True, |
|
no_numbers=True, |
|
no_digits=False, |
|
no_currency_symbols=False, |
|
no_punct=False, |
|
replace_with_url="<URL>", |
|
replace_with_email="<EMAIL>", |
|
replace_with_phone_number="<PHONE>", |
|
replace_with_number=lambda m: number_to_words(m.group()), |
|
replace_with_digit="0", |
|
replace_with_currency_symbol="<CUR>", |
|
lang="en" |
|
) |
|
|
|
clean_str_nopunct = lambda s: clean(s, |
|
fix_unicode=True, |
|
to_ascii=True, |
|
lower=True, |
|
no_line_breaks=True, |
|
no_urls=True, |
|
no_emails=True, |
|
no_phone_numbers=True, |
|
no_numbers=True, |
|
no_digits=False, |
|
no_currency_symbols=False, |
|
no_punct=True, |
|
replace_with_url="<URL>", |
|
replace_with_email="<EMAIL>", |
|
replace_with_phone_number="<PHONE>", |
|
replace_with_number=lambda m: number_to_words(m.group()), |
|
replace_with_digit="0", |
|
replace_with_currency_symbol="<CUR>", |
|
lang="en" |
|
) |
|
|
|
|
|
|
|
class MultiHeadModel(BertPreTrainedModel): |
|
"""Pre-trained BERT model that uses our loss functions""" |
|
|
|
def __init__(self, config, head2size): |
|
super(MultiHeadModel, self).__init__(config, head2size) |
|
config.num_labels = 1 |
|
self.bert = BertModel(config) |
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
module_dict = {} |
|
for head_name, num_labels in head2size.items(): |
|
module_dict[head_name] = nn.Linear(config.hidden_size, num_labels) |
|
self.heads = nn.ModuleDict(module_dict) |
|
|
|
self.init_weights() |
|
|
|
def forward(self, input_ids, token_type_ids=None, attention_mask=None, |
|
head2labels=None, return_pooler_output=False, head2mask=None, |
|
nsp_loss_weights=None): |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
output = self.bert( |
|
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, |
|
output_attentions=False, output_hidden_states=False, return_dict=True) |
|
pooled_output = self.dropout(output["pooler_output"]).to(device) |
|
|
|
head2logits = {} |
|
return_dict = {} |
|
for head_name, head in self.heads.items(): |
|
head2logits[head_name] = self.heads[head_name](pooled_output) |
|
head2logits[head_name] = head2logits[head_name].float() |
|
return_dict[head_name + "_logits"] = head2logits[head_name] |
|
|
|
|
|
if head2labels is not None: |
|
for head_name, labels in head2labels.items(): |
|
num_classes = head2logits[head_name].shape[1] |
|
|
|
|
|
if num_classes == 1: |
|
|
|
|
|
if head2mask is not None and head_name in head2mask: |
|
num_positives = head2labels[head2mask[head_name]].sum() |
|
if num_positives == 0: |
|
return_dict[head_name + "_loss"] = torch.tensor([0]).to(device) |
|
else: |
|
loss_fct = MSELoss(reduction='none') |
|
loss = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
|
return_dict[head_name + "_loss"] = loss.dot(head2labels[head2mask[head_name]].float().view(-1)) / num_positives |
|
else: |
|
loss_fct = MSELoss() |
|
return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name].view(-1), labels.float().view(-1)) |
|
else: |
|
loss_fct = CrossEntropyLoss(weight=nsp_loss_weights.float()) |
|
return_dict[head_name + "_loss"] = loss_fct(head2logits[head_name], labels.view(-1)) |
|
|
|
|
|
if return_pooler_output: |
|
return_dict["pooler_output"] = output["pooler_output"] |
|
|
|
return return_dict |
|
|
|
class InputBuilder(object): |
|
"""Base class for building inputs from segments.""" |
|
|
|
def __init__(self, tokenizer): |
|
self.tokenizer = tokenizer |
|
self.mask = [tokenizer.mask_token_id] |
|
|
|
def build_inputs(self, history, reply, max_length): |
|
raise NotImplementedError |
|
|
|
def mask_seq(self, sequence, seq_id): |
|
sequence[seq_id] = self.mask |
|
return sequence |
|
|
|
@classmethod |
|
def _combine_sequence(self, history, reply, max_length, flipped=False): |
|
|
|
history = [s[:max_length] for s in history] |
|
reply = reply[:max_length] |
|
if flipped: |
|
return [reply] + history |
|
return history + [reply] |
|
|
|
|
|
class BertInputBuilder(InputBuilder): |
|
"""Processor for BERT inputs""" |
|
|
|
def __init__(self, tokenizer): |
|
InputBuilder.__init__(self, tokenizer) |
|
self.cls = [tokenizer.cls_token_id] |
|
self.sep = [tokenizer.sep_token_id] |
|
self.model_inputs = ["input_ids", "token_type_ids", "attention_mask"] |
|
self.padded_inputs = ["input_ids", "token_type_ids"] |
|
self.flipped = False |
|
|
|
|
|
def build_inputs(self, history, reply, max_length, input_str=True): |
|
"""See base class.""" |
|
if input_str: |
|
history = [self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(t)) for t in history] |
|
reply = self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(reply)) |
|
sequence = self._combine_sequence(history, reply, max_length, self.flipped) |
|
sequence = [s + self.sep for s in sequence] |
|
sequence[0] = self.cls + sequence[0] |
|
|
|
instance = {} |
|
instance["input_ids"] = list(chain(*sequence)) |
|
last_speaker = 0 |
|
other_speaker = 1 |
|
seq_length = len(sequence) |
|
instance["token_type_ids"] = [last_speaker if ((seq_length - i) % 2 == 1) else other_speaker |
|
for i, s in enumerate(sequence) for _ in s] |
|
return instance |