|
import gradio as gr |
|
from collections import defaultdict |
|
from transformers import BertTokenizer, BertForMaskedLM |
|
import jsonlines |
|
import torch |
|
from src.modeling_bert import EXBertForMaskedLM |
|
from higher.patch import monkeypatch as make_functional |
|
|
|
|
|
|
|
edit_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="E-FB15k237") |
|
edit_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="E-FB15k237") |
|
|
|
edit_learner = torch.load("./learner_checkpoint/edit/learner_params.pt", map_location=torch.device('cpu')) |
|
add_learner = torch.load("./learner_checkpoint/add/learner_params.pt", map_location=torch.device('cpu')) |
|
|
|
add_origin_model = BertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="A-FB15k237") |
|
add_ex_model = EXBertForMaskedLM.from_pretrained(pretrained_model_name_or_path="zjunlp/KGEditor", subfolder="A-FB15k237") |
|
|
|
|
|
ent_name2id = defaultdict(str) |
|
id2ent_name = defaultdict(str) |
|
rel_name2id = defaultdict(str) |
|
id2ent_text = defaultdict(str) |
|
id2rel_text = defaultdict(str) |
|
corrupt_triple = defaultdict(list) |
|
|
|
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
add_tokenizer = BertTokenizer.from_pretrained(pretrained_model_name_or_path='zjunlp/KGEditor', subfolder="E-FB15k237") |
|
|
|
def init_triple_input(): |
|
global ent2token |
|
global ent2id |
|
global id2ent |
|
global rel2token |
|
|
|
with open("./dataset/fb15k237/relations.txt", "r") as f: |
|
lines = f.readlines() |
|
relations = [] |
|
for line in lines: |
|
relations.append(line.strip().split('\t')[0]) |
|
|
|
rel2token = {ent: f"[RELATION_{i}]" for i, ent in enumerate(relations)} |
|
|
|
with open("./dataset/fb15k237/entity2text.txt", "r") as f: |
|
for line in f.readlines(): |
|
id, name = line.rstrip('\n').split('\t') |
|
ent_name2id[name] = id |
|
id2ent_name[id] = name |
|
|
|
with open("./dataset/fb15k237/relation2text.txt", "r") as f: |
|
for line in f.readlines(): |
|
id, name = line.rstrip('\n').split('\t') |
|
rel_name2id[name] = id |
|
id2rel_text[id] = name |
|
|
|
with open("./dataset/fb15k237/entity2textlong.txt", "r") as f: |
|
for line in f.readlines(): |
|
id, text = line.rstrip('\n').split('\t') |
|
id2ent_text[id] = text.replace("\\n", " ").replace("\\", "") |
|
|
|
entities = list(id2ent_text.keys()) |
|
ent2token = {ent: f"[ENTITY_{i}]" for i, ent in enumerate(entities)} |
|
ent2id = {ent: i for i, ent in enumerate(entities)} |
|
id2ent = {i: ent for i, ent in enumerate(entities)} |
|
|
|
with jsonlines.open("./dataset/fb15k237/edit_test.jsonl") as f: |
|
lines = [] |
|
for d in f: |
|
corrupt_triple[" ".join(d["ori"])] = d["cor"] |
|
|
|
def solve(triple, alter_label, edit_task): |
|
h, r, t = triple.split("|") |
|
if h == "[MASK]": |
|
text_a = "[MASK]" |
|
text_b = id2rel_text[r] + " " + rel2token[r] |
|
text_c = ent2token[ent_name2id[t]] + " " + id2ent_text[ent_name2id[t]] |
|
origin_label = corrupt_triple[" ".join([ent_name2id[alter_label], r, ent_name2id[t]])][0] if edit_task else ent_name2id[alter_label] |
|
else: |
|
text_a = ent2token[ent_name2id[h]] |
|
|
|
text_b = id2rel_text[r] + " " + rel2token[r] |
|
text_c = "[MASK]" + " " + id2ent_text[ent_name2id[h]] |
|
origin_label = corrupt_triple[" ".join([ent_name2id[h], r, ent_name2id[alter_label]])][2] if edit_task else ent_name2id[alter_label] |
|
|
|
if text_a == "[MASK]": |
|
input_text_a = tokenizer.sep_token.join(["[MASK]", id2rel_text[r] + "[PAD]"]) |
|
input_text_b = "[PAD]" + " " + id2ent_text[ent_name2id[t]] |
|
else: |
|
input_text_a = "[PAD] " |
|
input_text_b = tokenizer.sep_token.join([id2rel_text[r] + "[PAD]", "[MASK]" + " " + id2ent_text[ent_name2id[h]]]) |
|
|
|
cond_inputs_text = "{} >> {} || {}".format( |
|
add_tokenizer.added_tokens_decoder[ent2id[origin_label] + len(tokenizer)], |
|
add_tokenizer.added_tokens_decoder[ent2id[ent_name2id[alter_label]] + len(tokenizer)], |
|
input_text_a + input_text_b |
|
) |
|
|
|
inputs = tokenizer( |
|
f"{text_a} [SEP] {text_b} [SEP] {text_c}", |
|
truncation="longest_first", |
|
max_length=64, |
|
padding="longest", |
|
add_special_tokens=True, |
|
) |
|
|
|
edit_inputs = tokenizer( |
|
input_text_a, |
|
input_text_b, |
|
truncation="longest_first", |
|
max_length=64, |
|
padding="longest", |
|
add_special_tokens=True, |
|
) |
|
|
|
cond_inputs = tokenizer( |
|
cond_inputs_text, |
|
truncation=True, |
|
max_length=64, |
|
padding="max_length", |
|
add_special_tokens=True, |
|
) |
|
|
|
inputs = { |
|
"input_ids": torch.tensor(inputs["input_ids"]).unsqueeze(dim=0), |
|
"attention_mask": torch.tensor(inputs["attention_mask"]).unsqueeze(dim=0), |
|
"token_type_ids": torch.tensor(inputs["token_type_ids"]).unsqueeze(dim=0) |
|
} |
|
|
|
edit_inputs = { |
|
"input_ids": torch.tensor(edit_inputs["input_ids"]).unsqueeze(dim=0), |
|
"attention_mask": torch.tensor(edit_inputs["attention_mask"]).unsqueeze(dim=0), |
|
"token_type_ids": torch.tensor(edit_inputs["token_type_ids"]).unsqueeze(dim=0) |
|
} |
|
|
|
cond_inputs = { |
|
"input_ids": torch.tensor(cond_inputs["input_ids"]).unsqueeze(dim=0), |
|
"attention_mask": torch.tensor(cond_inputs["attention_mask"]).unsqueeze(dim=0), |
|
"token_type_ids": torch.tensor(cond_inputs["token_type_ids"]).unsqueeze(dim=0) |
|
} |
|
|
|
return inputs, cond_inputs, edit_inputs |
|
|
|
def get_logits_orig_params_dict(inputs, cond_inputs, alter_label, ex_model, learner): |
|
with torch.enable_grad(): |
|
logits = ex_model.eval()( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
).logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = inputs['input_ids'] |
|
_, mask_idx = (input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True) |
|
mask_logits = logits[:, mask_idx, 30522:45473].squeeze(dim=0) |
|
|
|
grads = torch.autograd.grad( |
|
|
|
torch.nn.functional.cross_entropy( |
|
mask_logits[-1:, :], |
|
torch.tensor([alter_label]), |
|
reduction="none", |
|
).mean(-1), |
|
ex_model.parameters(), |
|
) |
|
|
|
grads = { |
|
name: grad |
|
for (name, _), grad in zip(ex_model.named_parameters(), grads) |
|
} |
|
|
|
|
|
params_dict = learner( |
|
cond_inputs["input_ids"][-1:], |
|
cond_inputs["attention_mask"][-1:], |
|
grads=grads, |
|
) |
|
|
|
return params_dict |
|
|
|
def edit_process(edit_input, alter_label): |
|
inputs, cond_inputs, edit_inputs = solve(edit_input, alter_label, edit_task=True) |
|
|
|
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True) |
|
logits = edit_origin_model(**inputs).logits[:, :, 30522:45473].squeeze() |
|
logits = logits[mask_idx, :] |
|
|
|
|
|
_, origin_entity_order = torch.sort(logits, dim=1, descending=True) |
|
origin_entity_order = origin_entity_order.squeeze(dim=0) |
|
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)] |
|
|
|
|
|
fmodel = make_functional(edit_ex_model).eval() |
|
params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], edit_ex_model, edit_learner) |
|
edit_logits = fmodel( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
|
|
params=[ |
|
params_dict.get(n, 0) + p |
|
for n, p in edit_ex_model.named_parameters() |
|
], |
|
).logits[:, :, 30522:45473].squeeze() |
|
|
|
edit_logits = edit_logits[mask_idx, :] |
|
_, edit_entity_order = torch.sort(edit_logits, dim=1, descending=True) |
|
edit_entity_order = edit_entity_order.squeeze(dim=0) |
|
edit_top3 = [id2ent_name[id2ent[edit_entity_order[i].item()]] for i in range(3)] |
|
|
|
return "\n".join(origin_top3), "\n".join(edit_top3) |
|
|
|
def add_process(edit_input, alter_label): |
|
inputs, cond_inputs, add_inputs = solve(edit_input, alter_label, edit_task=False) |
|
|
|
_, mask_idx = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True) |
|
logits = add_origin_model(**inputs).logits[:, :, 30522:45473].squeeze() |
|
logits = logits[mask_idx, :] |
|
|
|
|
|
_, origin_entity_order = torch.sort(logits, dim=1, descending=True) |
|
origin_entity_order = origin_entity_order.squeeze(dim=0) |
|
origin_top3 = [id2ent_name[id2ent[origin_entity_order[i].item()]] for i in range(3)] |
|
|
|
|
|
fmodel = make_functional(add_ex_model).eval() |
|
params_dict = get_logits_orig_params_dict(inputs, cond_inputs, ent2id[ent_name2id[alter_label]], add_ex_model, add_learner) |
|
add_logits = fmodel( |
|
input_ids=inputs["input_ids"], |
|
attention_mask=inputs["attention_mask"], |
|
|
|
params=[ |
|
params_dict.get(n, 0) + p |
|
for n, p in add_ex_model.named_parameters() |
|
], |
|
).logits[:, :, 30522:45473].squeeze() |
|
|
|
add_logits = add_logits[mask_idx, :] |
|
_, add_entity_order = torch.sort(add_logits, dim=1, descending=True) |
|
add_entity_order = add_entity_order.squeeze(dim=0) |
|
add_top3 = [id2ent_name[id2ent[add_entity_order[i].item()]] for i in range(3)] |
|
|
|
return "\n".join(origin_top3), "\n".join(add_top3) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
init_triple_input() |
|
|
|
|
|
add_process("Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs") |
|
gr.Markdown("# KGE Editing") |
|
|
|
|
|
with gr.Tabs(): |
|
with gr.TabItem("E-FB15k237"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
edit_input = gr.Textbox(label="Input", lines=1, placeholder="Mask triple input") |
|
|
|
alter_label = gr.Textbox(label="Alter Entity", lines=1, placeholder="Entity Name") |
|
edit_button = gr.Button("Edit") |
|
|
|
with gr.Column(): |
|
origin_output = gr.Textbox(label="Before Edit", lines=3, placeholder="") |
|
edit_output = gr.Textbox(label="After Edit", lines=3, placeholder="") |
|
|
|
gr.Examples( |
|
examples=[["[MASK]|/people/person/profession|Jack Black", "Kellie Martin"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]], |
|
inputs=[edit_input, alter_label], |
|
outputs=[origin_output, edit_output], |
|
fn=edit_process, |
|
cache_examples=True, |
|
) |
|
|
|
with gr.TabItem("A-FB15k237"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
add_input = gr.Textbox(label="Input", lines=1, placeholder="New triple input") |
|
|
|
inductive_entity = gr.Textbox(label="Inductive Entity", lines=1, placeholder="Entity Name") |
|
add_button = gr.Button("Add") |
|
|
|
with gr.Column(): |
|
add_origin_output = gr.Textbox(label="Origin Results", lines=3, placeholder="") |
|
add_output = gr.Textbox(label="Add Results", lines=3, placeholder="") |
|
|
|
gr.Examples( |
|
examples=[["Jane Wyman|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"], ["Red Skelton|/people/person/places_lived./people/place_lived/location|[MASK]", "Palm Springs"]], |
|
inputs=[add_input, inductive_entity], |
|
outputs=[add_origin_output, add_output], |
|
fn=add_process, |
|
cache_examples=True, |
|
) |
|
|
|
edit_button.click(fn=edit_process, inputs=[edit_input, alter_label], outputs=[origin_output, edit_output]) |
|
add_button.click(fn=add_process, inputs=[add_input, inductive_entity], outputs=[add_origin_output, add_output]) |
|
|
|
demo.launch() |