Spaces:
Runtime error
Runtime error
File size: 2,305 Bytes
58627fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel, AutoTokenizer
from colbert.utils.utils import torch_load_dnn
class HF_ColBERT(BertPreTrainedModel):
"""
Shallow wrapper around HuggingFace transformers. All new parameters should be defined at this level.
This makes sure `{from,save}_pretrained` and `init_weights` are applied to new parameters correctly.
"""
_keys_to_ignore_on_load_unexpected = [r"cls"]
def __init__(self, config, colbert_config):
super().__init__(config)
self.dim = colbert_config.dim
self.bert = BertModel(config)
self.linear = nn.Linear(config.hidden_size, colbert_config.dim, bias=False)
# if colbert_config.relu:
# self.score_scaler = nn.Linear(1, 1)
self.init_weights()
# if colbert_config.relu:
# self.score_scaler.weight.data.fill_(1.0)
# self.score_scaler.bias.data.fill_(-8.0)
@classmethod
def from_pretrained(cls, name_or_path, colbert_config):
if name_or_path.endswith('.dnn'):
dnn = torch_load_dnn(name_or_path)
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')
obj = super().from_pretrained(base, state_dict=dnn['model_state_dict'], colbert_config=colbert_config)
obj.base = base
return obj
obj = super().from_pretrained(name_or_path, colbert_config=colbert_config)
obj.base = name_or_path
return obj
@staticmethod
def raw_tokenizer_from_pretrained(name_or_path):
if name_or_path.endswith('.dnn'):
dnn = torch_load_dnn(name_or_path)
base = dnn.get('arguments', {}).get('model', 'bert-base-uncased')
obj = AutoTokenizer.from_pretrained(base)
obj.base = base
return obj
obj = AutoTokenizer.from_pretrained(name_or_path)
obj.base = name_or_path
return obj
"""
TODO: It's easy to write a class generator that takes "name_or_path" and loads AutoConfig to check the Architecture's
name, finds that name's *PreTrainedModel and *Model in dir(transformers), and then basically repeats the above.
It's easy for the BaseColBERT class to instantiate things from there.
"""
|