File size: 1,699 Bytes
8c5c4c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
import torch
from transformers import ZeroShotClassificationPipeline
class CustomConfig(PretrainedConfig):
model_type = "test-zeroshot"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class CustomModel(PreTrainedModel):
config_class = CustomConfig
def __init__(self, config: CustomConfig):
super().__init__(config)
self.config = config
self.embeddings = torch.nn.Embedding(num_embeddings=1, embedding_dim=1)
def forward(self, **kwargs) -> SequenceClassifierOutput:
return SequenceClassifierOutput(logits=torch.tensor([[1, 2, 3]]))
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
if __name__ == "__main__":
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="cl-tohoku/bert-base-japanese")
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class("AutoModel")
p = ZeroShotClassificationPipeline(model=CustomModel(CustomConfig()),
tokenizer=classifier.tokenizer)
from huggingface_hub import Repository
repo = Repository("zero-shot-classification",
clone_from="paulhindemith/zero-shot-classification")
p.save_pretrained("zero-shot-classification")
repo.push_to_hub() |