Katsumata420
commited on
Commit
•
4dcdff9
1
Parent(s):
de9ad40
Update modeling_retrieva_bert.py
Browse files
modeling_retrieva_bert.py
CHANGED
@@ -65,7 +65,7 @@ from .configuration_retrieva_bert import RetrievaBertConfig
|
|
65 |
logger = logging.get_logger(__name__)
|
66 |
|
67 |
_CONFIG_FOR_DOC = "RetrievaBertConfig"
|
68 |
-
_CHECKPOINT_FOR_DOC = "
|
69 |
|
70 |
|
71 |
def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
|
@@ -1170,8 +1170,8 @@ class RetrievaBertForPreTraining(RetrievaBertPreTrainedModel):
|
|
1170 |
>>> from models import RetrievaBertForPreTraining
|
1171 |
>>> import torch
|
1172 |
|
1173 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
1174 |
-
>>> model = RetrievaBertForPreTraining.from_pretrained("
|
1175 |
|
1176 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1177 |
>>> outputs = model(**inputs)
|
@@ -1294,8 +1294,8 @@ class RetrievaBertForCausalLM(RetrievaBertPreTrainedModel):
|
|
1294 |
>>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
|
1295 |
>>> import torch
|
1296 |
|
1297 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
1298 |
-
>>> model = RetrievaBertForCausalLM.from_pretrained("
|
1299 |
|
1300 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1301 |
>>> outputs = model(**inputs)
|
@@ -1528,8 +1528,8 @@ class RetrievaBertForNextSentencePrediction(RetrievaBertPreTrainedModel):
|
|
1528 |
>>> from models import RetrievaBertForNextSentencePrediction
|
1529 |
>>> import torch
|
1530 |
|
1531 |
-
>>> tokenizer = AutoTokenizer.from_pretrained("
|
1532 |
-
>>> model = RetrievaBertForNextSentencePrediction.from_pretrained("
|
1533 |
|
1534 |
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1535 |
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
|
|
65 |
logger = logging.get_logger(__name__)
|
66 |
|
67 |
_CONFIG_FOR_DOC = "RetrievaBertConfig"
|
68 |
+
_CHECKPOINT_FOR_DOC = "retrieva-jp/bert-1.3b"
|
69 |
|
70 |
|
71 |
def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
|
|
|
1170 |
>>> from models import RetrievaBertForPreTraining
|
1171 |
>>> import torch
|
1172 |
|
1173 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
1174 |
+
>>> model = RetrievaBertForPreTraining.from_pretrained("retrieva-jp/bert-1.3b")
|
1175 |
|
1176 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1177 |
>>> outputs = model(**inputs)
|
|
|
1294 |
>>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
|
1295 |
>>> import torch
|
1296 |
|
1297 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
1298 |
+
>>> model = RetrievaBertForCausalLM.from_pretrained("retrieva-jp/bert-1.3b", is_decoder=True)
|
1299 |
|
1300 |
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1301 |
>>> outputs = model(**inputs)
|
|
|
1528 |
>>> from models import RetrievaBertForNextSentencePrediction
|
1529 |
>>> import torch
|
1530 |
|
1531 |
+
>>> tokenizer = AutoTokenizer.from_pretrained("retrieva-jp/bert-1.3b")
|
1532 |
+
>>> model = RetrievaBertForNextSentencePrediction.from_pretrained("retrieva-jp/bert-1.3b")
|
1533 |
|
1534 |
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1535 |
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|