G2PTL Init
Browse files- G2PTL_utils.py +1542 -0
- Images/HTC.jpg +0 -0
- Images/Model.jpg +0 -0
- README.md +166 -0
- chn_2_code.pkl +3 -0
- config.json +33 -0
- configuration_G2PTL.py +97 -0
- graphormer.py +346 -0
- htc_loss.py +135 -0
- htc_mask_dict.pkl +3 -0
- modeling_G2PTL.py +1024 -0
- pytorch_model.bin +3 -0
- remap_code_2_chn.pkl +3 -0
- requirements.txt +5 -0
- special_tokens_map.json +7 -0
- tokenizer_config.json +15 -0
- vocab.txt +0 -0
G2PTL_utils.py
ADDED
@@ -0,0 +1,1542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch.utils.checkpoint
|
5 |
+
from torch import nn
|
6 |
+
from transformers.utils import logging
|
7 |
+
import inspect
|
8 |
+
from typing import Set, Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
|
9 |
+
import re
|
10 |
+
import math
|
11 |
+
from typing import Optional, Tuple
|
12 |
+
from transformers.models.ernie.modeling_ernie import *
|
13 |
+
import torch
|
14 |
+
from fairseq import utils
|
15 |
+
from fairseq.modules.fairseq_dropout import FairseqDropout
|
16 |
+
from fairseq.modules.quant_noise import quant_noise
|
17 |
+
from torch import Tensor, nn
|
18 |
+
from torch.hub import load_state_dict_from_url
|
19 |
+
import torch.distributed as dist
|
20 |
+
|
21 |
+
|
22 |
+
from torch.hub import load_state_dict_from_url
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
PRETRAINED_MODEL_URLS = {
|
26 |
+
"pcqm4mv1_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv1.pt",
|
27 |
+
"pcqm4mv2_graphormer_base":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_best_pcqm4mv2.pt",
|
28 |
+
"oc20is2re_graphormer3d_base":"https://szheng.blob.core.windows.net/graphormer/modelzoo/oc20is2re/checkpoint_last_oc20_is2re.pt", # this pretrained model is temporarily unavailable
|
29 |
+
"pcqm4mv1_graphormer_base_for_molhiv":"https://ml2md.blob.core.windows.net/graphormer-ckpts/checkpoint_base_preln_pcqm4mv1_for_hiv.pt",
|
30 |
+
}
|
31 |
+
|
32 |
+
def load_pretrained_model(pretrained_model_name):
|
33 |
+
if pretrained_model_name not in PRETRAINED_MODEL_URLS:
|
34 |
+
raise ValueError("Unknown pretrained model name %s", pretrained_model_name)
|
35 |
+
if not dist.is_initialized():
|
36 |
+
return load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True)["model"]
|
37 |
+
else:
|
38 |
+
pretrained_model = load_state_dict_from_url(PRETRAINED_MODEL_URLS[pretrained_model_name], progress=True, file_name=f"{pretrained_model_name}_{dist.get_rank()}")["model"]
|
39 |
+
dist.barrier()
|
40 |
+
return pretrained_model
|
41 |
+
|
42 |
+
|
43 |
+
class MultiheadAttention(nn.Module):
|
44 |
+
"""Multi-headed attention.
|
45 |
+
|
46 |
+
See "Attention Is All You Need" for more details.
|
47 |
+
"""
|
48 |
+
|
49 |
+
def __init__(
|
50 |
+
self,
|
51 |
+
embed_dim,
|
52 |
+
num_heads,
|
53 |
+
kdim=None,
|
54 |
+
vdim=None,
|
55 |
+
dropout=0.0,
|
56 |
+
bias=True,
|
57 |
+
self_attention=False,
|
58 |
+
q_noise=0.0,
|
59 |
+
qn_block_size=8,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
64 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
65 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
66 |
+
|
67 |
+
self.num_heads = num_heads
|
68 |
+
self.dropout_module = FairseqDropout(
|
69 |
+
dropout, module_name=self.__class__.__name__
|
70 |
+
)
|
71 |
+
|
72 |
+
self.head_dim = embed_dim // num_heads
|
73 |
+
assert (
|
74 |
+
self.head_dim * num_heads == self.embed_dim
|
75 |
+
), "embed_dim must be divisible by num_heads"
|
76 |
+
self.scaling = self.head_dim ** -0.5
|
77 |
+
|
78 |
+
self.self_attention = self_attention
|
79 |
+
|
80 |
+
assert self.self_attention, "Only support self attention"
|
81 |
+
|
82 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
83 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
84 |
+
)
|
85 |
+
|
86 |
+
self.k_proj = quant_noise(
|
87 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
88 |
+
)
|
89 |
+
self.v_proj = quant_noise(
|
90 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
91 |
+
)
|
92 |
+
self.q_proj = quant_noise(
|
93 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
94 |
+
)
|
95 |
+
|
96 |
+
self.out_proj = quant_noise(
|
97 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
98 |
+
)
|
99 |
+
|
100 |
+
self.reset_parameters()
|
101 |
+
|
102 |
+
self.onnx_trace = False
|
103 |
+
|
104 |
+
def prepare_for_onnx_export_(self):
|
105 |
+
raise NotImplementedError
|
106 |
+
|
107 |
+
def reset_parameters(self):
|
108 |
+
if self.qkv_same_dim:
|
109 |
+
# Empirically observed the convergence to be much better with
|
110 |
+
# the scaled initialization
|
111 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
112 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
113 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
114 |
+
else:
|
115 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
116 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
117 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
118 |
+
|
119 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
120 |
+
if self.out_proj.bias is not None:
|
121 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
122 |
+
|
123 |
+
def forward(
|
124 |
+
self,
|
125 |
+
query,
|
126 |
+
key: Optional[Tensor],
|
127 |
+
value: Optional[Tensor],
|
128 |
+
attn_bias: Optional[Tensor],
|
129 |
+
key_padding_mask: Optional[Tensor] = None,
|
130 |
+
need_weights: bool = True,
|
131 |
+
attn_mask: Optional[Tensor] = None,
|
132 |
+
before_softmax: bool = False,
|
133 |
+
need_head_weights: bool = False,
|
134 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
135 |
+
"""Input shape: Time x Batch x Channel
|
136 |
+
|
137 |
+
Args:
|
138 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
139 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
140 |
+
padding elements are indicated by 1s.
|
141 |
+
need_weights (bool, optional): return the attention weights,
|
142 |
+
averaged over heads (default: False).
|
143 |
+
attn_mask (ByteTensor, optional): typically used to
|
144 |
+
implement causal attention, where the mask prevents the
|
145 |
+
attention from looking forward in time (default: None).
|
146 |
+
before_softmax (bool, optional): return the raw attention
|
147 |
+
weights and values before the attention softmax.
|
148 |
+
need_head_weights (bool, optional): return the attention
|
149 |
+
weights for each head. Implies *need_weights*. Default:
|
150 |
+
return the average attention weights over all heads.
|
151 |
+
"""
|
152 |
+
if need_head_weights:
|
153 |
+
need_weights = True
|
154 |
+
|
155 |
+
tgt_len, bsz, embed_dim = query.size()
|
156 |
+
src_len = tgt_len
|
157 |
+
assert embed_dim == self.embed_dim, f"query dim {embed_dim} != {self.embed_dim}"
|
158 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
159 |
+
if key is not None:
|
160 |
+
src_len, key_bsz, _ = key.size()
|
161 |
+
if not torch.jit.is_scripting():
|
162 |
+
assert key_bsz == bsz
|
163 |
+
assert value is not None
|
164 |
+
assert src_len, bsz == value.shape[:2]
|
165 |
+
|
166 |
+
q = self.q_proj(query)
|
167 |
+
k = self.k_proj(query)
|
168 |
+
v = self.v_proj(query)
|
169 |
+
q *= self.scaling
|
170 |
+
|
171 |
+
q = (
|
172 |
+
q.contiguous()
|
173 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
174 |
+
.transpose(0, 1)
|
175 |
+
)
|
176 |
+
if k is not None:
|
177 |
+
k = (
|
178 |
+
k.contiguous()
|
179 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
180 |
+
.transpose(0, 1)
|
181 |
+
)
|
182 |
+
if v is not None:
|
183 |
+
v = (
|
184 |
+
v.contiguous()
|
185 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
186 |
+
.transpose(0, 1)
|
187 |
+
)
|
188 |
+
|
189 |
+
assert k is not None
|
190 |
+
assert k.size(1) == src_len
|
191 |
+
|
192 |
+
# This is part of a workaround to get around fork/join parallelism
|
193 |
+
# not supporting Optional types.
|
194 |
+
if key_padding_mask is not None and key_padding_mask.dim() == 0:
|
195 |
+
key_padding_mask = None
|
196 |
+
|
197 |
+
if key_padding_mask is not None:
|
198 |
+
assert key_padding_mask.size(0) == bsz
|
199 |
+
assert key_padding_mask.size(1) == src_len
|
200 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
201 |
+
attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz)
|
202 |
+
|
203 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
204 |
+
|
205 |
+
if attn_bias is not None:
|
206 |
+
attn_weights += attn_bias.view(bsz * self.num_heads, tgt_len, src_len)
|
207 |
+
|
208 |
+
if attn_mask is not None:
|
209 |
+
attn_mask = attn_mask.unsqueeze(0)
|
210 |
+
attn_weights += attn_mask
|
211 |
+
|
212 |
+
if key_padding_mask is not None:
|
213 |
+
# don't attend to padding symbols
|
214 |
+
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
215 |
+
attn_weights = attn_weights.masked_fill(
|
216 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
|
217 |
+
float("-inf"),
|
218 |
+
)
|
219 |
+
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
220 |
+
|
221 |
+
if before_softmax:
|
222 |
+
return attn_weights, v
|
223 |
+
|
224 |
+
attn_weights_float = utils.softmax(
|
225 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
226 |
+
)
|
227 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
228 |
+
attn_probs = self.dropout_module(attn_weights)
|
229 |
+
|
230 |
+
assert v is not None
|
231 |
+
attn = torch.bmm(attn_probs, v)
|
232 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
233 |
+
|
234 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
235 |
+
attn = self.out_proj(attn)
|
236 |
+
|
237 |
+
attn_weights: Optional[Tensor] = None
|
238 |
+
if need_weights:
|
239 |
+
attn_weights = attn_weights_float.view(
|
240 |
+
bsz, self.num_heads, tgt_len, src_len
|
241 |
+
).transpose(1, 0)
|
242 |
+
if not need_head_weights:
|
243 |
+
# average attention weights over heads
|
244 |
+
attn_weights = attn_weights.mean(dim=0)
|
245 |
+
|
246 |
+
return attn, attn_weights
|
247 |
+
|
248 |
+
def apply_sparse_mask(self, attn_weights, tgt_len: int, src_len: int, bsz: int):
|
249 |
+
return attn_weights
|
250 |
+
|
251 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
252 |
+
prefix = name + "." if name != "" else ""
|
253 |
+
items_to_add = {}
|
254 |
+
keys_to_remove = []
|
255 |
+
for k in state_dict.keys():
|
256 |
+
if k.endswith(prefix + "in_proj_weight"):
|
257 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
258 |
+
dim = int(state_dict[k].shape[0] / 3)
|
259 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
260 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
261 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
262 |
+
|
263 |
+
keys_to_remove.append(k)
|
264 |
+
|
265 |
+
k_bias = prefix + "in_proj_bias"
|
266 |
+
if k_bias in state_dict.keys():
|
267 |
+
dim = int(state_dict[k].shape[0] / 3)
|
268 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
269 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
270 |
+
dim : 2 * dim
|
271 |
+
]
|
272 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
273 |
+
|
274 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
275 |
+
|
276 |
+
for k in keys_to_remove:
|
277 |
+
del state_dict[k]
|
278 |
+
|
279 |
+
for key, value in items_to_add.items():
|
280 |
+
state_dict[key] = value
|
281 |
+
|
282 |
+
|
283 |
+
def init_graphormer_params(module):
|
284 |
+
"""
|
285 |
+
Initialize the weights specific to the Graphormer Model.
|
286 |
+
"""
|
287 |
+
|
288 |
+
def normal_(data):
|
289 |
+
# with FSDP, module params will be on CUDA, so we cast them back to CPU
|
290 |
+
# so that the RNG is consistent with and without FSDP
|
291 |
+
data.copy_(data.cpu().normal_(mean=0.0, std=0.02).to(data.device))
|
292 |
+
|
293 |
+
if isinstance(module, nn.Linear):
|
294 |
+
normal_(module.weight.data)
|
295 |
+
if module.bias is not None:
|
296 |
+
module.bias.data.zero_()
|
297 |
+
if isinstance(module, nn.Embedding):
|
298 |
+
normal_(module.weight.data)
|
299 |
+
if module.padding_idx is not None:
|
300 |
+
module.weight.data[module.padding_idx].zero_()
|
301 |
+
if isinstance(module, MultiheadAttention):
|
302 |
+
normal_(module.q_proj.weight.data)
|
303 |
+
normal_(module.k_proj.weight.data)
|
304 |
+
normal_(module.v_proj.weight.data)
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
|
309 |
+
def add_start_docstrings(*docstr):
|
310 |
+
def docstring_decorator(fn):
|
311 |
+
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
312 |
+
return fn
|
313 |
+
|
314 |
+
return docstring_decorator
|
315 |
+
|
316 |
+
|
317 |
+
def add_start_docstrings_to_model_forward(*docstr):
|
318 |
+
def docstring_decorator(fn):
|
319 |
+
docstring = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
|
320 |
+
class_name = f"[`{fn.__qualname__.split('.')[0]}`]"
|
321 |
+
intro = f" The {class_name} forward method, overrides the `__call__` special method."
|
322 |
+
note = r"""
|
323 |
+
|
324 |
+
<Tip>
|
325 |
+
|
326 |
+
Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]
|
327 |
+
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
|
328 |
+
the latter silently ignores them.
|
329 |
+
|
330 |
+
</Tip>
|
331 |
+
"""
|
332 |
+
|
333 |
+
fn.__doc__ = intro + note + docstring
|
334 |
+
return fn
|
335 |
+
|
336 |
+
return docstring_decorator
|
337 |
+
|
338 |
+
|
339 |
+
def add_end_docstrings(*docstr):
|
340 |
+
def docstring_decorator(fn):
|
341 |
+
fn.__doc__ = (fn.__doc__ if fn.__doc__ is not None else "") + "".join(docstr)
|
342 |
+
return fn
|
343 |
+
|
344 |
+
return docstring_decorator
|
345 |
+
|
346 |
+
|
347 |
+
PT_RETURN_INTRODUCTION = r"""
|
348 |
+
Returns:
|
349 |
+
[`{full_output_type}`] or `tuple(torch.FloatTensor)`: A [`{full_output_type}`] or a tuple of
|
350 |
+
`torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various
|
351 |
+
elements depending on the configuration ([`{config_class}`]) and inputs.
|
352 |
+
|
353 |
+
"""
|
354 |
+
|
355 |
+
TF_RETURN_INTRODUCTION = r"""
|
356 |
+
Returns:
|
357 |
+
[`{full_output_type}`] or `tuple(tf.Tensor)`: A [`{full_output_type}`] or a tuple of `tf.Tensor` (if
|
358 |
+
`return_dict=False` is passed or when `config.return_dict=False`) comprising various elements depending on the
|
359 |
+
configuration ([`{config_class}`]) and inputs.
|
360 |
+
|
361 |
+
"""
|
362 |
+
|
363 |
+
|
364 |
+
def _get_indent(t):
|
365 |
+
"""Returns the indentation in the first line of t"""
|
366 |
+
search = re.search(r"^(\s*)\S", t)
|
367 |
+
return "" if search is None else search.groups()[0]
|
368 |
+
|
369 |
+
|
370 |
+
def _convert_output_args_doc(output_args_doc):
|
371 |
+
"""Convert output_args_doc to display properly."""
|
372 |
+
# Split output_arg_doc in blocks argument/description
|
373 |
+
indent = _get_indent(output_args_doc)
|
374 |
+
blocks = []
|
375 |
+
current_block = ""
|
376 |
+
for line in output_args_doc.split("\n"):
|
377 |
+
# If the indent is the same as the beginning, the line is the name of new arg.
|
378 |
+
if _get_indent(line) == indent:
|
379 |
+
if len(current_block) > 0:
|
380 |
+
blocks.append(current_block[:-1])
|
381 |
+
current_block = f"{line}\n"
|
382 |
+
else:
|
383 |
+
# Otherwise it's part of the description of the current arg.
|
384 |
+
# We need to remove 2 spaces to the indentation.
|
385 |
+
current_block += f"{line[2:]}\n"
|
386 |
+
blocks.append(current_block[:-1])
|
387 |
+
|
388 |
+
# Format each block for proper rendering
|
389 |
+
for i in range(len(blocks)):
|
390 |
+
blocks[i] = re.sub(r"^(\s+)(\S+)(\s+)", r"\1- **\2**\3", blocks[i])
|
391 |
+
blocks[i] = re.sub(r":\s*\n\s*(\S)", r" -- \1", blocks[i])
|
392 |
+
|
393 |
+
return "\n".join(blocks)
|
394 |
+
|
395 |
+
|
396 |
+
def _prepare_output_docstrings(output_type, config_class, min_indent=None):
|
397 |
+
"""
|
398 |
+
Prepares the return part of the docstring using `output_type`.
|
399 |
+
"""
|
400 |
+
output_docstring = output_type.__doc__
|
401 |
+
|
402 |
+
# Remove the head of the docstring to keep the list of args only
|
403 |
+
lines = output_docstring.split("\n")
|
404 |
+
i = 0
|
405 |
+
while i < len(lines) and re.search(r"^\s*(Args|Parameters):\s*$", lines[i]) is None:
|
406 |
+
i += 1
|
407 |
+
if i < len(lines):
|
408 |
+
params_docstring = "\n".join(lines[(i + 1):])
|
409 |
+
params_docstring = _convert_output_args_doc(params_docstring)
|
410 |
+
|
411 |
+
# Add the return introduction
|
412 |
+
full_output_type = f"{output_type.__module__}.{output_type.__name__}"
|
413 |
+
intro = TF_RETURN_INTRODUCTION if output_type.__name__.startswith("TF") else PT_RETURN_INTRODUCTION
|
414 |
+
intro = intro.format(full_output_type=full_output_type, config_class=config_class)
|
415 |
+
result = intro + params_docstring
|
416 |
+
|
417 |
+
# Apply minimum indent if necessary
|
418 |
+
if min_indent is not None:
|
419 |
+
lines = result.split("\n")
|
420 |
+
# Find the indent of the first nonempty line
|
421 |
+
i = 0
|
422 |
+
while len(lines[i]) == 0:
|
423 |
+
i += 1
|
424 |
+
indent = len(_get_indent(lines[i]))
|
425 |
+
# If too small, add indentation to all nonempty lines
|
426 |
+
if indent < min_indent:
|
427 |
+
to_add = " " * (min_indent - indent)
|
428 |
+
lines = [(f"{to_add}{line}" if len(line) > 0 else line) for line in lines]
|
429 |
+
result = "\n".join(lines)
|
430 |
+
|
431 |
+
return result
|
432 |
+
|
433 |
+
|
434 |
+
PT_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
435 |
+
Example:
|
436 |
+
|
437 |
+
```python
|
438 |
+
>>> from transformers import {processor_class}, {model_class}
|
439 |
+
>>> import torch
|
440 |
+
|
441 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
442 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
443 |
+
|
444 |
+
>>> inputs = tokenizer(
|
445 |
+
... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="pt"
|
446 |
+
... )
|
447 |
+
|
448 |
+
>>> with torch.no_grad():
|
449 |
+
... logits = model(**inputs).logits
|
450 |
+
|
451 |
+
>>> predicted_token_class_ids = logits.argmax(-1)
|
452 |
+
|
453 |
+
>>> # Note that tokens are classified rather then input words which means that
|
454 |
+
>>> # there might be more predicted token classes than words.
|
455 |
+
>>> # Multiple token classes might account for the same word
|
456 |
+
>>> predicted_tokens_classes = [model.config.id2label[t.item()] for t in predicted_token_class_ids[0]]
|
457 |
+
>>> predicted_tokens_classes
|
458 |
+
{expected_output}
|
459 |
+
```
|
460 |
+
|
461 |
+
```python
|
462 |
+
>>> labels = predicted_token_class_ids
|
463 |
+
>>> loss = model(**inputs, labels=labels).loss
|
464 |
+
>>> round(loss.item(), 2)
|
465 |
+
{expected_loss}
|
466 |
+
```
|
467 |
+
"""
|
468 |
+
|
469 |
+
PT_QUESTION_ANSWERING_SAMPLE = r"""
|
470 |
+
Example:
|
471 |
+
|
472 |
+
```python
|
473 |
+
>>> from transformers import {processor_class}, {model_class}
|
474 |
+
>>> import torch
|
475 |
+
|
476 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
477 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
478 |
+
|
479 |
+
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
480 |
+
|
481 |
+
>>> inputs = tokenizer(question, text, return_tensors="pt")
|
482 |
+
>>> with torch.no_grad():
|
483 |
+
... outputs = model(**inputs)
|
484 |
+
|
485 |
+
>>> answer_start_index = outputs.start_logits.argmax()
|
486 |
+
>>> answer_end_index = outputs.end_logits.argmax()
|
487 |
+
|
488 |
+
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
489 |
+
>>> tokenizer.decode(predict_answer_tokens)
|
490 |
+
{expected_output}
|
491 |
+
```
|
492 |
+
|
493 |
+
```python
|
494 |
+
>>> # target is "nice puppet"
|
495 |
+
>>> target_start_index = torch.tensor([{qa_target_start_index}])
|
496 |
+
>>> target_end_index = torch.tensor([{qa_target_end_index}])
|
497 |
+
|
498 |
+
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
|
499 |
+
>>> loss = outputs.loss
|
500 |
+
>>> round(loss.item(), 2)
|
501 |
+
{expected_loss}
|
502 |
+
```
|
503 |
+
"""
|
504 |
+
|
505 |
+
PT_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
506 |
+
Example of single-label classification:
|
507 |
+
|
508 |
+
```python
|
509 |
+
>>> import torch
|
510 |
+
>>> from transformers import {processor_class}, {model_class}
|
511 |
+
|
512 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
513 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
514 |
+
|
515 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
516 |
+
|
517 |
+
>>> with torch.no_grad():
|
518 |
+
... logits = model(**inputs).logits
|
519 |
+
|
520 |
+
>>> predicted_class_id = logits.argmax().item()
|
521 |
+
>>> model.config.id2label[predicted_class_id]
|
522 |
+
{expected_output}
|
523 |
+
```
|
524 |
+
|
525 |
+
```python
|
526 |
+
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
527 |
+
>>> num_labels = len(model.config.id2label)
|
528 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
|
529 |
+
|
530 |
+
>>> labels = torch.tensor([1])
|
531 |
+
>>> loss = model(**inputs, labels=labels).loss
|
532 |
+
>>> round(loss.item(), 2)
|
533 |
+
{expected_loss}
|
534 |
+
```
|
535 |
+
|
536 |
+
Example of multi-label classification:
|
537 |
+
|
538 |
+
```python
|
539 |
+
>>> import torch
|
540 |
+
>>> from transformers import {processor_class}, {model_class}
|
541 |
+
|
542 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
543 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}", problem_type="multi_label_classification")
|
544 |
+
|
545 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
546 |
+
|
547 |
+
>>> with torch.no_grad():
|
548 |
+
... logits = model(**inputs).logits
|
549 |
+
|
550 |
+
>>> predicted_class_id = logits.argmax().item()
|
551 |
+
>>> model.config.id2label[predicted_class_id]
|
552 |
+
{expected_output}
|
553 |
+
```
|
554 |
+
|
555 |
+
```python
|
556 |
+
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
557 |
+
>>> num_labels = len(model.config.id2label)
|
558 |
+
>>> model = {model_class}.from_pretrained(
|
559 |
+
... "{checkpoint}", num_labels=num_labels, problem_type="multi_label_classification"
|
560 |
+
... )
|
561 |
+
|
562 |
+
>>> labels = torch.nn.functional.one_hot(torch.tensor([predicted_class_id]), num_classes=num_labels).to(
|
563 |
+
... torch.float
|
564 |
+
... )
|
565 |
+
>>> loss = model(**inputs, labels=labels).loss
|
566 |
+
>>> loss.backward() # doctest: +IGNORE_RESULT
|
567 |
+
```
|
568 |
+
"""
|
569 |
+
|
570 |
+
PT_MASKED_LM_SAMPLE = r"""
|
571 |
+
Example:
|
572 |
+
|
573 |
+
```python
|
574 |
+
>>> from transformers import {processor_class}, {model_class}
|
575 |
+
>>> import torch
|
576 |
+
|
577 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
578 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
579 |
+
|
580 |
+
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="pt")
|
581 |
+
|
582 |
+
>>> with torch.no_grad():
|
583 |
+
... logits = model(**inputs).logits
|
584 |
+
|
585 |
+
>>> # retrieve index of {mask}
|
586 |
+
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
|
587 |
+
|
588 |
+
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
|
589 |
+
>>> tokenizer.decode(predicted_token_id)
|
590 |
+
{expected_output}
|
591 |
+
```
|
592 |
+
|
593 |
+
```python
|
594 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
|
595 |
+
>>> # mask labels of non-{mask} tokens
|
596 |
+
>>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
597 |
+
|
598 |
+
>>> outputs = model(**inputs, labels=labels)
|
599 |
+
>>> round(outputs.loss.item(), 2)
|
600 |
+
{expected_loss}
|
601 |
+
```
|
602 |
+
"""
|
603 |
+
|
604 |
+
PT_BASE_MODEL_SAMPLE = r"""
|
605 |
+
Example:
|
606 |
+
|
607 |
+
```python
|
608 |
+
>>> from transformers import {processor_class}, {model_class}
|
609 |
+
>>> import torch
|
610 |
+
|
611 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
612 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
613 |
+
|
614 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
615 |
+
>>> outputs = model(**inputs)
|
616 |
+
|
617 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
618 |
+
```
|
619 |
+
"""
|
620 |
+
|
621 |
+
PT_MULTIPLE_CHOICE_SAMPLE = r"""
|
622 |
+
Example:
|
623 |
+
|
624 |
+
```python
|
625 |
+
>>> from transformers import {processor_class}, {model_class}
|
626 |
+
>>> import torch
|
627 |
+
|
628 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
629 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
630 |
+
|
631 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
632 |
+
>>> choice0 = "It is eaten with a fork and a knife."
|
633 |
+
>>> choice1 = "It is eaten while held in the hand."
|
634 |
+
>>> labels = torch.tensor(0).unsqueeze(0) # choice0 is correct (according to Wikipedia ;)), batch size 1
|
635 |
+
|
636 |
+
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
|
637 |
+
>>> outputs = model(**{{k: v.unsqueeze(0) for k, v in encoding.items()}}, labels=labels) # batch size is 1
|
638 |
+
|
639 |
+
>>> # the linear classifier still needs to be trained
|
640 |
+
>>> loss = outputs.loss
|
641 |
+
>>> logits = outputs.logits
|
642 |
+
```
|
643 |
+
"""
|
644 |
+
|
645 |
+
PT_CAUSAL_LM_SAMPLE = r"""
|
646 |
+
Example:
|
647 |
+
|
648 |
+
```python
|
649 |
+
>>> import torch
|
650 |
+
>>> from transformers import {processor_class}, {model_class}
|
651 |
+
|
652 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
653 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
654 |
+
|
655 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
656 |
+
>>> outputs = model(**inputs, labels=inputs["input_ids"])
|
657 |
+
>>> loss = outputs.loss
|
658 |
+
>>> logits = outputs.logits
|
659 |
+
```
|
660 |
+
"""
|
661 |
+
|
662 |
+
PT_SPEECH_BASE_MODEL_SAMPLE = r"""
|
663 |
+
Example:
|
664 |
+
|
665 |
+
```python
|
666 |
+
>>> from transformers import {processor_class}, {model_class}
|
667 |
+
>>> import torch
|
668 |
+
>>> from datasets import load_dataset
|
669 |
+
|
670 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
671 |
+
>>> dataset = dataset.sort("id")
|
672 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
673 |
+
|
674 |
+
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
675 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
676 |
+
|
677 |
+
>>> # audio file is decoded on the fly
|
678 |
+
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
679 |
+
>>> with torch.no_grad():
|
680 |
+
... outputs = model(**inputs)
|
681 |
+
|
682 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
683 |
+
>>> list(last_hidden_states.shape)
|
684 |
+
{expected_output}
|
685 |
+
```
|
686 |
+
"""
|
687 |
+
|
688 |
+
PT_SPEECH_CTC_SAMPLE = r"""
|
689 |
+
Example:
|
690 |
+
|
691 |
+
```python
|
692 |
+
>>> from transformers import {processor_class}, {model_class}
|
693 |
+
>>> from datasets import load_dataset
|
694 |
+
>>> import torch
|
695 |
+
|
696 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
697 |
+
>>> dataset = dataset.sort("id")
|
698 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
699 |
+
|
700 |
+
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
701 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
702 |
+
|
703 |
+
>>> # audio file is decoded on the fly
|
704 |
+
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
705 |
+
>>> with torch.no_grad():
|
706 |
+
... logits = model(**inputs).logits
|
707 |
+
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
708 |
+
|
709 |
+
>>> # transcribe speech
|
710 |
+
>>> transcription = processor.batch_decode(predicted_ids)
|
711 |
+
>>> transcription[0]
|
712 |
+
{expected_output}
|
713 |
+
```
|
714 |
+
|
715 |
+
```python
|
716 |
+
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="pt").input_ids
|
717 |
+
|
718 |
+
>>> # compute loss
|
719 |
+
>>> loss = model(**inputs).loss
|
720 |
+
>>> round(loss.item(), 2)
|
721 |
+
{expected_loss}
|
722 |
+
```
|
723 |
+
"""
|
724 |
+
|
725 |
+
PT_SPEECH_SEQ_CLASS_SAMPLE = r"""
|
726 |
+
Example:
|
727 |
+
|
728 |
+
```python
|
729 |
+
>>> from transformers import {processor_class}, {model_class}
|
730 |
+
>>> from datasets import load_dataset
|
731 |
+
>>> import torch
|
732 |
+
|
733 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
734 |
+
>>> dataset = dataset.sort("id")
|
735 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
736 |
+
|
737 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
738 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
739 |
+
|
740 |
+
>>> # audio file is decoded on the fly
|
741 |
+
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
|
742 |
+
|
743 |
+
>>> with torch.no_grad():
|
744 |
+
... logits = model(**inputs).logits
|
745 |
+
|
746 |
+
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
|
747 |
+
>>> predicted_label = model.config.id2label[predicted_class_ids]
|
748 |
+
>>> predicted_label
|
749 |
+
{expected_output}
|
750 |
+
```
|
751 |
+
|
752 |
+
```python
|
753 |
+
>>> # compute loss - target_label is e.g. "down"
|
754 |
+
>>> target_label = model.config.id2label[0]
|
755 |
+
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
|
756 |
+
>>> loss = model(**inputs).loss
|
757 |
+
>>> round(loss.item(), 2)
|
758 |
+
{expected_loss}
|
759 |
+
```
|
760 |
+
"""
|
761 |
+
|
762 |
+
PT_SPEECH_FRAME_CLASS_SAMPLE = r"""
|
763 |
+
Example:
|
764 |
+
|
765 |
+
```python
|
766 |
+
>>> from transformers import {processor_class}, {model_class}
|
767 |
+
>>> from datasets import load_dataset
|
768 |
+
>>> import torch
|
769 |
+
|
770 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
771 |
+
>>> dataset = dataset.sort("id")
|
772 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
773 |
+
|
774 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
775 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
776 |
+
|
777 |
+
>>> # audio file is decoded on the fly
|
778 |
+
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
|
779 |
+
>>> with torch.no_grad():
|
780 |
+
... logits = model(**inputs).logits
|
781 |
+
|
782 |
+
>>> probabilities = torch.sigmoid(logits[0])
|
783 |
+
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
|
784 |
+
>>> labels = (probabilities > 0.5).long()
|
785 |
+
>>> labels[0].tolist()
|
786 |
+
{expected_output}
|
787 |
+
```
|
788 |
+
"""
|
789 |
+
|
790 |
+
PT_SPEECH_XVECTOR_SAMPLE = r"""
|
791 |
+
Example:
|
792 |
+
|
793 |
+
```python
|
794 |
+
>>> from transformers import {processor_class}, {model_class}
|
795 |
+
>>> from datasets import load_dataset
|
796 |
+
>>> import torch
|
797 |
+
|
798 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
799 |
+
>>> dataset = dataset.sort("id")
|
800 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
801 |
+
|
802 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
803 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
804 |
+
|
805 |
+
>>> # audio file is decoded on the fly
|
806 |
+
>>> inputs = feature_extractor(
|
807 |
+
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
|
808 |
+
... )
|
809 |
+
>>> with torch.no_grad():
|
810 |
+
... embeddings = model(**inputs).embeddings
|
811 |
+
|
812 |
+
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
|
813 |
+
|
814 |
+
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
|
815 |
+
>>> cosine_sim = torch.nn.CosineSimilarity(dim=-1)
|
816 |
+
>>> similarity = cosine_sim(embeddings[0], embeddings[1])
|
817 |
+
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
|
818 |
+
>>> if similarity < threshold:
|
819 |
+
... print("Speakers are not the same!")
|
820 |
+
>>> round(similarity.item(), 2)
|
821 |
+
{expected_output}
|
822 |
+
```
|
823 |
+
"""
|
824 |
+
|
825 |
+
PT_VISION_BASE_MODEL_SAMPLE = r"""
|
826 |
+
Example:
|
827 |
+
|
828 |
+
```python
|
829 |
+
>>> from transformers import {processor_class}, {model_class}
|
830 |
+
>>> import torch
|
831 |
+
>>> from datasets import load_dataset
|
832 |
+
|
833 |
+
>>> dataset = load_dataset("huggingface/cats-image")
|
834 |
+
>>> image = dataset["test"]["image"][0]
|
835 |
+
|
836 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
837 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
838 |
+
|
839 |
+
>>> inputs = feature_extractor(image, return_tensors="pt")
|
840 |
+
|
841 |
+
>>> with torch.no_grad():
|
842 |
+
... outputs = model(**inputs)
|
843 |
+
|
844 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
845 |
+
>>> list(last_hidden_states.shape)
|
846 |
+
{expected_output}
|
847 |
+
```
|
848 |
+
"""
|
849 |
+
|
850 |
+
PT_VISION_SEQ_CLASS_SAMPLE = r"""
|
851 |
+
Example:
|
852 |
+
|
853 |
+
```python
|
854 |
+
>>> from transformers import {processor_class}, {model_class}
|
855 |
+
>>> import torch
|
856 |
+
>>> from datasets import load_dataset
|
857 |
+
|
858 |
+
>>> dataset = load_dataset("huggingface/cats-image")
|
859 |
+
>>> image = dataset["test"]["image"][0]
|
860 |
+
|
861 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
862 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
863 |
+
|
864 |
+
>>> inputs = feature_extractor(image, return_tensors="pt")
|
865 |
+
|
866 |
+
>>> with torch.no_grad():
|
867 |
+
... logits = model(**inputs).logits
|
868 |
+
|
869 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
870 |
+
>>> predicted_label = logits.argmax(-1).item()
|
871 |
+
>>> print(model.config.id2label[predicted_label])
|
872 |
+
{expected_output}
|
873 |
+
```
|
874 |
+
"""
|
875 |
+
|
876 |
+
PT_SAMPLE_DOCSTRINGS = {
|
877 |
+
"SequenceClassification": PT_SEQUENCE_CLASSIFICATION_SAMPLE,
|
878 |
+
"QuestionAnswering": PT_QUESTION_ANSWERING_SAMPLE,
|
879 |
+
"TokenClassification": PT_TOKEN_CLASSIFICATION_SAMPLE,
|
880 |
+
"MultipleChoice": PT_MULTIPLE_CHOICE_SAMPLE,
|
881 |
+
"MaskedLM": PT_MASKED_LM_SAMPLE,
|
882 |
+
"LMHead": PT_CAUSAL_LM_SAMPLE,
|
883 |
+
"BaseModel": PT_BASE_MODEL_SAMPLE,
|
884 |
+
"SpeechBaseModel": PT_SPEECH_BASE_MODEL_SAMPLE,
|
885 |
+
"CTC": PT_SPEECH_CTC_SAMPLE,
|
886 |
+
"AudioClassification": PT_SPEECH_SEQ_CLASS_SAMPLE,
|
887 |
+
"AudioFrameClassification": PT_SPEECH_FRAME_CLASS_SAMPLE,
|
888 |
+
"AudioXVector": PT_SPEECH_XVECTOR_SAMPLE,
|
889 |
+
"VisionBaseModel": PT_VISION_BASE_MODEL_SAMPLE,
|
890 |
+
"ImageClassification": PT_VISION_SEQ_CLASS_SAMPLE,
|
891 |
+
}
|
892 |
+
|
893 |
+
TF_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
894 |
+
Example:
|
895 |
+
|
896 |
+
```python
|
897 |
+
>>> from transformers import {processor_class}, {model_class}
|
898 |
+
>>> import tensorflow as tf
|
899 |
+
|
900 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
901 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
902 |
+
|
903 |
+
>>> inputs = tokenizer(
|
904 |
+
... "HuggingFace is a company based in Paris and New York", add_special_tokens=False, return_tensors="tf"
|
905 |
+
... )
|
906 |
+
|
907 |
+
>>> logits = model(**inputs).logits
|
908 |
+
>>> predicted_token_class_ids = tf.math.argmax(logits, axis=-1)
|
909 |
+
|
910 |
+
>>> # Note that tokens are classified rather then input words which means that
|
911 |
+
>>> # there might be more predicted token classes than words.
|
912 |
+
>>> # Multiple token classes might account for the same word
|
913 |
+
>>> predicted_tokens_classes = [model.config.id2label[t] for t in predicted_token_class_ids[0].numpy().tolist()]
|
914 |
+
>>> predicted_tokens_classes
|
915 |
+
{expected_output}
|
916 |
+
```
|
917 |
+
|
918 |
+
```python
|
919 |
+
>>> labels = predicted_token_class_ids
|
920 |
+
>>> loss = tf.math.reduce_mean(model(**inputs, labels=labels).loss)
|
921 |
+
>>> round(float(loss), 2)
|
922 |
+
{expected_loss}
|
923 |
+
```
|
924 |
+
"""
|
925 |
+
|
926 |
+
TF_QUESTION_ANSWERING_SAMPLE = r"""
|
927 |
+
Example:
|
928 |
+
|
929 |
+
```python
|
930 |
+
>>> from transformers import {processor_class}, {model_class}
|
931 |
+
>>> import tensorflow as tf
|
932 |
+
|
933 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
934 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
935 |
+
|
936 |
+
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
937 |
+
|
938 |
+
>>> inputs = tokenizer(question, text, return_tensors="tf")
|
939 |
+
>>> outputs = model(**inputs)
|
940 |
+
|
941 |
+
>>> answer_start_index = int(tf.math.argmax(outputs.start_logits, axis=-1)[0])
|
942 |
+
>>> answer_end_index = int(tf.math.argmax(outputs.end_logits, axis=-1)[0])
|
943 |
+
|
944 |
+
>>> predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1]
|
945 |
+
>>> tokenizer.decode(predict_answer_tokens)
|
946 |
+
{expected_output}
|
947 |
+
```
|
948 |
+
|
949 |
+
```python
|
950 |
+
>>> # target is "nice puppet"
|
951 |
+
>>> target_start_index = tf.constant([{qa_target_start_index}])
|
952 |
+
>>> target_end_index = tf.constant([{qa_target_end_index}])
|
953 |
+
|
954 |
+
>>> outputs = model(**inputs, start_positions=target_start_index, end_positions=target_end_index)
|
955 |
+
>>> loss = tf.math.reduce_mean(outputs.loss)
|
956 |
+
>>> round(float(loss), 2)
|
957 |
+
{expected_loss}
|
958 |
+
```
|
959 |
+
"""
|
960 |
+
|
961 |
+
TF_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
962 |
+
Example:
|
963 |
+
|
964 |
+
```python
|
965 |
+
>>> from transformers import {processor_class}, {model_class}
|
966 |
+
>>> import tensorflow as tf
|
967 |
+
|
968 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
969 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
970 |
+
|
971 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
972 |
+
|
973 |
+
>>> logits = model(**inputs).logits
|
974 |
+
|
975 |
+
>>> predicted_class_id = int(tf.math.argmax(logits, axis=-1)[0])
|
976 |
+
>>> model.config.id2label[predicted_class_id]
|
977 |
+
{expected_output}
|
978 |
+
```
|
979 |
+
|
980 |
+
```python
|
981 |
+
>>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`
|
982 |
+
>>> num_labels = len(model.config.id2label)
|
983 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}", num_labels=num_labels)
|
984 |
+
|
985 |
+
>>> labels = tf.constant(1)
|
986 |
+
>>> loss = model(**inputs, labels=labels).loss
|
987 |
+
>>> round(float(loss), 2)
|
988 |
+
{expected_loss}
|
989 |
+
```
|
990 |
+
"""
|
991 |
+
|
992 |
+
TF_MASKED_LM_SAMPLE = r"""
|
993 |
+
Example:
|
994 |
+
|
995 |
+
```python
|
996 |
+
>>> from transformers import {processor_class}, {model_class}
|
997 |
+
>>> import tensorflow as tf
|
998 |
+
|
999 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1000 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1001 |
+
|
1002 |
+
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="tf")
|
1003 |
+
>>> logits = model(**inputs).logits
|
1004 |
+
|
1005 |
+
>>> # retrieve index of {mask}
|
1006 |
+
>>> mask_token_index = tf.where((inputs.input_ids == tokenizer.mask_token_id)[0])
|
1007 |
+
>>> selected_logits = tf.gather_nd(logits[0], indices=mask_token_index)
|
1008 |
+
|
1009 |
+
>>> predicted_token_id = tf.math.argmax(selected_logits, axis=-1)
|
1010 |
+
>>> tokenizer.decode(predicted_token_id)
|
1011 |
+
{expected_output}
|
1012 |
+
```
|
1013 |
+
|
1014 |
+
```python
|
1015 |
+
>>> labels = tokenizer("The capital of France is Paris.", return_tensors="tf")["input_ids"]
|
1016 |
+
>>> # mask labels of non-{mask} tokens
|
1017 |
+
>>> labels = tf.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
|
1018 |
+
|
1019 |
+
>>> outputs = model(**inputs, labels=labels)
|
1020 |
+
>>> round(float(outputs.loss), 2)
|
1021 |
+
{expected_loss}
|
1022 |
+
```
|
1023 |
+
"""
|
1024 |
+
|
1025 |
+
TF_BASE_MODEL_SAMPLE = r"""
|
1026 |
+
Example:
|
1027 |
+
|
1028 |
+
```python
|
1029 |
+
>>> from transformers import {processor_class}, {model_class}
|
1030 |
+
>>> import tensorflow as tf
|
1031 |
+
|
1032 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1033 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1034 |
+
|
1035 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
1036 |
+
>>> outputs = model(inputs)
|
1037 |
+
|
1038 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
1039 |
+
```
|
1040 |
+
"""
|
1041 |
+
|
1042 |
+
TF_MULTIPLE_CHOICE_SAMPLE = r"""
|
1043 |
+
Example:
|
1044 |
+
|
1045 |
+
```python
|
1046 |
+
>>> from transformers import {processor_class}, {model_class}
|
1047 |
+
>>> import tensorflow as tf
|
1048 |
+
|
1049 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1050 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1051 |
+
|
1052 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1053 |
+
>>> choice0 = "It is eaten with a fork and a knife."
|
1054 |
+
>>> choice1 = "It is eaten while held in the hand."
|
1055 |
+
|
1056 |
+
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="tf", padding=True)
|
1057 |
+
>>> inputs = {{k: tf.expand_dims(v, 0) for k, v in encoding.items()}}
|
1058 |
+
>>> outputs = model(inputs) # batch size is 1
|
1059 |
+
|
1060 |
+
>>> # the linear classifier still needs to be trained
|
1061 |
+
>>> logits = outputs.logits
|
1062 |
+
```
|
1063 |
+
"""
|
1064 |
+
|
1065 |
+
TF_CAUSAL_LM_SAMPLE = r"""
|
1066 |
+
Example:
|
1067 |
+
|
1068 |
+
```python
|
1069 |
+
>>> from transformers import {processor_class}, {model_class}
|
1070 |
+
>>> import tensorflow as tf
|
1071 |
+
|
1072 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1073 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1074 |
+
|
1075 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf")
|
1076 |
+
>>> outputs = model(inputs)
|
1077 |
+
>>> logits = outputs.logits
|
1078 |
+
```
|
1079 |
+
"""
|
1080 |
+
|
1081 |
+
TF_SPEECH_BASE_MODEL_SAMPLE = r"""
|
1082 |
+
Example:
|
1083 |
+
|
1084 |
+
```python
|
1085 |
+
>>> from transformers import {processor_class}, {model_class}
|
1086 |
+
>>> from datasets import load_dataset
|
1087 |
+
|
1088 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
1089 |
+
>>> dataset = dataset.sort("id")
|
1090 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
1091 |
+
|
1092 |
+
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
1093 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1094 |
+
|
1095 |
+
>>> # audio file is decoded on the fly
|
1096 |
+
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
|
1097 |
+
>>> outputs = model(**inputs)
|
1098 |
+
|
1099 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
1100 |
+
>>> list(last_hidden_states.shape)
|
1101 |
+
{expected_output}
|
1102 |
+
```
|
1103 |
+
"""
|
1104 |
+
|
1105 |
+
TF_SPEECH_CTC_SAMPLE = r"""
|
1106 |
+
Example:
|
1107 |
+
|
1108 |
+
```python
|
1109 |
+
>>> from transformers import {processor_class}, {model_class}
|
1110 |
+
>>> from datasets import load_dataset
|
1111 |
+
>>> import tensorflow as tf
|
1112 |
+
|
1113 |
+
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
|
1114 |
+
>>> dataset = dataset.sort("id")
|
1115 |
+
>>> sampling_rate = dataset.features["audio"].sampling_rate
|
1116 |
+
|
1117 |
+
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
|
1118 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1119 |
+
|
1120 |
+
>>> # audio file is decoded on the fly
|
1121 |
+
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="tf")
|
1122 |
+
>>> logits = model(**inputs).logits
|
1123 |
+
>>> predicted_ids = tf.math.argmax(logits, axis=-1)
|
1124 |
+
|
1125 |
+
>>> # transcribe speech
|
1126 |
+
>>> transcription = processor.batch_decode(predicted_ids)
|
1127 |
+
>>> transcription[0]
|
1128 |
+
{expected_output}
|
1129 |
+
```
|
1130 |
+
|
1131 |
+
```python
|
1132 |
+
>>> inputs["labels"] = processor(text=dataset[0]["text"], return_tensors="tf").input_ids
|
1133 |
+
|
1134 |
+
>>> # compute loss
|
1135 |
+
>>> loss = model(**inputs).loss
|
1136 |
+
>>> round(float(loss), 2)
|
1137 |
+
{expected_loss}
|
1138 |
+
```
|
1139 |
+
"""
|
1140 |
+
|
1141 |
+
TF_VISION_BASE_MODEL_SAMPLE = r"""
|
1142 |
+
Example:
|
1143 |
+
|
1144 |
+
```python
|
1145 |
+
>>> from transformers import {processor_class}, {model_class}
|
1146 |
+
>>> from datasets import load_dataset
|
1147 |
+
|
1148 |
+
>>> dataset = load_dataset("huggingface/cats-image")
|
1149 |
+
>>> image = dataset["test"]["image"][0]
|
1150 |
+
|
1151 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
1152 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1153 |
+
|
1154 |
+
>>> inputs = feature_extractor(image, return_tensors="tf")
|
1155 |
+
>>> outputs = model(**inputs)
|
1156 |
+
|
1157 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
1158 |
+
>>> list(last_hidden_states.shape)
|
1159 |
+
{expected_output}
|
1160 |
+
```
|
1161 |
+
"""
|
1162 |
+
|
1163 |
+
TF_VISION_SEQ_CLASS_SAMPLE = r"""
|
1164 |
+
Example:
|
1165 |
+
|
1166 |
+
```python
|
1167 |
+
>>> from transformers import {processor_class}, {model_class}
|
1168 |
+
>>> import tensorflow as tf
|
1169 |
+
>>> from datasets import load_dataset
|
1170 |
+
|
1171 |
+
>>> dataset = load_dataset("huggingface/cats-image")
|
1172 |
+
>>> image = dataset["test"]["image"][0]
|
1173 |
+
|
1174 |
+
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
|
1175 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1176 |
+
|
1177 |
+
>>> inputs = feature_extractor(image, return_tensors="tf")
|
1178 |
+
>>> logits = model(**inputs).logits
|
1179 |
+
|
1180 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
1181 |
+
>>> predicted_label = int(tf.math.argmax(logits, axis=-1))
|
1182 |
+
>>> print(model.config.id2label[predicted_label])
|
1183 |
+
{expected_output}
|
1184 |
+
```
|
1185 |
+
"""
|
1186 |
+
|
1187 |
+
TF_SAMPLE_DOCSTRINGS = {
|
1188 |
+
"SequenceClassification": TF_SEQUENCE_CLASSIFICATION_SAMPLE,
|
1189 |
+
"QuestionAnswering": TF_QUESTION_ANSWERING_SAMPLE,
|
1190 |
+
"TokenClassification": TF_TOKEN_CLASSIFICATION_SAMPLE,
|
1191 |
+
"MultipleChoice": TF_MULTIPLE_CHOICE_SAMPLE,
|
1192 |
+
"MaskedLM": TF_MASKED_LM_SAMPLE,
|
1193 |
+
"LMHead": TF_CAUSAL_LM_SAMPLE,
|
1194 |
+
"BaseModel": TF_BASE_MODEL_SAMPLE,
|
1195 |
+
"SpeechBaseModel": TF_SPEECH_BASE_MODEL_SAMPLE,
|
1196 |
+
"CTC": TF_SPEECH_CTC_SAMPLE,
|
1197 |
+
"VisionBaseModel": TF_VISION_BASE_MODEL_SAMPLE,
|
1198 |
+
"ImageClassification": TF_VISION_SEQ_CLASS_SAMPLE,
|
1199 |
+
}
|
1200 |
+
|
1201 |
+
FLAX_TOKEN_CLASSIFICATION_SAMPLE = r"""
|
1202 |
+
Example:
|
1203 |
+
|
1204 |
+
```python
|
1205 |
+
>>> from transformers import {processor_class}, {model_class}
|
1206 |
+
|
1207 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1208 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1209 |
+
|
1210 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
1211 |
+
|
1212 |
+
>>> outputs = model(**inputs)
|
1213 |
+
>>> logits = outputs.logits
|
1214 |
+
```
|
1215 |
+
"""
|
1216 |
+
|
1217 |
+
FLAX_QUESTION_ANSWERING_SAMPLE = r"""
|
1218 |
+
Example:
|
1219 |
+
|
1220 |
+
```python
|
1221 |
+
>>> from transformers import {processor_class}, {model_class}
|
1222 |
+
|
1223 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1224 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1225 |
+
|
1226 |
+
>>> question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
1227 |
+
>>> inputs = tokenizer(question, text, return_tensors="jax")
|
1228 |
+
|
1229 |
+
>>> outputs = model(**inputs)
|
1230 |
+
>>> start_scores = outputs.start_logits
|
1231 |
+
>>> end_scores = outputs.end_logits
|
1232 |
+
```
|
1233 |
+
"""
|
1234 |
+
|
1235 |
+
FLAX_SEQUENCE_CLASSIFICATION_SAMPLE = r"""
|
1236 |
+
Example:
|
1237 |
+
|
1238 |
+
```python
|
1239 |
+
>>> from transformers import {processor_class}, {model_class}
|
1240 |
+
|
1241 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1242 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1243 |
+
|
1244 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
1245 |
+
|
1246 |
+
>>> outputs = model(**inputs)
|
1247 |
+
>>> logits = outputs.logits
|
1248 |
+
```
|
1249 |
+
"""
|
1250 |
+
|
1251 |
+
FLAX_MASKED_LM_SAMPLE = r"""
|
1252 |
+
Example:
|
1253 |
+
|
1254 |
+
```python
|
1255 |
+
>>> from transformers import {processor_class}, {model_class}
|
1256 |
+
|
1257 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1258 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1259 |
+
|
1260 |
+
>>> inputs = tokenizer("The capital of France is {mask}.", return_tensors="jax")
|
1261 |
+
|
1262 |
+
>>> outputs = model(**inputs)
|
1263 |
+
>>> logits = outputs.logits
|
1264 |
+
```
|
1265 |
+
"""
|
1266 |
+
|
1267 |
+
FLAX_BASE_MODEL_SAMPLE = r"""
|
1268 |
+
Example:
|
1269 |
+
|
1270 |
+
```python
|
1271 |
+
>>> from transformers import {processor_class}, {model_class}
|
1272 |
+
|
1273 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1274 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1275 |
+
|
1276 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="jax")
|
1277 |
+
>>> outputs = model(**inputs)
|
1278 |
+
|
1279 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
1280 |
+
```
|
1281 |
+
"""
|
1282 |
+
|
1283 |
+
FLAX_MULTIPLE_CHOICE_SAMPLE = r"""
|
1284 |
+
Example:
|
1285 |
+
|
1286 |
+
```python
|
1287 |
+
>>> from transformers import {processor_class}, {model_class}
|
1288 |
+
|
1289 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1290 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1291 |
+
|
1292 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1293 |
+
>>> choice0 = "It is eaten with a fork and a knife."
|
1294 |
+
>>> choice1 = "It is eaten while held in the hand."
|
1295 |
+
|
1296 |
+
>>> encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="jax", padding=True)
|
1297 |
+
>>> outputs = model(**{{k: v[None, :] for k, v in encoding.items()}})
|
1298 |
+
|
1299 |
+
>>> logits = outputs.logits
|
1300 |
+
```
|
1301 |
+
"""
|
1302 |
+
|
1303 |
+
FLAX_CAUSAL_LM_SAMPLE = r"""
|
1304 |
+
Example:
|
1305 |
+
|
1306 |
+
```python
|
1307 |
+
>>> from transformers import {processor_class}, {model_class}
|
1308 |
+
|
1309 |
+
>>> tokenizer = {processor_class}.from_pretrained("{checkpoint}")
|
1310 |
+
>>> model = {model_class}.from_pretrained("{checkpoint}")
|
1311 |
+
|
1312 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="np")
|
1313 |
+
>>> outputs = model(**inputs)
|
1314 |
+
|
1315 |
+
>>> # retrieve logts for next token
|
1316 |
+
>>> next_token_logits = outputs.logits[:, -1]
|
1317 |
+
```
|
1318 |
+
"""
|
1319 |
+
|
1320 |
+
FLAX_SAMPLE_DOCSTRINGS = {
|
1321 |
+
"SequenceClassification": FLAX_SEQUENCE_CLASSIFICATION_SAMPLE,
|
1322 |
+
"QuestionAnswering": FLAX_QUESTION_ANSWERING_SAMPLE,
|
1323 |
+
"TokenClassification": FLAX_TOKEN_CLASSIFICATION_SAMPLE,
|
1324 |
+
"MultipleChoice": FLAX_MULTIPLE_CHOICE_SAMPLE,
|
1325 |
+
"MaskedLM": FLAX_MASKED_LM_SAMPLE,
|
1326 |
+
"BaseModel": FLAX_BASE_MODEL_SAMPLE,
|
1327 |
+
"LMHead": FLAX_CAUSAL_LM_SAMPLE,
|
1328 |
+
}
|
1329 |
+
|
1330 |
+
|
1331 |
+
def add_code_sample_docstrings(
|
1332 |
+
*docstr,
|
1333 |
+
processor_class=None,
|
1334 |
+
checkpoint=None,
|
1335 |
+
output_type=None,
|
1336 |
+
config_class=None,
|
1337 |
+
mask="[MASK]",
|
1338 |
+
qa_target_start_index=14,
|
1339 |
+
qa_target_end_index=15,
|
1340 |
+
model_cls=None,
|
1341 |
+
modality=None,
|
1342 |
+
expected_output="",
|
1343 |
+
expected_loss="",
|
1344 |
+
):
|
1345 |
+
def docstring_decorator(fn):
|
1346 |
+
# model_class defaults to function's class if not specified otherwise
|
1347 |
+
model_class = fn.__qualname__.split(".")[0] if model_cls is None else model_cls
|
1348 |
+
|
1349 |
+
if model_class[:2] == "TF":
|
1350 |
+
sample_docstrings = TF_SAMPLE_DOCSTRINGS
|
1351 |
+
elif model_class[:4] == "Flax":
|
1352 |
+
sample_docstrings = FLAX_SAMPLE_DOCSTRINGS
|
1353 |
+
else:
|
1354 |
+
sample_docstrings = PT_SAMPLE_DOCSTRINGS
|
1355 |
+
|
1356 |
+
# putting all kwargs for docstrings in a dict to be used
|
1357 |
+
# with the `.format(**doc_kwargs)`. Note that string might
|
1358 |
+
# be formatted with non-existing keys, which is fine.
|
1359 |
+
doc_kwargs = dict(
|
1360 |
+
model_class=model_class,
|
1361 |
+
processor_class=processor_class,
|
1362 |
+
checkpoint=checkpoint,
|
1363 |
+
mask=mask,
|
1364 |
+
qa_target_start_index=qa_target_start_index,
|
1365 |
+
qa_target_end_index=qa_target_end_index,
|
1366 |
+
expected_output=expected_output,
|
1367 |
+
expected_loss=expected_loss,
|
1368 |
+
)
|
1369 |
+
|
1370 |
+
if "SequenceClassification" in model_class and modality == "audio":
|
1371 |
+
code_sample = sample_docstrings["AudioClassification"]
|
1372 |
+
elif "SequenceClassification" in model_class:
|
1373 |
+
code_sample = sample_docstrings["SequenceClassification"]
|
1374 |
+
elif "QuestionAnswering" in model_class:
|
1375 |
+
code_sample = sample_docstrings["QuestionAnswering"]
|
1376 |
+
elif "TokenClassification" in model_class:
|
1377 |
+
code_sample = sample_docstrings["TokenClassification"]
|
1378 |
+
elif "MultipleChoice" in model_class:
|
1379 |
+
code_sample = sample_docstrings["MultipleChoice"]
|
1380 |
+
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
|
1381 |
+
code_sample = sample_docstrings["MaskedLM"]
|
1382 |
+
elif "LMHead" in model_class or "CausalLM" in model_class:
|
1383 |
+
code_sample = sample_docstrings["LMHead"]
|
1384 |
+
elif "CTC" in model_class:
|
1385 |
+
code_sample = sample_docstrings["CTC"]
|
1386 |
+
elif "AudioFrameClassification" in model_class:
|
1387 |
+
code_sample = sample_docstrings["AudioFrameClassification"]
|
1388 |
+
elif "XVector" in model_class and modality == "audio":
|
1389 |
+
code_sample = sample_docstrings["AudioXVector"]
|
1390 |
+
elif "Model" in model_class and modality == "audio":
|
1391 |
+
code_sample = sample_docstrings["SpeechBaseModel"]
|
1392 |
+
elif "Model" in model_class and modality == "vision":
|
1393 |
+
code_sample = sample_docstrings["VisionBaseModel"]
|
1394 |
+
elif "Model" in model_class or "Encoder" in model_class:
|
1395 |
+
code_sample = sample_docstrings["BaseModel"]
|
1396 |
+
elif "ImageClassification" in model_class:
|
1397 |
+
code_sample = sample_docstrings["ImageClassification"]
|
1398 |
+
else:
|
1399 |
+
raise ValueError(f"Docstring can't be built for model {model_class}")
|
1400 |
+
|
1401 |
+
func_doc = (fn.__doc__ or "") + "".join(docstr)
|
1402 |
+
output_doc = "" if output_type is None else _prepare_output_docstrings(output_type, config_class)
|
1403 |
+
built_doc = code_sample.format(**doc_kwargs)
|
1404 |
+
fn.__doc__ = func_doc + output_doc + built_doc
|
1405 |
+
return fn
|
1406 |
+
|
1407 |
+
return docstring_decorator
|
1408 |
+
|
1409 |
+
|
1410 |
+
def prune_linear_layer(layer: nn.Linear, index: torch.LongTensor, dim: int = 0) -> nn.Linear:
|
1411 |
+
"""
|
1412 |
+
Prune a linear layer to keep only entries in index.
|
1413 |
+
|
1414 |
+
Used to remove heads.
|
1415 |
+
|
1416 |
+
Args:
|
1417 |
+
layer (`torch.nn.Linear`): The layer to prune.
|
1418 |
+
index (`torch.LongTensor`): The indices to keep in the layer.
|
1419 |
+
dim (`int`, *optional*, defaults to 0): The dimension on which to keep the indices.
|
1420 |
+
|
1421 |
+
Returns:
|
1422 |
+
`torch.nn.Linear`: The pruned layer as a new layer with `requires_grad=True`.
|
1423 |
+
"""
|
1424 |
+
index = index.to(layer.weight.device)
|
1425 |
+
W = layer.weight.index_select(dim, index).clone().detach()
|
1426 |
+
if layer.bias is not None:
|
1427 |
+
if dim == 1:
|
1428 |
+
b = layer.bias.clone().detach()
|
1429 |
+
else:
|
1430 |
+
b = layer.bias[index].clone().detach()
|
1431 |
+
new_size = list(layer.weight.size())
|
1432 |
+
new_size[dim] = len(index)
|
1433 |
+
new_layer = nn.Linear(new_size[1], new_size[0], bias=layer.bias is not None).to(layer.weight.device)
|
1434 |
+
new_layer.weight.requires_grad = False
|
1435 |
+
new_layer.weight.copy_(W.contiguous())
|
1436 |
+
new_layer.weight.requires_grad = True
|
1437 |
+
if layer.bias is not None:
|
1438 |
+
new_layer.bias.requires_grad = False
|
1439 |
+
new_layer.bias.copy_(b.contiguous())
|
1440 |
+
new_layer.bias.requires_grad = True
|
1441 |
+
return new_layer
|
1442 |
+
|
1443 |
+
|
1444 |
+
def apply_chunking_to_forward(
|
1445 |
+
forward_fn: Callable[..., torch.Tensor], chunk_size: int, chunk_dim: int, *input_tensors
|
1446 |
+
) -> torch.Tensor:
|
1447 |
+
"""
|
1448 |
+
This function chunks the `input_tensors` into smaller input tensor parts of size `chunk_size` over the dimension
|
1449 |
+
`chunk_dim`. It then applies a layer `forward_fn` to each chunk independently to save memory.
|
1450 |
+
|
1451 |
+
If the `forward_fn` is independent across the `chunk_dim` this function will yield the same result as directly
|
1452 |
+
applying `forward_fn` to `input_tensors`.
|
1453 |
+
|
1454 |
+
Args:
|
1455 |
+
forward_fn (`Callable[..., torch.Tensor]`):
|
1456 |
+
The forward function of the model.
|
1457 |
+
chunk_size (`int`):
|
1458 |
+
The chunk size of a chunked tensor: `num_chunks = len(input_tensors[0]) / chunk_size`.
|
1459 |
+
chunk_dim (`int`):
|
1460 |
+
The dimension over which the `input_tensors` should be chunked.
|
1461 |
+
input_tensors (`Tuple[torch.Tensor]`):
|
1462 |
+
The input tensors of `forward_fn` which will be chunked
|
1463 |
+
|
1464 |
+
Returns:
|
1465 |
+
`torch.Tensor`: A tensor with the same shape as the `forward_fn` would have given if applied`.
|
1466 |
+
|
1467 |
+
|
1468 |
+
Examples:
|
1469 |
+
|
1470 |
+
```python
|
1471 |
+
# rename the usual forward() fn to forward_chunk()
|
1472 |
+
def forward_chunk(self, hidden_states):
|
1473 |
+
hidden_states = self.decoder(hidden_states)
|
1474 |
+
return hidden_states
|
1475 |
+
|
1476 |
+
|
1477 |
+
# implement a chunked forward function
|
1478 |
+
def forward(self, hidden_states):
|
1479 |
+
return apply_chunking_to_forward(self.forward_chunk, self.chunk_size_lm_head, self.seq_len_dim, hidden_states)
|
1480 |
+
```"""
|
1481 |
+
|
1482 |
+
assert len(input_tensors) > 0, f"{input_tensors} has to be a tuple/list of tensors"
|
1483 |
+
|
1484 |
+
# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
|
1485 |
+
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
1486 |
+
if num_args_in_forward_chunk_fn != len(input_tensors):
|
1487 |
+
raise ValueError(
|
1488 |
+
f"forward_chunk_fn expects {num_args_in_forward_chunk_fn} arguments, but only {len(input_tensors)} input "
|
1489 |
+
"tensors are given"
|
1490 |
+
)
|
1491 |
+
|
1492 |
+
if chunk_size > 0:
|
1493 |
+
tensor_shape = input_tensors[0].shape[chunk_dim]
|
1494 |
+
for input_tensor in input_tensors:
|
1495 |
+
if input_tensor.shape[chunk_dim] != tensor_shape:
|
1496 |
+
raise ValueError(
|
1497 |
+
f"All input tenors have to be of the same shape: {tensor_shape}, "
|
1498 |
+
f"found shape {input_tensor.shape[chunk_dim]}"
|
1499 |
+
)
|
1500 |
+
|
1501 |
+
if input_tensors[0].shape[chunk_dim] % chunk_size != 0:
|
1502 |
+
raise ValueError(
|
1503 |
+
f"The dimension to be chunked {input_tensors[0].shape[chunk_dim]} has to be a multiple of the chunk "
|
1504 |
+
f"size {chunk_size}"
|
1505 |
+
)
|
1506 |
+
|
1507 |
+
num_chunks = input_tensors[0].shape[chunk_dim] // chunk_size
|
1508 |
+
|
1509 |
+
# chunk input tensor into tuples
|
1510 |
+
input_tensors_chunks = tuple(input_tensor.chunk(num_chunks, dim=chunk_dim) for input_tensor in input_tensors)
|
1511 |
+
# apply forward fn to every tuple
|
1512 |
+
output_chunks = tuple(forward_fn(*input_tensors_chunk) for input_tensors_chunk in zip(*input_tensors_chunks))
|
1513 |
+
# concatenate output at same dimension
|
1514 |
+
return torch.cat(output_chunks, dim=chunk_dim)
|
1515 |
+
|
1516 |
+
return forward_fn(*input_tensors)
|
1517 |
+
|
1518 |
+
|
1519 |
+
def find_pruneable_heads_and_indices(
|
1520 |
+
heads: List[int], n_heads: int, head_size: int, already_pruned_heads: Set[int]
|
1521 |
+
) -> Tuple[Set[int], torch.LongTensor]:
|
1522 |
+
"""
|
1523 |
+
Finds the heads and their indices taking `already_pruned_heads` into account.
|
1524 |
+
|
1525 |
+
Args:
|
1526 |
+
heads (`List[int]`): List of the indices of heads to prune.
|
1527 |
+
n_heads (`int`): The number of heads in the model.
|
1528 |
+
head_size (`int`): The size of each head.
|
1529 |
+
already_pruned_heads (`Set[int]`): A set of already pruned heads.
|
1530 |
+
|
1531 |
+
Returns:
|
1532 |
+
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
|
1533 |
+
"""
|
1534 |
+
mask = torch.ones(n_heads, head_size)
|
1535 |
+
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
|
1536 |
+
for head in heads:
|
1537 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
1538 |
+
head = head - sum(1 if h < head else 0 for h in already_pruned_heads)
|
1539 |
+
mask[head] = 0
|
1540 |
+
mask = mask.view(-1).contiguous().eq(1)
|
1541 |
+
index: torch.LongTensor = torch.arange(len(mask))[mask].long()
|
1542 |
+
return heads, index
|
Images/HTC.jpg
ADDED
Images/Model.jpg
ADDED
README.md
CHANGED
@@ -1,3 +1,169 @@
|
|
1 |
---
|
|
|
2 |
license: apache-2.0
|
3 |
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language: zh
|
3 |
license: apache-2.0
|
4 |
---
|
5 |
+
|
6 |
+
|
7 |
+
# G2PTL
|
8 |
+
|
9 |
+
## Introduction
|
10 |
+
|
11 |
+
G2PTL: A Geography-Graph Pre-trained model for address.
|
12 |
+
|
13 |
+
|
14 |
+
## Model description
|
15 |
+
G2PTL is a Transformer model that is pretrained on a large corpus of Chinese addresses in a self-supervised manner. It has three pretraining objectives:
|
16 |
+
|
17 |
+
- Masked language modeling (MLM): taking an address, the model randomly masks some words in the input text and predicts the masked words. It should be noted that for the geographical entities in the address, we adopt the Whole Word Masking (WWM) approach to mask them and learn the co-occurrence relationships among them.
|
18 |
+
|
19 |
+
- Hierarchical text modeling (HTC): an address is a text with a hierarchical structure of province, city, district, and street. HTC is used to model the hierarchical relationship among these levels in addresses.
|
20 |
+
![HTC.jpg](./Images/HTC.jpg)
|
21 |
+
|
22 |
+
- Geocoding (GC): an address can be represented by a point with latitude and longitude in the real world. The GC task is designed to learn the mapping relationship between address text and geographical location.
|
23 |
+
|
24 |
+
More detail: https://arxiv.org/abs/2304.01559
|
25 |
+
![Model.jpg](./Images/Model.jpg)
|
26 |
+
|
27 |
+
|
28 |
+
## Intended uses & limitations
|
29 |
+
|
30 |
+
This model is designed for decision tasks based on address text, including tasks related to understanding address texts and Spatial-Temporal downstream tasks which rely on address text representation.
|
31 |
+
|
32 |
+
1. Address text understanding tasks
|
33 |
+
- Geocoding
|
34 |
+
- Named Entity Recognition
|
35 |
+
- Geographic Entity Alignment
|
36 |
+
- Address Text Similarity
|
37 |
+
- Address Texy Classification
|
38 |
+
2. Spatial-Temporal downstream tasks:
|
39 |
+
- Estimated Time of Arrival (ETA) Prediction
|
40 |
+
- Pick-up & Delivery Route Prediction.
|
41 |
+
|
42 |
+
The model currently only supports Chinese addresses.
|
43 |
+
|
44 |
+
|
45 |
+
## How to use
|
46 |
+
You can use this model directly with a pipeline for masked language modeling:
|
47 |
+
|
48 |
+
```Python
|
49 |
+
>>> from transformers import pipeline, AutoModel, AutoTokenizer
|
50 |
+
>>> model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
51 |
+
>>> tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
52 |
+
|
53 |
+
>>> mask_filler = pipeline(task= 'fill-mask', model= model,tokenizer = tokenizer)
|
54 |
+
>>> mask_filler("浙江省杭州市[MASK]杭区五常街道阿里巴巴西溪园区")
|
55 |
+
```
|
56 |
+
```json
|
57 |
+
[{'score': 1.0,
|
58 |
+
'token': 562,
|
59 |
+
'token_str': '余',
|
60 |
+
'sequence': '浙 江 省 杭 州 市 余 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
|
61 |
+
{'score': 7.49648343401077e-09,
|
62 |
+
'token': 1852,
|
63 |
+
'token_str': '杭',
|
64 |
+
'sequence': '浙 江 省 杭 州 市 杭 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
|
65 |
+
{'score': 5.823675763849678e-09,
|
66 |
+
'token': 213,
|
67 |
+
'token_str': '西',
|
68 |
+
'sequence': '浙 江 省 杭 州 市 西 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
|
69 |
+
{'score': 3.383779922927488e-09,
|
70 |
+
'token': 346,
|
71 |
+
'token_str': '五',
|
72 |
+
'sequence': '浙 江 省 杭 州 市 五 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'},
|
73 |
+
{'score': 2.9116642430437878e-09,
|
74 |
+
'token': 2268,
|
75 |
+
'token_str': '荆',
|
76 |
+
'sequence': '浙 江 省 杭 州 市 荆 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区'}]
|
77 |
+
```
|
78 |
+
|
79 |
+
You can also use this model for multiple [MASK] filling in PyTorch:
|
80 |
+
```python
|
81 |
+
from transformers import pipeline, AutoModel, AutoTokenizer
|
82 |
+
import torch
|
83 |
+
model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
84 |
+
tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
85 |
+
model.eval()
|
86 |
+
text = ['浙江省杭州市[MASK][MASK][MASK]五常街道阿里巴巴西溪园区']
|
87 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
88 |
+
outputs = model(**encoded_input)
|
89 |
+
prediction_scores = outputs.logits
|
90 |
+
prediction_scores = torch.argmax(prediction_scores, dim=-1)
|
91 |
+
prediction_scores = prediction_scores.cpu().detach().numpy()
|
92 |
+
input_ids = encoded_input['input_ids']
|
93 |
+
print('G2PTL:', tokenizer.decode(prediction_scores[torch.where(input_ids.cpu()>0)][1:-1]))
|
94 |
+
```
|
95 |
+
|
96 |
+
```json
|
97 |
+
G2PTL: 浙 江 省 杭 州 市 余 杭 区 五 常 街 道 阿 里 巴 巴 西 溪 园 区
|
98 |
+
```
|
99 |
+
|
100 |
+
Here is how to use this model to get the HTC output of a given text in PyTorch:
|
101 |
+
|
102 |
+
```python
|
103 |
+
from transformers import pipeline, AutoModel, AutoTokenizer
|
104 |
+
model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
105 |
+
tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
106 |
+
model.eval()
|
107 |
+
text = "浙江省杭州市五常街道阿里巴巴西溪园区"
|
108 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
109 |
+
output = model(**encoded_input)
|
110 |
+
htc_layer_out = output.htc_layer_out
|
111 |
+
htc_pred = model.get_htc_code(htc_layer_out)
|
112 |
+
print('HTC Result: ', model.decode_htc_code_2_chn(htc_pred))
|
113 |
+
```
|
114 |
+
```json
|
115 |
+
HTC Result: ['浙江省杭州市余杭区五常街道', '浙江省杭州市五常街道']
|
116 |
+
```
|
117 |
+
|
118 |
+
Here is how to use this model to get the features/embeddings of a given text in PyTorch:
|
119 |
+
|
120 |
+
```python
|
121 |
+
from transformers import pipeline, AutoModel, AutoTokenizer
|
122 |
+
model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
123 |
+
tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
124 |
+
model.eval()
|
125 |
+
text = "浙江省杭州市余杭区五常街道阿里巴巴西溪园区"
|
126 |
+
encoded_input = tokenizer(text, return_tensors='pt')
|
127 |
+
output = model(**encoded_input)
|
128 |
+
final_hidden_state = output.final_hidden_state
|
129 |
+
```
|
130 |
+
|
131 |
+
Here is how to use this model to get cosine similarity between two address texts in PyTorch:
|
132 |
+
|
133 |
+
```python
|
134 |
+
from transformers import pipeline, AutoModel, AutoTokenizer
|
135 |
+
import torch
|
136 |
+
model = AutoModel.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
137 |
+
tokenizer = AutoTokenizer.from_pretrained('JunhongLou/G2PTL', trust_remote_code=True)
|
138 |
+
model.eval()
|
139 |
+
text = ["浙江省杭州市余杭区五常街道阿里巴巴西溪园区", "浙江省杭州市阿里巴巴西溪园区"]
|
140 |
+
encoded_input = tokenizer(text, return_tensors='pt', padding=True)
|
141 |
+
output = model(**encoded_input)
|
142 |
+
final_pooler_output = output.final_pooler_output
|
143 |
+
cos_sim = torch.cosine_similarity(final_pooler_output[0], final_pooler_output[1])
|
144 |
+
print('Cosin Similarity: ', cos_sim[0].detach().numpy())
|
145 |
+
```
|
146 |
+
```json
|
147 |
+
Cosin Similarity: 0.8974346
|
148 |
+
```
|
149 |
+
## Requirements
|
150 |
+
python>=3.8
|
151 |
+
```shell
|
152 |
+
tqdm==4.65.0
|
153 |
+
torch==1.13.1
|
154 |
+
transformers==4.27.4
|
155 |
+
datasets==2.11.0
|
156 |
+
fairseq==0.12.2
|
157 |
+
```
|
158 |
+
|
159 |
+
## Citation
|
160 |
+
```bibtex
|
161 |
+
@misc{wu2023g2ptl,
|
162 |
+
title={G2PTL: A Pre-trained Model for Delivery Address and its Applications in Logistics System},
|
163 |
+
author={Lixia Wu and Jianlin Liu and Junhong Lou and Haoyuan Hu and Jianbin Zheng and Haomin Wen and Chao Song and Shu He},
|
164 |
+
year={2023},
|
165 |
+
eprint={2304.01559},
|
166 |
+
archivePrefix={arXiv},
|
167 |
+
primaryClass={cs.AI}
|
168 |
+
}
|
169 |
+
```
|
chn_2_code.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2850ae4e9d3ad005d519d2e1d3e7916b1a8fab7884ef9ad88da62d8159673ee2
|
3 |
+
size 6044124
|
config.json
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"G2PTL"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "configuration_G2PTL.G2PTLConfig",
|
7 |
+
"AutoModel": "modeling_G2PTL.G2PTL",
|
8 |
+
"AutoModelForMaskedLM": "modeling_G2PTL.G2PTL"
|
9 |
+
},
|
10 |
+
"attention_probs_dropout_prob": 0.1,
|
11 |
+
"classifier_dropout": null,
|
12 |
+
"hidden_act": "gelu",
|
13 |
+
"hidden_dropout_prob": 0.1,
|
14 |
+
"hidden_size": 768,
|
15 |
+
"initializer_range": 0.02,
|
16 |
+
"intermediate_size": 3072,
|
17 |
+
"layer_norm_eps": 1e-05,
|
18 |
+
"max_position_embeddings": 2048,
|
19 |
+
"model_type": "G2PTL",
|
20 |
+
"num_attention_heads": 12,
|
21 |
+
"num_hidden_layers": 12,
|
22 |
+
"output_attentions": true,
|
23 |
+
"output_hidden_states": true,
|
24 |
+
"pad_token_id": 0,
|
25 |
+
"position_embedding_type": "absolute",
|
26 |
+
"task_type_vocab_size": 3,
|
27 |
+
"torch_dtype": "float32",
|
28 |
+
"transformers_version": "4.27.1",
|
29 |
+
"type_vocab_size": 4,
|
30 |
+
"use_cache": true,
|
31 |
+
"use_task_id": true,
|
32 |
+
"vocab_size": 40000
|
33 |
+
}
|
configuration_G2PTL.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from transformers.configuration_utils import PretrainedConfig
|
5 |
+
|
6 |
+
|
7 |
+
class G2PTLConfig(PretrainedConfig):
|
8 |
+
r"""
|
9 |
+
G2PTL model configuration
|
10 |
+
|
11 |
+
Args:
|
12 |
+
vocab_size (`int`, *optional*, defaults to 40000):
|
13 |
+
Vocabulary size of the STELLAR model.
|
14 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
15 |
+
Dimensionality of the encoder layers and the pooler layer.
|
16 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
17 |
+
Number of hidden layers in the Transformer encoder.
|
18 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
19 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
20 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
21 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
22 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
23 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
24 |
+
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
25 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
26 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
27 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
28 |
+
The dropout ratio for the attention probabilities.
|
29 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
30 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
31 |
+
just in case (e.g., 512 or 1024 or 2048).
|
32 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
33 |
+
The vocabulary size of the `token_type_ids` passed.
|
34 |
+
task_type_vocab_size (`int`, *optional*, defaults to 3):
|
35 |
+
The vocabulary size of the `task_type_ids`
|
36 |
+
use_task_id (`bool`, *optional*, defaults to `False`):
|
37 |
+
Whether or not the model support `task_type_ids`
|
38 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
39 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
40 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
41 |
+
The epsilon used by the layer normalization layers.
|
42 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
43 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
|
44 |
+
positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
|
45 |
+
[Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
|
46 |
+
For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
|
47 |
+
with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
48 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
49 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
50 |
+
relevant if `config.is_decoder=True`.
|
51 |
+
classifier_dropout (`float`, *optional*):
|
52 |
+
The dropout ratio for the classification head.
|
53 |
+
"""
|
54 |
+
model_type = "M2PTL"
|
55 |
+
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
vocab_size=40000,
|
59 |
+
hidden_size=768,
|
60 |
+
num_hidden_layers=12,
|
61 |
+
num_attention_heads=12,
|
62 |
+
intermediate_size=3072,
|
63 |
+
hidden_act="gelu",
|
64 |
+
hidden_dropout_prob=0.1,
|
65 |
+
attention_probs_dropout_prob=0.1,
|
66 |
+
max_position_embeddings=2048,
|
67 |
+
type_vocab_size=4,
|
68 |
+
task_type_vocab_size=3,
|
69 |
+
use_task_id=True,
|
70 |
+
initializer_range=0.02,
|
71 |
+
layer_norm_eps=1e-05,
|
72 |
+
pad_token_id=0,
|
73 |
+
position_embedding_type="absolute",
|
74 |
+
use_cache=True,
|
75 |
+
classifier_dropout=None,
|
76 |
+
**kwargs
|
77 |
+
):
|
78 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
79 |
+
|
80 |
+
self.vocab_size = vocab_size
|
81 |
+
self.hidden_size = hidden_size
|
82 |
+
self.num_hidden_layers = num_hidden_layers
|
83 |
+
self.num_attention_heads = num_attention_heads
|
84 |
+
self.hidden_act = hidden_act
|
85 |
+
self.intermediate_size = intermediate_size
|
86 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
87 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
88 |
+
self.max_position_embeddings = max_position_embeddings
|
89 |
+
self.type_vocab_size = type_vocab_size
|
90 |
+
self.task_type_vocab_size = task_type_vocab_size
|
91 |
+
self.use_task_id = use_task_id
|
92 |
+
self.initializer_range = initializer_range
|
93 |
+
self.layer_norm_eps = layer_norm_eps
|
94 |
+
self.position_embedding_type = position_embedding_type
|
95 |
+
self.use_cache = use_cache
|
96 |
+
self.classifier_dropout = classifier_dropout
|
97 |
+
|
graphormer.py
ADDED
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from copy import deepcopy
|
5 |
+
from torch.nn.init import xavier_uniform_
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Parameter
|
8 |
+
from torch.nn.init import normal_
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from torch import Tensor, device
|
11 |
+
from .G2PTL_utils import *
|
12 |
+
from transformers.modeling_utils import ModuleUtilsMixin
|
13 |
+
from fairseq import utils
|
14 |
+
from fairseq.models import (
|
15 |
+
FairseqEncoder,
|
16 |
+
register_model,
|
17 |
+
register_model_architecture,
|
18 |
+
)
|
19 |
+
from fairseq.modules import (
|
20 |
+
LayerNorm,
|
21 |
+
)
|
22 |
+
|
23 |
+
def init_params(module, n_layers):
|
24 |
+
if isinstance(module, nn.Linear):
|
25 |
+
module.weight.data.normal_(mean=0.0, std=0.02 / math.sqrt(n_layers))
|
26 |
+
if module.bias is not None:
|
27 |
+
module.bias.data.zero_()
|
28 |
+
if isinstance(module, nn.Embedding):
|
29 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
30 |
+
|
31 |
+
|
32 |
+
@torch.jit.script
|
33 |
+
def softmax_dropout(input, dropout_prob: float, is_training: bool):
|
34 |
+
return F.dropout(F.softmax(input, -1), dropout_prob, is_training)
|
35 |
+
|
36 |
+
|
37 |
+
class SelfMultiheadAttention(nn.Module):
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
embed_dim,
|
41 |
+
num_heads,
|
42 |
+
dropout=0.0,
|
43 |
+
bias=True,
|
44 |
+
scaling_factor=1,
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
self.embed_dim = embed_dim
|
48 |
+
|
49 |
+
self.num_heads = num_heads
|
50 |
+
self.dropout = dropout
|
51 |
+
|
52 |
+
self.head_dim = embed_dim // num_heads
|
53 |
+
assert (self.head_dim * num_heads == self.embed_dim), "embed_dim must be divisible by num_heads"
|
54 |
+
self.scaling = (self.head_dim * scaling_factor) ** -0.5
|
55 |
+
|
56 |
+
self.linear_q = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
|
57 |
+
self.linear_k = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
|
58 |
+
self.linear_v = nn.Linear(self.embed_dim, self.num_heads * self.head_dim)
|
59 |
+
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias)
|
60 |
+
|
61 |
+
def forward(
|
62 |
+
self,
|
63 |
+
query: Tensor,
|
64 |
+
attn_bias: Tensor = None,
|
65 |
+
) -> Tensor:
|
66 |
+
n_graph, n_node, embed_dim = query.size()
|
67 |
+
# q, k, v = self.in_proj(query).chunk(3, dim=-1)
|
68 |
+
|
69 |
+
_shape = (-1, n_graph * self.num_heads, self.head_dim)
|
70 |
+
q = self.linear_q(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2) * self.scaling
|
71 |
+
k = self.linear_k(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
72 |
+
v = self.linear_v(query).contiguous().view(n_graph, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
73 |
+
|
74 |
+
attn_weights = torch.matmul(q, k.transpose(2, 3))
|
75 |
+
attn_weights = attn_weights + attn_bias
|
76 |
+
attn_probs = softmax_dropout(attn_weights, self.dropout, self.training)
|
77 |
+
|
78 |
+
attn = torch.matmul(attn_probs, v)
|
79 |
+
attn = attn.transpose(1, 2).contiguous().view(n_graph, -1, embed_dim)
|
80 |
+
attn = self.out_proj(attn)
|
81 |
+
return attn
|
82 |
+
|
83 |
+
|
84 |
+
class Graphormer3DEncoderLayer(nn.Module):
|
85 |
+
"""
|
86 |
+
Implements a Graphormer-3D Encoder Layer.
|
87 |
+
"""
|
88 |
+
|
89 |
+
def __init__(
|
90 |
+
self,
|
91 |
+
embedding_dim: int = 768,
|
92 |
+
ffn_embedding_dim: int = 3072,
|
93 |
+
num_attention_heads: int = 8,
|
94 |
+
dropout: float = 0.1,
|
95 |
+
attention_dropout: float = 0.1,
|
96 |
+
activation_dropout: float = 0.1,
|
97 |
+
) -> None:
|
98 |
+
super().__init__()
|
99 |
+
|
100 |
+
# Initialize parameters
|
101 |
+
self.embedding_dim = embedding_dim
|
102 |
+
self.num_attention_heads = num_attention_heads
|
103 |
+
self.attention_dropout = attention_dropout
|
104 |
+
|
105 |
+
self.dropout = dropout
|
106 |
+
self.activation_dropout = activation_dropout
|
107 |
+
|
108 |
+
self.self_attn = SelfMultiheadAttention(self.embedding_dim, num_attention_heads, dropout=attention_dropout)
|
109 |
+
# layer norm associated with the self attention layer
|
110 |
+
self.self_attn_layer_norm = nn.LayerNorm(self.embedding_dim)
|
111 |
+
self.fc1 = nn.Linear(self.embedding_dim, ffn_embedding_dim)
|
112 |
+
self.fc2 = nn.Linear(ffn_embedding_dim, self.embedding_dim)
|
113 |
+
self.final_layer_norm = nn.LayerNorm(self.embedding_dim)
|
114 |
+
|
115 |
+
def forward(self, x: Tensor, attn_bias: Tensor = None):
|
116 |
+
residual = x
|
117 |
+
x = self.self_attn_layer_norm(x)
|
118 |
+
x = self.self_attn(query=x, attn_bias=attn_bias)
|
119 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
120 |
+
x = residual + x
|
121 |
+
|
122 |
+
residual = x
|
123 |
+
x = self.final_layer_norm(x)
|
124 |
+
x = F.gelu(self.fc1(x))
|
125 |
+
x = F.dropout(x, p=self.activation_dropout, training=self.training)
|
126 |
+
x = self.fc2(x)
|
127 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
128 |
+
x = residual + x
|
129 |
+
return x
|
130 |
+
|
131 |
+
|
132 |
+
from fairseq.models import (
|
133 |
+
BaseFairseqModel,
|
134 |
+
register_model,
|
135 |
+
register_model_architecture,
|
136 |
+
)
|
137 |
+
|
138 |
+
|
139 |
+
class Graphormer3D(BaseFairseqModel):
|
140 |
+
def __init__(self):
|
141 |
+
super().__init__()
|
142 |
+
self.atom_types = 64
|
143 |
+
self.edge_types = 64 * 64
|
144 |
+
self.embed_dim = 768
|
145 |
+
self.layer_nums = 12
|
146 |
+
self.ffn_embed_dim = 768
|
147 |
+
self.blocks = 4
|
148 |
+
self.attention_heads = 48
|
149 |
+
self.input_dropout = 0.0
|
150 |
+
self.dropout = 0.1
|
151 |
+
self.attention_dropout = 0.1
|
152 |
+
self.activation_dropout = 0.0
|
153 |
+
self.node_loss_weight = 15
|
154 |
+
self.min_node_loss_weight = 1
|
155 |
+
self.eng_loss_weight = 1
|
156 |
+
self.num_kernel = 128
|
157 |
+
self.atom_encoder = nn.Embedding(self.atom_types, self.embed_dim, padding_idx=0)
|
158 |
+
self.edge_embedding = nn.Embedding(32, self.attention_heads, padding_idx=0)
|
159 |
+
self.input_dropout = nn.Dropout(0.1)
|
160 |
+
self.layers = nn.ModuleList(
|
161 |
+
[
|
162 |
+
Graphormer3DEncoderLayer(
|
163 |
+
self.embed_dim,
|
164 |
+
self.ffn_embed_dim,
|
165 |
+
num_attention_heads=self.attention_heads,
|
166 |
+
dropout=self.dropout,
|
167 |
+
attention_dropout=self.attention_dropout,
|
168 |
+
activation_dropout=self.activation_dropout,
|
169 |
+
)
|
170 |
+
for _ in range(self.layer_nums)
|
171 |
+
]
|
172 |
+
)
|
173 |
+
self.atom_encoder = nn.Embedding(512 * 9 + 1, self.embed_dim, padding_idx=0)
|
174 |
+
self.edge_encoder = nn.Embedding(512 * 3 + 1, self.attention_heads, padding_idx=0)
|
175 |
+
self.edge_type = 'multi_hop'
|
176 |
+
if self.edge_type == 'multi_hop':
|
177 |
+
self.edge_dis_encoder = nn.Embedding(16 * self.attention_heads * self.attention_heads, 1)
|
178 |
+
self.spatial_pos_encoder = nn.Embedding(512, self.attention_heads, padding_idx=0)
|
179 |
+
self.in_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
|
180 |
+
self.out_degree_encoder = nn.Embedding(512, self.embed_dim, padding_idx=0)
|
181 |
+
self.node_position_ids_encoder = nn.Embedding(10, self.embed_dim, padding_idx=0)
|
182 |
+
|
183 |
+
self.final_ln: Callable[[Tensor], Tensor] = nn.LayerNorm(self.embed_dim)
|
184 |
+
|
185 |
+
self.engergy_proj: Callable[[Tensor], Tensor] = NonLinear(self.embed_dim, 1)
|
186 |
+
self.energe_agg_factor: Callable[[Tensor], Tensor] = nn.Embedding(3, 1)
|
187 |
+
nn.init.normal_(self.energe_agg_factor.weight, 0, 0.01)
|
188 |
+
|
189 |
+
self.graph_token = nn.Embedding(1, 768)
|
190 |
+
self.graph_token_virtual_distance = nn.Embedding(1, self.attention_heads)
|
191 |
+
|
192 |
+
K = self.num_kernel
|
193 |
+
|
194 |
+
self.gbf: Callable[[Tensor, Tensor], Tensor] = GaussianLayer(K, self.edge_types)
|
195 |
+
self.bias_proj: Callable[[Tensor], Tensor] = NonLinear(K, self.attention_heads)
|
196 |
+
self.edge_proj: Callable[[Tensor], Tensor] = nn.Linear(K, self.embed_dim)
|
197 |
+
self.node_proc: Callable[[Tensor, Tensor, Tensor], Tensor] = NodeTaskHead(self.embed_dim, self.attention_heads)
|
198 |
+
|
199 |
+
def forward(self, node_feature, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids):
|
200 |
+
"""
|
201 |
+
node_feature: text embedding
|
202 |
+
spatial_pos: The shortest path length between nodes in the graph, shape: (n_graph, n_node, n_node)
|
203 |
+
in_degree: The in-degree of nodes in the graph, shape: (n_graph, n_node)
|
204 |
+
out_degree: The out-degree of nodes in the graph, shape: (n_graph, n_node)
|
205 |
+
edge_type_matrix: The edge type of edges in the graph
|
206 |
+
edge_input: The shortest path route between nodes in the graph, shape: (n_graph, n_node, n_node, multi_hop_max_dist, n_edge_features)
|
207 |
+
node_position_ids: node poistion ids
|
208 |
+
"""
|
209 |
+
attn_edge_type = self.edge_embedding(edge_type_matrix)
|
210 |
+
edge_input = self.edge_embedding(edge_input)
|
211 |
+
n_graph, n_node = node_feature.size()[:2]
|
212 |
+
spatial_pos_bias = self.spatial_pos_encoder(spatial_pos).permute(0, 3, 1, 2)
|
213 |
+
|
214 |
+
if self.edge_type == 'multi_hop':
|
215 |
+
spatial_pos_ = spatial_pos.clone()
|
216 |
+
spatial_pos_[spatial_pos_ == 0] = 1 # set pad to 1
|
217 |
+
spatial_pos_ = torch.where(spatial_pos_ > 1, spatial_pos_ - 1, spatial_pos_)
|
218 |
+
max_dist = edge_input.size(-2)
|
219 |
+
edge_input_flat = edge_input.permute(3, 0, 1, 2, 4).reshape(max_dist, -1, self.attention_heads)
|
220 |
+
edge_input_flat = torch.bmm(edge_input_flat, self.edge_dis_encoder.weight.reshape(-1, self.attention_heads, self.attention_heads)[:max_dist, :, :])
|
221 |
+
edge_input = edge_input_flat.reshape(max_dist, n_graph, n_node, n_node, self.attention_heads).permute(1, 2, 3, 0, 4)
|
222 |
+
edge_input = (edge_input.sum(-2) / (spatial_pos_.float().unsqueeze(-1))).permute(0, 3, 1, 2)
|
223 |
+
else:
|
224 |
+
# [n_graph, n_node, n_node, n_head] -> [n_graph, n_head, n_node, n_node]
|
225 |
+
edge_input = self.edge_encoder(attn_edge_type).mean(-2).permute(0, 3, 1, 2)
|
226 |
+
|
227 |
+
graph_attn_bias = spatial_pos_bias + edge_input
|
228 |
+
node_position_embedding = self.node_position_ids_encoder(node_position_ids)
|
229 |
+
node_position_embedding = node_position_embedding.contiguous().view(n_graph, n_node, self.embed_dim)
|
230 |
+
node_feature = node_feature + self.in_degree_encoder(in_degree) + \
|
231 |
+
self.out_degree_encoder(out_degree) + node_position_embedding
|
232 |
+
|
233 |
+
# transfomrer encoder
|
234 |
+
output = self.input_dropout(node_feature)
|
235 |
+
for enc_layer in self.layers:
|
236 |
+
output = enc_layer(output, graph_attn_bias)
|
237 |
+
output = self.final_ln(output)
|
238 |
+
|
239 |
+
return output
|
240 |
+
|
241 |
+
|
242 |
+
@torch.jit.script
|
243 |
+
def gaussian(x, mean, std):
|
244 |
+
pi = 3.14159
|
245 |
+
a = (2 * pi) ** 0.5
|
246 |
+
return torch.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std)
|
247 |
+
|
248 |
+
|
249 |
+
class GaussianLayer(nn.Module):
|
250 |
+
def __init__(self, K=128, edge_types=1024):
|
251 |
+
super().__init__()
|
252 |
+
self.K = K
|
253 |
+
self.means = nn.Embedding(1, K)
|
254 |
+
self.stds = nn.Embedding(1, K)
|
255 |
+
self.mul = nn.Embedding(edge_types, 1)
|
256 |
+
self.bias = nn.Embedding(edge_types, 1)
|
257 |
+
nn.init.uniform_(self.means.weight, 0, 3)
|
258 |
+
nn.init.uniform_(self.stds.weight, 0, 3)
|
259 |
+
nn.init.constant_(self.bias.weight, 0)
|
260 |
+
nn.init.constant_(self.mul.weight, 1)
|
261 |
+
|
262 |
+
def forward(self, x, edge_types):
|
263 |
+
mul = self.mul(edge_types)
|
264 |
+
bias = self.bias(edge_types)
|
265 |
+
x = mul * x.unsqueeze(-1) + bias
|
266 |
+
x = x.expand(-1, -1, -1, self.K)
|
267 |
+
mean = self.means.weight.float().view(-1)
|
268 |
+
std = self.stds.weight.float().view(-1).abs() + 1e-5
|
269 |
+
return gaussian(x.float(), mean, std).type_as(self.means.weight)
|
270 |
+
|
271 |
+
|
272 |
+
class RBF(nn.Module):
|
273 |
+
def __init__(self, K, edge_types):
|
274 |
+
super().__init__()
|
275 |
+
self.K = K
|
276 |
+
self.means = nn.parameter.Parameter(torch.empty(K))
|
277 |
+
self.temps = nn.parameter.Parameter(torch.empty(K))
|
278 |
+
self.mul: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
|
279 |
+
self.bias: Callable[..., Tensor] = nn.Embedding(edge_types, 1)
|
280 |
+
nn.init.uniform_(self.means, 0, 3)
|
281 |
+
nn.init.uniform_(self.temps, 0.1, 10)
|
282 |
+
nn.init.constant_(self.bias.weight, 0)
|
283 |
+
nn.init.constant_(self.mul.weight, 1)
|
284 |
+
|
285 |
+
def forward(self, x: Tensor, edge_types):
|
286 |
+
mul = self.mul(edge_types)
|
287 |
+
bias = self.bias(edge_types)
|
288 |
+
x = mul * x.unsqueeze(-1) + bias
|
289 |
+
mean = self.means.float()
|
290 |
+
temp = self.temps.float().abs()
|
291 |
+
return ((x - mean).square() * (-temp)).exp().type_as(self.means)
|
292 |
+
|
293 |
+
|
294 |
+
class NonLinear(nn.Module):
|
295 |
+
def __init__(self, input, output_size, hidden=None):
|
296 |
+
super(NonLinear, self).__init__()
|
297 |
+
if hidden is None:
|
298 |
+
hidden = input
|
299 |
+
self.layer1 = nn.Linear(input, hidden)
|
300 |
+
self.layer2 = nn.Linear(hidden, output_size)
|
301 |
+
|
302 |
+
def forward(self, x):
|
303 |
+
x = F.gelu(self.layer1(x))
|
304 |
+
x = self.layer2(x)
|
305 |
+
return x
|
306 |
+
|
307 |
+
|
308 |
+
class NodeTaskHead(nn.Module):
|
309 |
+
def __init__(
|
310 |
+
self,
|
311 |
+
embed_dim: int,
|
312 |
+
num_heads: int,
|
313 |
+
):
|
314 |
+
super().__init__()
|
315 |
+
self.embed_dim = embed_dim
|
316 |
+
self.q_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
317 |
+
self.k_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
318 |
+
self.v_proj: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, embed_dim)
|
319 |
+
self.num_heads = num_heads
|
320 |
+
self.scaling = (embed_dim // num_heads) ** -0.5
|
321 |
+
self.force_proj1: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
322 |
+
self.force_proj2: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
323 |
+
self.force_proj3: Callable[[Tensor], Tensor] = nn.Linear(embed_dim, 1)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
query: Tensor,
|
328 |
+
attn_bias: Tensor,
|
329 |
+
delta_pos: Tensor,
|
330 |
+
) -> Tensor:
|
331 |
+
bsz, n_node, _ = query.size()
|
332 |
+
q = (self.q_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2) * self.scaling)
|
333 |
+
k = self.k_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
|
334 |
+
v = self.v_proj(query).view(bsz, n_node, self.num_heads, -1).transpose(1, 2)
|
335 |
+
attn = q @ k.transpose(-1, -2) # [bsz, head, n, n]
|
336 |
+
attn_probs = softmax_dropout(attn.view(-1, n_node, n_node) + attn_bias, 0.1, self.training).view(bsz, self.num_heads, n_node, n_node)
|
337 |
+
rot_attn_probs = attn_probs.unsqueeze(-1) * delta_pos.unsqueeze(1).type_as(attn_probs) # [bsz, head, n, n, 3]
|
338 |
+
rot_attn_probs = rot_attn_probs.permute(0, 1, 4, 2, 3)
|
339 |
+
x = rot_attn_probs @ v.unsqueeze(2) # [bsz, head , 3, n, d]
|
340 |
+
x = x.permute(0, 3, 2, 1, 4).contiguous().view(bsz, n_node, 3, -1)
|
341 |
+
f1 = self.force_proj1(x[:, :, 0, :]).view(bsz, n_node, 1)
|
342 |
+
f2 = self.force_proj2(x[:, :, 1, :]).view(bsz, n_node, 1)
|
343 |
+
f3 = self.force_proj3(x[:, :, 2, :]).view(bsz, n_node, 1)
|
344 |
+
cur_force = torch.cat([f1, f2, f3], dim=-1).float()
|
345 |
+
return cur_force
|
346 |
+
|
htc_loss.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import pandas as pd
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
|
10 |
+
|
11 |
+
from transformers.utils.hub import cached_file
|
12 |
+
|
13 |
+
resolved_module_file = cached_file(
|
14 |
+
'JunhongLou/G2PTL',
|
15 |
+
'htc_mask_dict.pkl',
|
16 |
+
)
|
17 |
+
|
18 |
+
htc_weights = [0.067, 0.133, 0.2, 0.267, 0.333]
|
19 |
+
htc_mask_dict = pd.read_pickle(resolved_module_file)
|
20 |
+
import numpy as np
|
21 |
+
import operator
|
22 |
+
def calculate_multi_htc_acc_batch(predicted_htc, y, sequence_len = 6):
|
23 |
+
acc_cnt = np.array([0, 0, 0, 0, 0])
|
24 |
+
y = y.view(-1, sequence_len, 5).tolist()
|
25 |
+
predicted = np.array(predicted_htc).reshape(-1, sequence_len, 5).tolist()
|
26 |
+
batch_size = len(y)
|
27 |
+
total_cnt = np.array([0, 0, 0, 0, 0])
|
28 |
+
for batch_i in range(batch_size):
|
29 |
+
for index, s2 in enumerate(y[batch_i]):
|
30 |
+
for c, i in enumerate(range(5)):
|
31 |
+
y_l10 = y[batch_i][index][:i+1]
|
32 |
+
p_l10 = predicted[batch_i][index][:i+1]
|
33 |
+
if -100 in y_l10:
|
34 |
+
break
|
35 |
+
if operator.eq(y_l10, p_l10):
|
36 |
+
acc_cnt[c] += 1
|
37 |
+
total_cnt[c] += 1
|
38 |
+
return acc_cnt, total_cnt
|
39 |
+
|
40 |
+
|
41 |
+
class HTCLoss(torch.nn.Module):
|
42 |
+
def __init__(self, device, reduction='mean', using_htc = True):
|
43 |
+
super(HTCLoss, self).__init__()
|
44 |
+
self.reduction = reduction
|
45 |
+
self.htc_weights = htc_weights
|
46 |
+
self.device = device
|
47 |
+
self.using_htc = using_htc
|
48 |
+
self.htc_mask_dict = htc_mask_dict
|
49 |
+
for key, value in self.htc_mask_dict.items():
|
50 |
+
self.htc_mask_dict[key] = torch.tensor(value).clone().detach().to(self.device)
|
51 |
+
|
52 |
+
def forward(self, logits, target):
|
53 |
+
target = target.reshape(-1, 1)
|
54 |
+
target_mask = target != -100
|
55 |
+
target_mask = target_mask.squeeze()
|
56 |
+
target_mask_idx = torch.where(target == -100)
|
57 |
+
target_new = target.clone()
|
58 |
+
target_new[target_mask_idx] = 0
|
59 |
+
predict_res = []
|
60 |
+
if not self.using_htc:
|
61 |
+
log_pro = -1.0 * F.log_softmax(logits, dim=1)
|
62 |
+
else:
|
63 |
+
logits_reshaped = logits.clone()
|
64 |
+
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
|
65 |
+
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
|
66 |
+
aa_predicted += 1
|
67 |
+
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
|
68 |
+
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
|
69 |
+
for sample_idx, aa in enumerate(aa_predicted):
|
70 |
+
# Using mask_dict to get candidates for the next hierarchical
|
71 |
+
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
|
72 |
+
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
|
73 |
+
bb = bb_idx[bb_idy]
|
74 |
+
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
|
75 |
+
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
|
76 |
+
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
|
77 |
+
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
|
78 |
+
cc = cc_idx[cc_idy]
|
79 |
+
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
|
80 |
+
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
|
81 |
+
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
|
82 |
+
d = d_idx[d_idy]
|
83 |
+
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
|
84 |
+
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
|
85 |
+
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
|
86 |
+
ee = ee_idx[ee_idy]
|
87 |
+
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
|
88 |
+
|
89 |
+
logits_new = logits_new.reshape(-1, 100)
|
90 |
+
log_pro = -1.0 * F.log_softmax(logits_new, dim=1)
|
91 |
+
logits = logits.contiguous().view(-1, 100)
|
92 |
+
one_hot = torch.zeros(logits.shape[0], logits.shape[1]).to(self.device) # .cuda()
|
93 |
+
one_hot = one_hot.scatter_(1, target_new, 1)
|
94 |
+
loss = torch.mul(log_pro, one_hot).sum(dim=1)
|
95 |
+
loss = loss*target_mask
|
96 |
+
bs = int(loss.shape[0] / 5)
|
97 |
+
w_loss = []
|
98 |
+
for i in range(bs):
|
99 |
+
w_loss.extend(self.htc_weights)
|
100 |
+
w_loss = torch.FloatTensor(w_loss).to(self.device)
|
101 |
+
loss = loss.mul(w_loss) * 5
|
102 |
+
if self.reduction == 'mean':
|
103 |
+
loss = loss[torch.where(loss>0)].mean()
|
104 |
+
elif self.reduction == 'sum':
|
105 |
+
loss = loss[torch.where(loss>0)].sum()
|
106 |
+
return loss, predict_res
|
107 |
+
|
108 |
+
def get_htc_code(self, logits):
|
109 |
+
logits_reshaped = logits.clone()
|
110 |
+
logits_reshaped = logits_reshaped.reshape(-1, 5, 100)
|
111 |
+
_, aa_predicted = torch.max(logits_reshaped[:,0,1:32], 1)
|
112 |
+
aa_predicted += 1
|
113 |
+
logits_new = -5 * torch.ones_like(logits_reshaped).to(self.device)
|
114 |
+
logits_new[:,0,1:32] = logits_reshaped[:,0,1:32]
|
115 |
+
predict_res = []
|
116 |
+
for sample_idx, aa in enumerate(aa_predicted):
|
117 |
+
bb_idx = htc_mask_dict['{:02d}'.format(aa)]
|
118 |
+
_, bb_idy = torch.max(logits_reshaped[sample_idx,1,bb_idx], 0)
|
119 |
+
bb = bb_idx[bb_idy]
|
120 |
+
logits_new[sample_idx,1,bb_idx] = logits_reshaped[sample_idx,1,bb_idx]
|
121 |
+
cc_idx = htc_mask_dict['{:02d}{:02d}'.format(aa, bb)]
|
122 |
+
_, cc_idy = torch.max(logits_reshaped[sample_idx,2,cc_idx], 0)
|
123 |
+
logits_new[sample_idx,2,cc_idx] = logits_reshaped[sample_idx,2,cc_idx]
|
124 |
+
cc = cc_idx[cc_idy]
|
125 |
+
d_idx = htc_mask_dict['{:02d}{:02d}{:02d}'.format(aa, bb, cc)]
|
126 |
+
_, d_idy = torch.max(logits_reshaped[sample_idx,3,d_idx], 0)
|
127 |
+
logits_new[sample_idx,3,d_idx] = logits_reshaped[sample_idx,3,d_idx]
|
128 |
+
d = d_idx[d_idy]
|
129 |
+
ee_idx = htc_mask_dict['{:02d}{:02d}{:02d}{:01d}'.format(aa, bb, cc, d)]
|
130 |
+
_, ee_idy = torch.max(logits_reshaped[sample_idx,4,ee_idx], 0)
|
131 |
+
logits_new[sample_idx,4,ee_idx] = logits_reshaped[sample_idx,4,ee_idx]
|
132 |
+
ee = ee_idx[ee_idy]
|
133 |
+
predict_res.extend([aa.item(), bb.item(), cc.item(), d.item(), ee.item()])
|
134 |
+
return predict_res
|
135 |
+
|
htc_mask_dict.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cf03eaf44926730e193f5b37ccf7fb36561b411d64d635495b2e9c87d8e5ecea
|
3 |
+
size 250511
|
modeling_G2PTL.py
ADDED
@@ -0,0 +1,1024 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#! python3
|
2 |
+
# -*- encoding: utf-8 -*-
|
3 |
+
|
4 |
+
from copy import deepcopy
|
5 |
+
from torch.nn.init import xavier_uniform_
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Parameter
|
8 |
+
from torch.nn.init import normal_
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from torch import Tensor, device
|
11 |
+
from .G2PTL_utils import *
|
12 |
+
from transformers.modeling_utils import ModuleUtilsMixin
|
13 |
+
from .graphormer import Graphormer3D
|
14 |
+
import pickle
|
15 |
+
from transformers.modeling_outputs import ModelOutput
|
16 |
+
import numpy as np
|
17 |
+
# with open('remap_code_2_chn.bin', 'rb') as fr:
|
18 |
+
# remap_code_2_chn = pickle.loads(fr.read())
|
19 |
+
|
20 |
+
from .htc_loss import HTCLoss
|
21 |
+
from transformers.utils.hub import cached_file
|
22 |
+
remap_code_2_chn_file_path = cached_file(
|
23 |
+
'JunhongLou/G2PTL',
|
24 |
+
'remap_code_2_chn.pkl',
|
25 |
+
)
|
26 |
+
|
27 |
+
class G2PTLEmbedding(nn.Module):
|
28 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
29 |
+
|
30 |
+
def __init__(self, config):
|
31 |
+
super().__init__()
|
32 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
33 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
34 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
35 |
+
self.ner_type_embeddings = nn.Embedding(10, config.hidden_size)
|
36 |
+
self.use_task_id = config.use_task_id
|
37 |
+
if config.use_task_id:
|
38 |
+
self.task_type_embeddings = nn.Embedding(config.task_type_vocab_size, config.hidden_size)
|
39 |
+
|
40 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
41 |
+
# any TensorFlow checkpoint file
|
42 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
43 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
44 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
45 |
+
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
|
46 |
+
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
|
47 |
+
self.register_buffer("token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long),
|
48 |
+
persistent=False)
|
49 |
+
self._reset_parameters()
|
50 |
+
|
51 |
+
def forward(
|
52 |
+
self,
|
53 |
+
input_ids: Optional[torch.LongTensor] = None,
|
54 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
55 |
+
ner_type_ids: Optional[torch.LongTensor] = None,
|
56 |
+
task_type_ids: Optional[torch.LongTensor] = None,
|
57 |
+
position_ids: Optional[torch.LongTensor] = None,
|
58 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
59 |
+
past_key_values_length: int = 0,
|
60 |
+
) -> torch.Tensor:
|
61 |
+
if input_ids is not None:
|
62 |
+
input_shape = input_ids.size()
|
63 |
+
else:
|
64 |
+
input_shape = inputs_embeds.size()[:-1]
|
65 |
+
|
66 |
+
seq_length = input_shape[1]
|
67 |
+
|
68 |
+
if position_ids is None:
|
69 |
+
position_ids = self.position_ids[:, past_key_values_length: seq_length + past_key_values_length]
|
70 |
+
|
71 |
+
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
72 |
+
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
73 |
+
# issue #5664
|
74 |
+
if token_type_ids is None:
|
75 |
+
if hasattr(self, "token_type_ids"):
|
76 |
+
buffered_token_type_ids = self.token_type_ids[:, :seq_length]
|
77 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
|
78 |
+
token_type_ids = buffered_token_type_ids_expanded
|
79 |
+
else:
|
80 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
81 |
+
|
82 |
+
if inputs_embeds is None:
|
83 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
84 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
85 |
+
if ner_type_ids is not None:
|
86 |
+
ner_type_embeddings = self.ner_type_embeddings(ner_type_ids)
|
87 |
+
|
88 |
+
embeddings = inputs_embeds + token_type_embeddings + ner_type_embeddings
|
89 |
+
else:
|
90 |
+
embeddings = inputs_embeds + token_type_embeddings
|
91 |
+
if self.position_embedding_type == "absolute":
|
92 |
+
position_embeddings = self.position_embeddings(position_ids)
|
93 |
+
embeddings += position_embeddings
|
94 |
+
|
95 |
+
if self.use_task_id:
|
96 |
+
if task_type_ids is None:
|
97 |
+
task_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
|
98 |
+
task_type_embeddings = self.task_type_embeddings(task_type_ids)
|
99 |
+
embeddings += task_type_embeddings
|
100 |
+
|
101 |
+
embeddings = self.LayerNorm(embeddings)
|
102 |
+
embeddings = self.dropout(embeddings)
|
103 |
+
return embeddings
|
104 |
+
|
105 |
+
def _reset_parameters(self):
|
106 |
+
for p in self.parameters():
|
107 |
+
if p.dim() > 1:
|
108 |
+
normal_(p, mean=0.0, std=0.02)
|
109 |
+
|
110 |
+
def save_weights(self, path):
|
111 |
+
torch.save(self.state_dict(), path)
|
112 |
+
|
113 |
+
def load_weights(self, path):
|
114 |
+
self.load_state_dict(torch.load(path))
|
115 |
+
|
116 |
+
|
117 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert
|
118 |
+
class TransformerSelfAttention(nn.Module):
|
119 |
+
def __init__(self, config, position_embedding_type=None):
|
120 |
+
super().__init__()
|
121 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
122 |
+
raise ValueError(
|
123 |
+
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
124 |
+
f"heads ({config.num_attention_heads})"
|
125 |
+
)
|
126 |
+
|
127 |
+
self.num_attention_heads = config.num_attention_heads
|
128 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
129 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
130 |
+
|
131 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
132 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
133 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
134 |
+
|
135 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
136 |
+
self.position_embedding_type = position_embedding_type or getattr(
|
137 |
+
config, "position_embedding_type", "absolute"
|
138 |
+
)
|
139 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
140 |
+
self.max_position_embeddings = config.max_position_embeddings
|
141 |
+
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
|
142 |
+
|
143 |
+
self.is_decoder = config.is_decoder
|
144 |
+
|
145 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
146 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
147 |
+
x = x.view(new_x_shape)
|
148 |
+
return x.permute(0, 2, 1, 3)
|
149 |
+
|
150 |
+
def forward(
|
151 |
+
self,
|
152 |
+
hidden_states: torch.Tensor,
|
153 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
154 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
155 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
156 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
157 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
158 |
+
output_attentions: Optional[bool] = False,
|
159 |
+
) -> Tuple[torch.Tensor]:
|
160 |
+
mixed_query_layer = self.query(hidden_states)
|
161 |
+
|
162 |
+
# If this is instantiated as a cross-attention module, the keys
|
163 |
+
# and values come from an encoder; the attention mask needs to be
|
164 |
+
# such that the encoder's padding tokens are not attended to.
|
165 |
+
is_cross_attention = encoder_hidden_states is not None
|
166 |
+
|
167 |
+
if is_cross_attention and past_key_value is not None:
|
168 |
+
# reuse k,v, cross_attentions
|
169 |
+
key_layer = past_key_value[0]
|
170 |
+
value_layer = past_key_value[1]
|
171 |
+
attention_mask = encoder_attention_mask
|
172 |
+
elif is_cross_attention:
|
173 |
+
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
|
174 |
+
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
|
175 |
+
attention_mask = encoder_attention_mask
|
176 |
+
elif past_key_value is not None:
|
177 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
178 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
179 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
180 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
181 |
+
else:
|
182 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
183 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
184 |
+
|
185 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
186 |
+
|
187 |
+
use_cache = past_key_value is not None
|
188 |
+
if self.is_decoder:
|
189 |
+
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
|
190 |
+
# Further calls to cross_attention layer can then reuse all cross-attention
|
191 |
+
# key/value_states (first "if" case)
|
192 |
+
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
|
193 |
+
# all previous decoder key/value_states. Further calls to uni-directional self-attention
|
194 |
+
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
|
195 |
+
# if encoder bi-directional self-attention `past_key_value` is always `None`
|
196 |
+
past_key_value = (key_layer, value_layer)
|
197 |
+
|
198 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
199 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
200 |
+
|
201 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
202 |
+
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
203 |
+
if use_cache:
|
204 |
+
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
|
205 |
+
-1, 1
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
209 |
+
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
210 |
+
distance = position_ids_l - position_ids_r
|
211 |
+
|
212 |
+
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
|
213 |
+
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
|
214 |
+
|
215 |
+
if self.position_embedding_type == "relative_key":
|
216 |
+
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
217 |
+
attention_scores = attention_scores + relative_position_scores
|
218 |
+
elif self.position_embedding_type == "relative_key_query":
|
219 |
+
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
|
220 |
+
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
|
221 |
+
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
|
222 |
+
|
223 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
224 |
+
if attention_mask is not None:
|
225 |
+
# Apply the attention mask is (precomputed for all layers in TransformerModel forward() function)
|
226 |
+
attention_scores = attention_scores + attention_mask
|
227 |
+
|
228 |
+
# Normalize the attention scores to probabilities.
|
229 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
230 |
+
|
231 |
+
# This is actually dropping out entire tokens to attend to, which might
|
232 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
233 |
+
attention_probs = self.dropout(attention_probs)
|
234 |
+
|
235 |
+
# Mask heads if we want to
|
236 |
+
if head_mask is not None:
|
237 |
+
attention_probs = attention_probs * head_mask
|
238 |
+
|
239 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
240 |
+
|
241 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
242 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
243 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
244 |
+
|
245 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
246 |
+
|
247 |
+
if self.is_decoder:
|
248 |
+
outputs = outputs + (past_key_value,)
|
249 |
+
return outputs
|
250 |
+
|
251 |
+
|
252 |
+
# Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert
|
253 |
+
class TransformerSelfOutput(nn.Module):
|
254 |
+
def __init__(self, config):
|
255 |
+
super().__init__()
|
256 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
257 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
258 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
259 |
+
|
260 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
261 |
+
hidden_states = self.dense(hidden_states)
|
262 |
+
hidden_states = self.dropout(hidden_states)
|
263 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
264 |
+
return hidden_states
|
265 |
+
|
266 |
+
|
267 |
+
# Copied from transformers.models.bert.modeling_bert.BertAttention with Bert
|
268 |
+
class TransformerAttention(nn.Module):
|
269 |
+
def __init__(self, config, position_embedding_type=None):
|
270 |
+
super().__init__()
|
271 |
+
self.self = TransformerSelfAttention(config, position_embedding_type=position_embedding_type)
|
272 |
+
self.output = TransformerSelfOutput(config)
|
273 |
+
self.pruned_heads = set()
|
274 |
+
|
275 |
+
def prune_heads(self, heads):
|
276 |
+
if len(heads) == 0:
|
277 |
+
return
|
278 |
+
heads, index = find_pruneable_heads_and_indices(
|
279 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
280 |
+
)
|
281 |
+
|
282 |
+
# Prune linear layers
|
283 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
284 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
285 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
286 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
287 |
+
|
288 |
+
# Update hyper params and store pruned heads
|
289 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
290 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
291 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
292 |
+
|
293 |
+
def forward(
|
294 |
+
self,
|
295 |
+
hidden_states: torch.Tensor,
|
296 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
297 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
298 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
299 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
300 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
301 |
+
output_attentions: Optional[bool] = False,
|
302 |
+
) -> Tuple[torch.Tensor]:
|
303 |
+
self_outputs = self.self(
|
304 |
+
hidden_states,
|
305 |
+
attention_mask,
|
306 |
+
head_mask,
|
307 |
+
encoder_hidden_states,
|
308 |
+
encoder_attention_mask,
|
309 |
+
past_key_value,
|
310 |
+
output_attentions,
|
311 |
+
)
|
312 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
313 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
314 |
+
return outputs
|
315 |
+
|
316 |
+
# Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert
|
317 |
+
class TransformerIntermediate(nn.Module):
|
318 |
+
def __init__(self, config):
|
319 |
+
super().__init__()
|
320 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
321 |
+
if isinstance(config.hidden_act, str):
|
322 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
323 |
+
else:
|
324 |
+
self.intermediate_act_fn = config.hidden_act
|
325 |
+
|
326 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
327 |
+
hidden_states = self.dense(hidden_states)
|
328 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
329 |
+
return hidden_states
|
330 |
+
|
331 |
+
# Copied from transformers.models.bert.modeling_bert.BertOutput with Bert
|
332 |
+
class TransformerOutput(nn.Module):
|
333 |
+
def __init__(self, config):
|
334 |
+
super().__init__()
|
335 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
336 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
337 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
338 |
+
|
339 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
340 |
+
hidden_states = self.dense(hidden_states)
|
341 |
+
hidden_states = self.dropout(hidden_states)
|
342 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
343 |
+
return hidden_states
|
344 |
+
|
345 |
+
|
346 |
+
# Copied from transformers.models.bert.modeling_bert.BertLayer
|
347 |
+
class TransformerLayer(nn.Module):
|
348 |
+
def __init__(self, config):
|
349 |
+
super().__init__()
|
350 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
351 |
+
self.seq_len_dim = 1
|
352 |
+
self.attention = TransformerAttention(config)
|
353 |
+
self.is_decoder = config.is_decoder
|
354 |
+
self.add_cross_attention = config.add_cross_attention
|
355 |
+
if self.add_cross_attention:
|
356 |
+
if not self.is_decoder:
|
357 |
+
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
|
358 |
+
self.crossattention = TransformerAttention(config, position_embedding_type="absolute")
|
359 |
+
self.intermediate = TransformerIntermediate(config)
|
360 |
+
self.output = TransformerOutput(config)
|
361 |
+
|
362 |
+
def forward(
|
363 |
+
self,
|
364 |
+
hidden_states: torch.Tensor,
|
365 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
366 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
367 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
368 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
369 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
370 |
+
output_attentions: Optional[bool] = False,
|
371 |
+
) -> Tuple[torch.Tensor]:
|
372 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
373 |
+
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
|
374 |
+
self_attention_outputs = self.attention(
|
375 |
+
hidden_states,
|
376 |
+
attention_mask,
|
377 |
+
head_mask,
|
378 |
+
output_attentions=output_attentions,
|
379 |
+
past_key_value=self_attn_past_key_value,
|
380 |
+
)
|
381 |
+
attention_output = self_attention_outputs[0]
|
382 |
+
|
383 |
+
# if decoder, the last output is tuple of self-attn cache
|
384 |
+
if self.is_decoder:
|
385 |
+
outputs = self_attention_outputs[1:-1]
|
386 |
+
present_key_value = self_attention_outputs[-1]
|
387 |
+
else:
|
388 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
389 |
+
|
390 |
+
cross_attn_present_key_value = None
|
391 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
392 |
+
if not hasattr(self, "crossattention"):
|
393 |
+
raise ValueError(
|
394 |
+
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
|
395 |
+
" by setting `config.add_cross_attention=True`"
|
396 |
+
)
|
397 |
+
|
398 |
+
# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
|
399 |
+
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
|
400 |
+
cross_attention_outputs = self.crossattention(
|
401 |
+
attention_output,
|
402 |
+
attention_mask,
|
403 |
+
head_mask,
|
404 |
+
encoder_hidden_states,
|
405 |
+
encoder_attention_mask,
|
406 |
+
cross_attn_past_key_value,
|
407 |
+
output_attentions,
|
408 |
+
)
|
409 |
+
attention_output = cross_attention_outputs[0]
|
410 |
+
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
|
411 |
+
|
412 |
+
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
413 |
+
cross_attn_present_key_value = cross_attention_outputs[-1]
|
414 |
+
present_key_value = present_key_value + cross_attn_present_key_value
|
415 |
+
|
416 |
+
layer_output = apply_chunking_to_forward(
|
417 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
418 |
+
)
|
419 |
+
outputs = (layer_output,) + outputs
|
420 |
+
|
421 |
+
# if decoder, return the attn key/values as the last output
|
422 |
+
if self.is_decoder:
|
423 |
+
outputs = outputs + (present_key_value,)
|
424 |
+
|
425 |
+
return outputs
|
426 |
+
|
427 |
+
def feed_forward_chunk(self, attention_output):
|
428 |
+
intermediate_output = self.intermediate(attention_output)
|
429 |
+
layer_output = self.output(intermediate_output, attention_output)
|
430 |
+
return layer_output
|
431 |
+
|
432 |
+
|
433 |
+
class TransformerEncoder(nn.Module):
|
434 |
+
def __init__(self, config):
|
435 |
+
super().__init__()
|
436 |
+
self.config = config
|
437 |
+
self.layer = nn.ModuleList([TransformerLayer(config) for _ in range(config.num_hidden_layers)])
|
438 |
+
self.gradient_checkpointing = False
|
439 |
+
|
440 |
+
def forward(
|
441 |
+
self,
|
442 |
+
hidden_states: torch.Tensor,
|
443 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
444 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
445 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
446 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
447 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
448 |
+
use_cache: Optional[bool] = None,
|
449 |
+
output_attentions: Optional[bool] = False,
|
450 |
+
output_hidden_states: Optional[bool] = False,
|
451 |
+
return_dict: Optional[bool] = True,
|
452 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
453 |
+
all_hidden_states = () if output_hidden_states else None
|
454 |
+
all_self_attentions = () if output_attentions else None
|
455 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
456 |
+
|
457 |
+
next_decoder_cache = () if use_cache else None
|
458 |
+
for i, layer_module in enumerate(self.layer):
|
459 |
+
if output_hidden_states:
|
460 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
461 |
+
|
462 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
463 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
464 |
+
|
465 |
+
if self.gradient_checkpointing and self.training:
|
466 |
+
|
467 |
+
if use_cache:
|
468 |
+
logger.warning(
|
469 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
470 |
+
)
|
471 |
+
use_cache = False
|
472 |
+
|
473 |
+
def create_custom_forward(module):
|
474 |
+
def custom_forward(*inputs):
|
475 |
+
return module(*inputs, past_key_value, output_attentions)
|
476 |
+
|
477 |
+
return custom_forward
|
478 |
+
|
479 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
480 |
+
create_custom_forward(layer_module),
|
481 |
+
hidden_states,
|
482 |
+
attention_mask,
|
483 |
+
layer_head_mask,
|
484 |
+
encoder_hidden_states,
|
485 |
+
encoder_attention_mask,
|
486 |
+
)
|
487 |
+
else:
|
488 |
+
layer_outputs = layer_module(
|
489 |
+
hidden_states,
|
490 |
+
attention_mask,
|
491 |
+
layer_head_mask,
|
492 |
+
encoder_hidden_states,
|
493 |
+
encoder_attention_mask,
|
494 |
+
past_key_value,
|
495 |
+
output_attentions,
|
496 |
+
)
|
497 |
+
|
498 |
+
hidden_states = layer_outputs[0]
|
499 |
+
if use_cache:
|
500 |
+
next_decoder_cache += (layer_outputs[-1],)
|
501 |
+
if output_attentions:
|
502 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
503 |
+
if self.config.add_cross_attention:
|
504 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
|
505 |
+
|
506 |
+
if output_hidden_states:
|
507 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
508 |
+
|
509 |
+
if not return_dict:
|
510 |
+
return tuple(
|
511 |
+
v
|
512 |
+
for v in [
|
513 |
+
hidden_states,
|
514 |
+
next_decoder_cache,
|
515 |
+
all_hidden_states,
|
516 |
+
all_self_attentions,
|
517 |
+
all_cross_attentions,
|
518 |
+
]
|
519 |
+
if v is not None
|
520 |
+
)
|
521 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
522 |
+
last_hidden_state=hidden_states,
|
523 |
+
past_key_values=next_decoder_cache,
|
524 |
+
hidden_states=all_hidden_states,
|
525 |
+
attentions=all_self_attentions,
|
526 |
+
cross_attentions=all_cross_attentions,
|
527 |
+
)
|
528 |
+
|
529 |
+
|
530 |
+
# Copied from transformers.models.bert.modeling_bert.BertPooler
|
531 |
+
class Pooler(nn.Module):
|
532 |
+
def __init__(self, config):
|
533 |
+
super().__init__()
|
534 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
535 |
+
self.activation = nn.Tanh()
|
536 |
+
|
537 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
538 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
539 |
+
# to the first token.
|
540 |
+
first_token_tensor = hidden_states[:, 0]
|
541 |
+
pooled_output = self.dense(first_token_tensor)
|
542 |
+
pooled_output = self.activation(pooled_output)
|
543 |
+
return pooled_output
|
544 |
+
|
545 |
+
|
546 |
+
class TransformerModel(nn.Module):
|
547 |
+
"""
|
548 |
+
"""
|
549 |
+
|
550 |
+
def __init__(self, config, add_pooling_layer=True):
|
551 |
+
super().__init__()
|
552 |
+
self.config = config
|
553 |
+
self.encoder = TransformerEncoder(config)
|
554 |
+
self.pooler = Pooler(config) if add_pooling_layer else None
|
555 |
+
# Initialize weights and apply final processing
|
556 |
+
self._reset_parameters()
|
557 |
+
|
558 |
+
# Copied from transformers.models.bert.modeling_bert.BertModel._prune_heads
|
559 |
+
def _prune_heads(self, heads_to_prune):
|
560 |
+
"""
|
561 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
562 |
+
class PreTrainedModel
|
563 |
+
"""
|
564 |
+
for layer, heads in heads_to_prune.items():
|
565 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
566 |
+
|
567 |
+
def forward(
|
568 |
+
self,
|
569 |
+
h_input,
|
570 |
+
input_ids: Optional[torch.Tensor] = None,
|
571 |
+
attention_mask: Optional[torch.Tensor] = None,
|
572 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
573 |
+
task_type_ids: Optional[torch.Tensor] = None,
|
574 |
+
position_ids: Optional[torch.Tensor] = None,
|
575 |
+
head_mask: Optional[torch.Tensor] = None,
|
576 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
577 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
578 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
579 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
580 |
+
use_cache: Optional[bool] = None,
|
581 |
+
output_attentions: Optional[bool] = None,
|
582 |
+
output_hidden_states: Optional[bool] = None,
|
583 |
+
return_dict: Optional[bool] = None,
|
584 |
+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
585 |
+
r"""
|
586 |
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
587 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
588 |
+
the model is configured as a decoder.
|
589 |
+
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
590 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
591 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
592 |
+
|
593 |
+
- 1 for tokens that are **not masked**,
|
594 |
+
- 0 for tokens that are **masked**.
|
595 |
+
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
596 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
597 |
+
|
598 |
+
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
599 |
+
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
600 |
+
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
601 |
+
use_cache (`bool`, *optional*):
|
602 |
+
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
603 |
+
`past_key_values`).
|
604 |
+
"""
|
605 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
606 |
+
output_hidden_states = (
|
607 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
608 |
+
)
|
609 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
610 |
+
|
611 |
+
if self.config.is_decoder:
|
612 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
613 |
+
else:
|
614 |
+
use_cache = False
|
615 |
+
|
616 |
+
if input_ids is not None and inputs_embeds is not None:
|
617 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
618 |
+
elif input_ids is not None:
|
619 |
+
input_shape = input_ids.size()
|
620 |
+
elif inputs_embeds is not None:
|
621 |
+
input_shape = inputs_embeds.size()[:-1]
|
622 |
+
else:
|
623 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
624 |
+
|
625 |
+
batch_size, seq_length = input_shape
|
626 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
627 |
+
|
628 |
+
# past_key_values_length
|
629 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
630 |
+
|
631 |
+
if attention_mask is None:
|
632 |
+
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
633 |
+
|
634 |
+
if token_type_ids is None:
|
635 |
+
if hasattr(self.embeddings, "token_type_ids"):
|
636 |
+
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
637 |
+
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
638 |
+
token_type_ids = buffered_token_type_ids_expanded
|
639 |
+
else:
|
640 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
641 |
+
|
642 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
643 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
644 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
645 |
+
|
646 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
647 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
648 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
649 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
650 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
651 |
+
if encoder_attention_mask is None:
|
652 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
653 |
+
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
654 |
+
else:
|
655 |
+
encoder_extended_attention_mask = None
|
656 |
+
|
657 |
+
# Prepare head mask if needed
|
658 |
+
# 1.0 in head_mask indicate we keep the head
|
659 |
+
# attention_probs has shape bsz x n_heads x N x N
|
660 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
661 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
662 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
663 |
+
|
664 |
+
encoder_outputs = self.encoder(
|
665 |
+
h_input,
|
666 |
+
attention_mask=extended_attention_mask,
|
667 |
+
head_mask=head_mask,
|
668 |
+
encoder_hidden_states=encoder_hidden_states,
|
669 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
670 |
+
past_key_values=past_key_values,
|
671 |
+
use_cache=use_cache,
|
672 |
+
output_attentions=output_attentions,
|
673 |
+
output_hidden_states=output_hidden_states,
|
674 |
+
return_dict=return_dict,
|
675 |
+
)
|
676 |
+
sequence_output = encoder_outputs[0]
|
677 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
678 |
+
|
679 |
+
if not return_dict:
|
680 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
681 |
+
|
682 |
+
return BaseModelOutputWithPoolingAndCrossAttentions(
|
683 |
+
last_hidden_state=sequence_output,
|
684 |
+
pooler_output=pooled_output,
|
685 |
+
past_key_values=encoder_outputs.past_key_values,
|
686 |
+
hidden_states=encoder_outputs.hidden_states,
|
687 |
+
attentions=encoder_outputs.attentions,
|
688 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
689 |
+
)
|
690 |
+
|
691 |
+
def get_extended_attention_mask(
|
692 |
+
self, attention_mask: Tensor, input_shape: Tuple[int], device: device = None, dtype: torch.float = None
|
693 |
+
) -> Tensor:
|
694 |
+
"""
|
695 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
696 |
+
|
697 |
+
Arguments:
|
698 |
+
attention_mask (`torch.Tensor`):
|
699 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
700 |
+
input_shape (`Tuple[int]`):
|
701 |
+
The shape of the input to the model.
|
702 |
+
|
703 |
+
Returns:
|
704 |
+
`torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
|
705 |
+
"""
|
706 |
+
if dtype is None:
|
707 |
+
dtype = torch.float32
|
708 |
+
|
709 |
+
if not (attention_mask.dim() == 2 and self.config.is_decoder):
|
710 |
+
# show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
|
711 |
+
if device is not None:
|
712 |
+
warnings.warn(
|
713 |
+
"The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
|
714 |
+
)
|
715 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
716 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
717 |
+
if attention_mask.dim() == 3:
|
718 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
719 |
+
elif attention_mask.dim() == 2:
|
720 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
721 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
722 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
723 |
+
if self.config.is_decoder:
|
724 |
+
extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
|
725 |
+
input_shape, attention_mask, device
|
726 |
+
)
|
727 |
+
else:
|
728 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
729 |
+
else:
|
730 |
+
raise ValueError(
|
731 |
+
f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
|
732 |
+
)
|
733 |
+
|
734 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
735 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
736 |
+
# positions we want to attend and the dtype's smallest value for masked positions.
|
737 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
738 |
+
# effectively the same as removing these entirely.
|
739 |
+
extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
|
740 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
|
741 |
+
return extended_attention_mask
|
742 |
+
|
743 |
+
def get_head_mask(
|
744 |
+
self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
|
745 |
+
) -> Tensor:
|
746 |
+
"""
|
747 |
+
Prepare the head mask if needed.
|
748 |
+
|
749 |
+
Args:
|
750 |
+
head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
|
751 |
+
The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
|
752 |
+
num_hidden_layers (`int`):
|
753 |
+
The number of hidden layers in the model.
|
754 |
+
is_attention_chunked: (`bool`, *optional*, defaults to `False`):
|
755 |
+
Whether or not the attentions scores are computed by chunks or not.
|
756 |
+
|
757 |
+
Returns:
|
758 |
+
`torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
|
759 |
+
`[None]` for each layer.
|
760 |
+
"""
|
761 |
+
if head_mask is not None:
|
762 |
+
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
|
763 |
+
if is_attention_chunked is True:
|
764 |
+
head_mask = head_mask.unsqueeze(-1)
|
765 |
+
else:
|
766 |
+
head_mask = [None] * num_hidden_layers
|
767 |
+
|
768 |
+
return head_mask
|
769 |
+
|
770 |
+
def _reset_parameters(self):
|
771 |
+
r"""Initiate parameters in the transformer model."""
|
772 |
+
for p in self.parameters():
|
773 |
+
if p.dim() > 1:
|
774 |
+
normal_(p, mean=0.0, std=self.config.initializer_range)
|
775 |
+
|
776 |
+
def save_weights(self, path):
|
777 |
+
torch.save(self.state_dict(), path)
|
778 |
+
|
779 |
+
def load_weights(self, path):
|
780 |
+
self.load_state_dict(torch.load(path))
|
781 |
+
|
782 |
+
@dataclass
|
783 |
+
|
784 |
+
@dataclass
|
785 |
+
class G2PTLMaskedLMOutput(ModelOutput):
|
786 |
+
"""
|
787 |
+
Base class for masked language models outputs.
|
788 |
+
|
789 |
+
Args:
|
790 |
+
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
|
791 |
+
Masked language modeling (MLM) loss.
|
792 |
+
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
793 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
794 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
795 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
796 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
797 |
+
|
798 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
799 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
800 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
801 |
+
sequence_length)`.
|
802 |
+
|
803 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
804 |
+
heads.
|
805 |
+
"""
|
806 |
+
|
807 |
+
loss: Optional[torch.FloatTensor] = None
|
808 |
+
logits: torch.FloatTensor = None
|
809 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
810 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
811 |
+
gc_layer_out: Optional[torch.FloatTensor] = None
|
812 |
+
final_pooler_output: Optional[torch.FloatTensor] = None
|
813 |
+
final_hidden_state: Optional[torch.FloatTensor] = None
|
814 |
+
last_hidden_state: Optional[torch.FloatTensor] = None
|
815 |
+
htc_layer_out: Optional[Tuple[torch.FloatTensor]] = None
|
816 |
+
|
817 |
+
from transformers.activations import ACT2FN
|
818 |
+
# Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert
|
819 |
+
class TransformerPredictionHeadTransform(nn.Module):
|
820 |
+
def __init__(self, config):
|
821 |
+
super().__init__()
|
822 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
823 |
+
if isinstance(config.hidden_act, str):
|
824 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
825 |
+
else:
|
826 |
+
self.transform_act_fn = config.hidden_act
|
827 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
828 |
+
|
829 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
830 |
+
hidden_states = self.dense(hidden_states)
|
831 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
832 |
+
hidden_states = self.LayerNorm(hidden_states)
|
833 |
+
return hidden_states
|
834 |
+
|
835 |
+
# Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert
|
836 |
+
class TransformerLMPredictionHead(nn.Module):
|
837 |
+
def __init__(self, config):
|
838 |
+
super().__init__()
|
839 |
+
self.transform = TransformerPredictionHeadTransform(config)
|
840 |
+
|
841 |
+
# The output weights are the same as the input embeddings, but there is
|
842 |
+
# an output-only bias for each token.
|
843 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
844 |
+
|
845 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
846 |
+
|
847 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
848 |
+
self.decoder.bias = self.bias
|
849 |
+
|
850 |
+
def forward(self, hidden_states):
|
851 |
+
hidden_states = self.transform(hidden_states)
|
852 |
+
hidden_states = self.decoder(hidden_states)
|
853 |
+
return hidden_states
|
854 |
+
|
855 |
+
|
856 |
+
# Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->Transformer
|
857 |
+
class TransformerOnlyMLMHead(nn.Module):
|
858 |
+
def __init__(self, config):
|
859 |
+
super().__init__()
|
860 |
+
self.predictions = TransformerLMPredictionHead(config)
|
861 |
+
|
862 |
+
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
|
863 |
+
prediction_scores = self.predictions(sequence_output)
|
864 |
+
return prediction_scores
|
865 |
+
|
866 |
+
class G2PTL(PreTrainedModel):
|
867 |
+
def __init__(self, config, return_last_hidden_state=False):
|
868 |
+
super(G2PTL, self).__init__(config)
|
869 |
+
|
870 |
+
self.config = deepcopy(config)
|
871 |
+
self.return_last_hidden_state = return_last_hidden_state
|
872 |
+
self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
|
873 |
+
# ================ G2PTLEmbedding =====================
|
874 |
+
self.embedding = G2PTLEmbedding(self.config)
|
875 |
+
# ================ TransformerModel =====================
|
876 |
+
self.G2PTL_config = deepcopy(config)
|
877 |
+
self.transformer_model = TransformerModel(self.G2PTL_config)
|
878 |
+
# ================ TranSAGE =====================
|
879 |
+
self.graphormer = Graphormer3D()
|
880 |
+
# ================ Encoding =====================
|
881 |
+
self.encoder_config = deepcopy(config)
|
882 |
+
self.encoder_config.num_hidden_layers = 1
|
883 |
+
self.encoder = TransformerModel(self.encoder_config)
|
884 |
+
self.encoder_out_dim = self.encoder_config.hidden_size
|
885 |
+
# ================ GC =====================
|
886 |
+
self.gc_trans = nn.Linear(self.encoder_out_dim, 16 * 33, bias=True)
|
887 |
+
# ================ MLM =====================
|
888 |
+
self.cls = TransformerOnlyMLMHead(self.G2PTL_config)
|
889 |
+
# ================ HTC =====================
|
890 |
+
self.htc_trans = nn.Linear(self.encoder_out_dim, 5 * 100, bias=True)
|
891 |
+
# ================ alias =====================
|
892 |
+
self.down_hidden_dim = 512
|
893 |
+
self.down_kernel_num = 128
|
894 |
+
self.alias_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
|
895 |
+
self.alias_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
|
896 |
+
self.alias_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
|
897 |
+
# ================ AOI =====================
|
898 |
+
self.aoi_trans = nn.Linear(self.encoder_out_dim, self.down_hidden_dim, bias=True)
|
899 |
+
self.aoi_trans2 = torch.nn.Conv2d(1, self.down_kernel_num, (2, self.down_hidden_dim), stride=1, bias=True)
|
900 |
+
self.aoi_layer = nn.Linear(self.down_kernel_num * 5, 2 * 5, bias=True)
|
901 |
+
|
902 |
+
self._reset_parameters()
|
903 |
+
|
904 |
+
def forward(self,
|
905 |
+
input_ids,
|
906 |
+
attention_mask : Optional[torch.Tensor] = None,
|
907 |
+
token_type_ids : Optional[torch.Tensor] = None,
|
908 |
+
node_position_ids: Optional[torch.Tensor] = None,
|
909 |
+
spatial_pos: Optional[torch.Tensor] = None,
|
910 |
+
in_degree: Optional[torch.Tensor] = None,
|
911 |
+
out_degree: Optional[torch.Tensor] = None,
|
912 |
+
edge_type_matrix: Optional[torch.Tensor] = None,
|
913 |
+
edge_input : Optional[torch.Tensor] = None,
|
914 |
+
prov_city_mask: Optional[torch.Tensor] = None,
|
915 |
+
sequence_len : Optional[int] = 1,
|
916 |
+
labels: Optional[torch.Tensor] = None
|
917 |
+
):
|
918 |
+
"""
|
919 |
+
:param input_ids: [sequence_len * batch_size, src_len]
|
920 |
+
:param attention_mask: [sequence_len * batch_size, src_len]
|
921 |
+
:param token_type_ids: [sequence_len * batch_size, src_len]
|
922 |
+
:param sequence_len: int
|
923 |
+
:param labels:
|
924 |
+
:param is_eval: bool
|
925 |
+
:return:
|
926 |
+
"""
|
927 |
+
|
928 |
+
batch_size_input = int(input_ids.shape[0] / sequence_len)
|
929 |
+
|
930 |
+
# If the model inputs missing graph information, a single-node subgraph is constructed by default.
|
931 |
+
if spatial_pos is None:
|
932 |
+
# The shortest path length between nodes in the graph
|
933 |
+
spatial_pos = torch.LongTensor(np.zeros((batch_size_input, 1, 1), dtype=np.int64)).to(self.device)
|
934 |
+
if in_degree is None:
|
935 |
+
# The in-degree of nodes in the graph
|
936 |
+
in_degree = torch.LongTensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
|
937 |
+
if out_degree is None:
|
938 |
+
# The out-degree of nodes in the graph
|
939 |
+
out_degree = torch.LongTensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
|
940 |
+
if edge_type_matrix is None:
|
941 |
+
# The edge type of edges in the graph
|
942 |
+
edge_type_matrix = torch.LongTensor(8*np.ones((batch_size_input, 1, 1), dtype=np.int64)).to(self.device)
|
943 |
+
if edge_input is None:
|
944 |
+
# The shortest path route between nodes in the graph
|
945 |
+
edge_input = torch.LongTensor(8*np.ones((batch_size_input, 1, 1, 1), dtype=np.int64)).to(self.device)
|
946 |
+
if node_position_ids is None:
|
947 |
+
# node poistion ids
|
948 |
+
node_position_ids = torch.tensor(np.ones((batch_size_input, 1), dtype=np.int64)).to(self.device)
|
949 |
+
|
950 |
+
embedding_output = self.embedding(input_ids=input_ids, token_type_ids=token_type_ids)
|
951 |
+
|
952 |
+
transformer_predictions = self.transformer_model(embedding_output,
|
953 |
+
input_ids=input_ids,
|
954 |
+
token_type_ids=token_type_ids,
|
955 |
+
attention_mask=attention_mask)
|
956 |
+
last_hidden_state = transformer_predictions[0].contiguous().view(batch_size_input, sequence_len, -1,
|
957 |
+
self.encoder_out_dim)
|
958 |
+
pooler_output = transformer_predictions[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
|
959 |
+
|
960 |
+
h_ = self.graphormer(pooler_output, spatial_pos, in_degree, out_degree, edge_type_matrix, edge_input, node_position_ids)
|
961 |
+
h_ = h_.unsqueeze(2)
|
962 |
+
new_hidden_state = torch.cat((h_, last_hidden_state[:, :, 1:, :]), dim=2)
|
963 |
+
new_hidden_state = new_hidden_state.contiguous().view(batch_size_input * sequence_len, -1, self.encoder_out_dim)
|
964 |
+
encoder_outputs = self.encoder(new_hidden_state,
|
965 |
+
input_ids=input_ids,
|
966 |
+
token_type_ids=token_type_ids,
|
967 |
+
attention_mask=attention_mask)
|
968 |
+
final_hidden_state = encoder_outputs[0]
|
969 |
+
final_pooler_output = encoder_outputs[1].contiguous().view(batch_size_input, sequence_len, self.encoder_out_dim)
|
970 |
+
prediction_scores = self.cls(final_hidden_state) # Logits For MLM
|
971 |
+
|
972 |
+
gc_layer_out = self.gc_trans(final_pooler_output)
|
973 |
+
gc_layer_out = gc_layer_out.contiguous().view(-1, 16) # Logits For GC
|
974 |
+
|
975 |
+
htc_layer_out = self.htc_trans(final_pooler_output)
|
976 |
+
htc_layer_out = htc_layer_out.contiguous().view(-1, 100) # Logits For HTC
|
977 |
+
|
978 |
+
masked_lm_loss = None
|
979 |
+
|
980 |
+
# MLM loss
|
981 |
+
if labels is not None:
|
982 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
983 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
984 |
+
|
985 |
+
if self.return_last_hidden_state:
|
986 |
+
return final_pooler_output, pooler_output
|
987 |
+
|
988 |
+
return G2PTLMaskedLMOutput(
|
989 |
+
loss=masked_lm_loss,
|
990 |
+
logits=prediction_scores,
|
991 |
+
hidden_states=final_hidden_state,
|
992 |
+
attentions=encoder_outputs.attentions,
|
993 |
+
gc_layer_out = gc_layer_out,
|
994 |
+
final_pooler_output = final_pooler_output,
|
995 |
+
final_hidden_state = final_hidden_state,
|
996 |
+
last_hidden_state = last_hidden_state,
|
997 |
+
htc_layer_out = htc_layer_out
|
998 |
+
)
|
999 |
+
|
1000 |
+
def get_htc_code(self, htc_layer_out):
|
1001 |
+
htc_loss_fct = HTCLoss(device=self.device, reduction='mean')
|
1002 |
+
htc_pred = htc_loss_fct.get_htc_code(htc_layer_out)
|
1003 |
+
return htc_pred
|
1004 |
+
|
1005 |
+
def decode_htc_code_2_chn(self, htc_pred):
|
1006 |
+
with open(remap_code_2_chn_file_path, 'rb') as fr:
|
1007 |
+
remap_code_2_chn = pickle.loads(fr.read())
|
1008 |
+
htc_pred = np.array(htc_pred).reshape(-1, 5)
|
1009 |
+
htc_res = []
|
1010 |
+
for arr in htc_pred:
|
1011 |
+
htc_res.append(remap_code_2_chn['{:02d}{:02d}{:02d}{:01d}{:02d}'.format(arr[0], arr[1], arr[2], arr[3], arr[4])])
|
1012 |
+
return htc_res
|
1013 |
+
|
1014 |
+
def _reset_parameters(self):
|
1015 |
+
for p in self.parameters():
|
1016 |
+
if p.dim() > 1:
|
1017 |
+
xavier_uniform_(p)
|
1018 |
+
|
1019 |
+
def save_weights(self, path):
|
1020 |
+
torch.save(self.state_dict(), path)
|
1021 |
+
|
1022 |
+
def load_weights(self, path):
|
1023 |
+
self.load_state_dict(torch.load(path, map_location=torch.device('cpu')), False)
|
1024 |
+
|
pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:21e06d160d8ffddc861d52f65e07e8dbe459feb666f9f33f856a169c1a5eb244
|
3 |
+
size 833629489
|
remap_code_2_chn.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e998605c058964cd9cead64edeaecfadef6bd754c025c28b1bacb5af5fe02f3
|
3 |
+
size 4159356
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
torch==1.13.1
|
3 |
+
transformers==4.27.4
|
4 |
+
datasets
|
5 |
+
fairseq
|
special_tokens_map.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"mask_token": "[MASK]",
|
4 |
+
"pad_token": "[PAD]",
|
5 |
+
"sep_token": "[SEP]",
|
6 |
+
"unk_token": "[UNK]"
|
7 |
+
}
|
tokenizer_config.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cls_token": "[CLS]",
|
3 |
+
"do_basic_tokenize": true,
|
4 |
+
"do_lower_case": true,
|
5 |
+
"mask_token": "[MASK]",
|
6 |
+
"model_max_length": 1000000000000000019884624838656,
|
7 |
+
"never_split": null,
|
8 |
+
"pad_token": "[PAD]",
|
9 |
+
"sep_token": "[SEP]",
|
10 |
+
"special_tokens_map_file": null,
|
11 |
+
"strip_accents": null,
|
12 |
+
"tokenize_chinese_chars": true,
|
13 |
+
"tokenizer_class": "BertTokenizer",
|
14 |
+
"unk_token": "[UNK]"
|
15 |
+
}
|
vocab.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|