Spaces:
Sleeping
Sleeping
import argparse | |
import torch | |
from reader.data.relik_reader_sample import load_relik_reader_samples | |
from relik.reader.pytorch_modules.hf.modeling_relik import ( | |
RelikReaderConfig, | |
RelikReaderREModel, | |
) | |
from relik.reader.relik_reader_re import RelikReaderForTripletExtraction | |
from relik.reader.utils.relation_matching_eval import StrongMatching | |
dict_nyt = { | |
"/people/person/nationality": "nationality", | |
"/sports/sports_team/location": "sports team location", | |
"/location/country/administrative_divisions": "administrative divisions", | |
"/business/company/major_shareholders": "shareholders", | |
"/people/ethnicity/people": "ethnicity", | |
"/people/ethnicity/geographic_distribution": "geographic distributi6on", | |
"/business/company_shareholder/major_shareholder_of": "major shareholder", | |
"/location/location/contains": "location", | |
"/business/company/founders": "founders", | |
"/business/person/company": "company", | |
"/business/company/advisors": "advisor", | |
"/people/deceased_person/place_of_death": "place of death", | |
"/business/company/industry": "industry", | |
"/people/person/ethnicity": "ethnic background", | |
"/people/person/place_of_birth": "place of birth", | |
"/location/administrative_division/country": "country of an administration division", | |
"/people/person/place_lived": "place lived", | |
"/sports/sports_team_location/teams": "sports team", | |
"/people/person/children": "child", | |
"/people/person/religion": "religion", | |
"/location/neighborhood/neighborhood_of": "neighborhood", | |
"/location/country/capital": "capital", | |
"/business/company/place_founded": "company founded location", | |
"/people/person/profession": "occupation", | |
} | |
def eval(model_path, data_path, is_eval, output_path=None): | |
if model_path.endswith(".ckpt"): | |
# if it is a lightning checkpoint we load the model state dict and the tokenizer from the config | |
model_dict = torch.load(model_path) | |
additional_special_symbols = model_dict["hyper_parameters"][ | |
"additional_special_symbols" | |
] | |
from transformers import AutoTokenizer | |
from relik.reader.utils.special_symbols import get_special_symbols_re | |
special_symbols = get_special_symbols_re(additional_special_symbols - 1) | |
tokenizer = AutoTokenizer.from_pretrained( | |
model_dict["hyper_parameters"]["transformer_model"], | |
additional_special_tokens=special_symbols, | |
add_prefix_space=True, | |
) | |
config_model = RelikReaderConfig( | |
model_dict["hyper_parameters"]["transformer_model"], | |
len(special_symbols), | |
training=False, | |
) | |
model = RelikReaderREModel(config_model) | |
model_dict["state_dict"] = { | |
k.replace("relik_reader_re_model.", ""): v | |
for k, v in model_dict["state_dict"].items() | |
} | |
model.load_state_dict(model_dict["state_dict"], strict=False) | |
reader = RelikReaderForTripletExtraction( | |
model, training=False, device="cuda", tokenizer=tokenizer | |
) | |
else: | |
# if it is a huggingface model we load the model directly. Note that it could even be a string from the hub | |
model = RelikReaderREModel.from_pretrained(model_path) | |
reader = RelikReaderForTripletExtraction(model, training=False, device="cuda") | |
samples = list(load_relik_reader_samples(data_path)) | |
for sample in samples: | |
sample.candidates = [dict_nyt[cand] for cand in sample.candidates] | |
sample.triplets = [ | |
{ | |
"subject": triplet["subject"], | |
"relation": { | |
"name": dict_nyt[triplet["relation"]["name"]], | |
"type": triplet["relation"]["type"], | |
}, | |
"object": triplet["object"], | |
} | |
for triplet in sample.triplets | |
] | |
predicted_samples = reader.read(samples=samples, progress_bar=True) | |
if is_eval: | |
strong_matching_metric = StrongMatching() | |
predicted_samples = list(predicted_samples) | |
for k, v in strong_matching_metric(predicted_samples).items(): | |
print(f"test_{k}", v) | |
if output_path is not None: | |
with open(output_path, "w") as f: | |
for sample in predicted_samples: | |
f.write(sample.to_jsons() + "\n") | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--model_path", | |
type=str, | |
default="/home/huguetcabot/alby-re/relik/relik/reader/models/relik_re_reader_base", | |
) | |
parser.add_argument( | |
"--data_path", | |
type=str, | |
default="/home/huguetcabot/alby-re/relik/relik/reader/data/testa.jsonl", | |
) | |
parser.add_argument("--is-eval", action="store_true") | |
parser.add_argument("--output_path", type=str, default=None) | |
args = parser.parse_args() | |
eval(args.model_path, args.data_path, args.is_eval, args.output_path) | |
if __name__ == "__main__": | |
main() | |