File size: 5,436 Bytes
0c954bd
7a1c034
 
 
2059531
 
7a1c034
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e372099
7a1c034
 
 
 
 
 
e372099
7a1c034
 
 
e372099
 
d7c91d6
7a1c034
e372099
a1b42e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e372099
eaa2345
02f29aa
 
 
eaa2345
5fcc4f7
 
 
eaa2345
5fcc4f7
 
a1b42e5
 
 
 
5fcc4f7
 
eaa2345
 
e372099
 
 
eaa2345
 
 
e372099
7a1c034
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


def getScores(ids, scores, pad_token_id):
    """get sequence scores from model.generate output"""
    scores = torch.stack(scores, dim=1)
    log_probs = torch.log_softmax(scores, dim=2)
    # remove start token
    ids = ids[:,1:]
    # gather needed probs
    x = ids.unsqueeze(-1).expand(log_probs.shape)
    needed_logits = torch.gather(log_probs, 2, x)
    final_logits = needed_logits[:, :, 0]
    padded_mask = (ids == pad_token_id)
    final_logits[padded_mask] = 0
    final_scores = final_logits.sum(dim=-1)
    return final_scores.cpu().detach().numpy()

def topkSample(input, model, tokenizer, 
                num_samples=5,
                num_beams=1,
                max_output_length=30):
    tokenized = tokenizer(input, return_tensors="pt")
    out = model.generate(**tokenized,
                        do_sample=True,
                        num_return_sequences = num_samples,
                        num_beams = num_beams,
                        eos_token_id = tokenizer.eos_token_id,
                        pad_token_id = tokenizer.pad_token_id,
                        output_scores = True,
                        return_dict_in_generate=True,
                        max_length=max_output_length,)
    out_tokens = out.sequences
    out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
    out_scores = getScores(out_tokens, out.scores, tokenizer.pad_token_id)
    
    pair_list = [(x[0], x[1]) for x in zip(out_str, out_scores)]
    sorted_pair_list = sorted(pair_list, key=lambda x:x[1], reverse=True)
    return sorted_pair_list

def greedyPredict(input, model, tokenizer):
    input_ids = tokenizer([input], return_tensors="pt").input_ids
    out_tokens = model.generate(input_ids)
    out_str = tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
    return out_str[0]
    
def predict_tail(entity, relation):
    global model, tokenizer
    input = entity + "| " + relation
    out = topkSample(input, model, tokenizer, num_samples=25)
    out_dict = {}
    for k, v in out:
        out_dict[k] = np.exp(v).item()
    return out_dict

    
tokenizer = AutoTokenizer.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")
model = AutoModelForSeq2SeqLM.from_pretrained("apoorvumang/kgt5-base-wikikg90mv2")


ent_input = gr.inputs.Textbox(lines=1, default="Apoorv Umang Saxena")
rel_input = gr.inputs.Textbox(lines=1, default="country")
output = gr.outputs.Label()

examples = [
['Adrian Kochsiek', 'sex or gender'],
['Apoorv Umang Saxena', 'family name'],
['World War II', 'followed by'],
['Apoorv Umang Saxena', 'country'],
['Ippolito Boccolini', 'writing language'] ,
['Roelant', 'writing system'] ,
['The Accountant 2227', 'language of work or name'] ,
['Microbial Infection and AMR in Hospitalized Patients With Covid 19', 'study type'] ,
['Carla Fracci', 'manner of death'] ,
['list of programs broadcast by Comet', 'is a list of'] ,
['Loreta Podhradí', 'continent'] ,
['Opistognathotrema', 'taxon rank'] ,
['Museum Arbeitswelt Steyr', 'wheelchair accessibility'] ,
['Heliotropium tytoides', 'subject has role'] ,
['School bus crash rates on routine and nonroutine routes.', 'sponsor'] ,
['Tachigalieae', 'taxon rank'] ,
['Irena Salusová', 'place of detention'] ,

]
title = "Interactive demo: KGT5"
description = """Demo for <a href='https://arxiv.org/abs/2203.10321'>Sequence-to-Sequence Knowledge Graph Completion and Question Answering </a> (KGT5). This particular model is a T5-base model trained on the task of tail prediction on WikiKG90Mv2 dataset and obtains 0.239 validation MRR on this task (<a href="https://ogb.stanford.edu/docs/lsc/leaderboards/#wikikg90mv2">leaderboard</a>, see paper for details).
 To use it, simply give an entity name and relation and click 'submit'. Upto 25 model predictions will show up in a few seconds. The model works best when the exact entity/relation names that it has been trained on are used. 
 It is sometimes able to generalize to unseen entities as well (see examples).
"""
#article = """
#<p style='text-align: center'><a href='https://arxiv.org/abs/2203.10321'>Sequence-to-Sequence Knowledge Graph Completion and Question Answering </a> | <a href='https://github.com/apoorvumang/kgt5'>Github Repo</a></p>
#"""

article = """
Under the hood, this demo concatenates the entity and relation, feeds it to the model and then samples 25 sequences, which are then ranked according to their sequence probabilities.
<br>
The text representations of the relations and entities can be downloaded from here: <a href="https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle">https://storage.googleapis.com/kgt5-wikikg90mv2/rel_alias_list.pickle</a> and 
<a href="https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle">https://storage.googleapis.com/kgt5-wikikg90mv2/ent_alias_list.pickle</a>
<br>
For more details see the <a href='https://github.com/apoorvumang/kgt5'>Github repo</a> or the <a href="https://huggingface.co/apoorvumang/kgt5-base-wikikg90mv2">hf model page</a>.
"""


iface = gr.Interface(fn=predict_tail, 
                     inputs=[ent_input, rel_input], 
                     outputs=output,
                     title=title,
                     description=description,
                     article=article,
                     examples=examples,)
iface.launch()