KoichiYasuoka's picture
model improved
85c9513
raw
history blame
4.98 kB
#! /usr/bin/python3
import os
src="KoichiYasuoka/deberta-large-japanese-wikipedia"
tgt="KoichiYasuoka/deberta-large-japanese-wikipedia-ud-head"
url="https://github.com/UniversalDependencies/UD_Japanese-GSDLUW"
d=os.path.basename(url)
os.system("test -d {} || git clone --depth=1 {}".format(d,url))
os.system("for F in train dev test ; do cp "+d+"/*-$F*.conllu $F.conllu ; done")
from transformers import (AutoTokenizer,AutoModelForQuestionAnswering,
AutoModelForTokenClassification,AutoConfig,DefaultDataCollator,
DataCollatorForTokenClassification,TrainingArguments,Trainer)
class HEADDataset(object):
def __init__(self,conllu,tokenizer,augment=False,length=384):
self.qa,self.pad,self.length=[],tokenizer.pad_token_id,length
with open(conllu,"r",encoding="utf-8") as r:
form,head=[],[]
for t in r:
w=t.split("\t")
if len(w)==10 and w[0].isdecimal():
form.append(w[1])
head.append(len(head) if w[6]=="0" else int(w[6])-1)
elif t.strip()=="" and form!=[]:
v=tokenizer(form,add_special_tokens=False)["input_ids"]
for i,t in enumerate(v):
q=[tokenizer.cls_token_id]+t+[tokenizer.sep_token_id]
c=[q]+v[0:i]+[[tokenizer.mask_token_id]]+v[i+1:]+[[q[-1]]]
b=[len(sum(c[0:j+1],[])) for j in range(len(c))]
if b[-1]<length:
self.qa.append((sum(c,[]),head[i],b))
if augment and [1 for x in v if t==x]==[1]:
c[i+1]=t
b=[len(sum(c[0:j+1],[])) for j in range(len(c))]
if b[-1]<length:
self.qa.append((sum(c,[]),head[i],b))
form,head=[],[]
__len__=lambda self:len(self.qa)
def __getitem__(self,i):
(v,h,b),k=self.qa[i],self.length-self.qa[i][2][-1]
return {"input_ids":v+[self.pad]*k,"attention_mask":[1]*b[-1]+[0]*k,
"token_type_ids":[0]*b[0]+[1]*(b[-1]-b[0])+[0]*k,
"start_positions":b[h],"end_positions":b[h+1]-1}
class UPOSDataset(object):
def __init__(self,conllu,tokenizer,fields=[3]):
self.ids,self.upos=[],[]
label,cls,sep=set(),tokenizer.cls_token_id,tokenizer.sep_token_id
with open(conllu,"r",encoding="utf-8") as r:
form,upos=[],[]
for t in r:
w=t.split("\t")
if len(w)==10 and w[0].isdecimal():
form.append(w[1])
upos.append("|".join(w[i] for i in fields))
elif t.strip()=="" and form!=[]:
v,u=tokenizer(form,add_special_tokens=False)["input_ids"],[]
for x,y in zip(v,upos):
u.extend(["B-"+y]*min(len(x),1)+["I-"+y]*(len(x)-1))
if len(u)>tokenizer.model_max_length-4:
self.ids.append(sum(v,[])[0:tokenizer.model_max_length-2])
self.upos.append(u[0:tokenizer.model_max_length-2])
elif len(u)>0:
self.ids.append([cls]+sum(v,[])+[sep])
self.upos.append([u[0]]+u+[u[0]])
label=set(sum([self.upos[-1],list(label)],[]))
form,upos=[],[]
self.label2id={l:i for i,l in enumerate(sorted(label))}
def __call__(*args):
label=set(sum([list(t.label2id) for t in args],[]))
lid={l:i for i,l in enumerate(sorted(label))}
for t in args:
t.label2id=lid
return lid
__len__=lambda self:len(self.ids)
__getitem__=lambda self,i:{"input_ids":self.ids[i],
"labels":[self.label2id[t] for t in self.upos[i]]}
tkz=AutoTokenizer.from_pretrained(src)
trainDS=HEADDataset("train.conllu",tkz,True)
devDS=HEADDataset("dev.conllu",tkz)
arg=TrainingArguments(num_train_epochs=3,per_device_train_batch_size=8,
output_dir="/tmp",overwrite_output_dir=True,save_total_limit=2,
evaluation_strategy="epoch",learning_rate=5e-05,warmup_ratio=0.1)
trn=Trainer(args=arg,data_collator=DefaultDataCollator(),
model=AutoModelForQuestionAnswering.from_pretrained(src),
train_dataset=trainDS,eval_dataset=devDS)
trn.train()
trn.save_model(tgt)
tkz.save_pretrained(tgt)
trainDS=UPOSDataset("train.conllu",tkz,[7])
devDS=UPOSDataset("dev.conllu",tkz,[7])
testDS=UPOSDataset("test.conllu",tkz,[7])
lid=trainDS(devDS,testDS)
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid,
id2label={i:l for l,i in lid.items()})
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz),
model=AutoModelForTokenClassification.from_pretrained(src,config=cfg),
train_dataset=trainDS,eval_dataset=devDS)
trn.train()
trn.save_model(tgt+"/deprel")
tkz.save_pretrained(tgt+"/deprel")
trainDS=UPOSDataset("train.conllu",tkz,[3,5])
devDS=UPOSDataset("dev.conllu",tkz,[3,5])
testDS=UPOSDataset("test.conllu",tkz,[3,5])
lid=trainDS(devDS,testDS)
cfg=AutoConfig.from_pretrained(src,num_labels=len(lid),label2id=lid,
id2label={i:l for l,i in lid.items()})
trn=Trainer(args=arg,data_collator=DataCollatorForTokenClassification(tkz),
model=AutoModelForTokenClassification.from_pretrained(src,config=cfg),
train_dataset=trainDS,eval_dataset=devDS)
trn.train()
trn.save_model(tgt+"/tagger")
tkz.save_pretrained(tgt+"/tagger")