Taizo Kaneko
commit files to HF hub
8c5c4c6
raw
history blame
1.7 kB
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig
import torch
from transformers import ZeroShotClassificationPipeline
class CustomConfig(PretrainedConfig):
model_type = "test-zeroshot"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class CustomModel(PreTrainedModel):
config_class = CustomConfig
def __init__(self, config: CustomConfig):
super().__init__(config)
self.config = config
self.embeddings = torch.nn.Embedding(num_embeddings=1, embedding_dim=1)
def forward(self, **kwargs) -> SequenceClassifierOutput:
return SequenceClassifierOutput(logits=torch.tensor([[1, 2, 3]]))
from transformers.pipelines import PIPELINE_REGISTRY
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification
if __name__ == "__main__":
from transformers import pipeline
classifier = pipeline("zero-shot-classification",
model="cl-tohoku/bert-base-japanese")
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification
CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class("AutoModel")
p = ZeroShotClassificationPipeline(model=CustomModel(CustomConfig()),
tokenizer=classifier.tokenizer)
from huggingface_hub import Repository
repo = Repository("zero-shot-classification",
clone_from="paulhindemith/zero-shot-classification")
p.save_pretrained("zero-shot-classification")
repo.push_to_hub()