|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.configuration_utils import PretrainedConfig |
|
import torch |
|
from transformers import ZeroShotClassificationPipeline |
|
|
|
|
|
class CustomConfig(PretrainedConfig): |
|
model_type = "test-zeroshot" |
|
|
|
def __init__(self, **kwargs): |
|
super().__init__(**kwargs) |
|
|
|
|
|
class CustomModel(PreTrainedModel): |
|
config_class = CustomConfig |
|
|
|
def __init__(self, config: CustomConfig): |
|
super().__init__(config) |
|
self.config = config |
|
self.embeddings = torch.nn.Embedding(num_embeddings=1, embedding_dim=1) |
|
|
|
def forward(self, **kwargs) -> SequenceClassifierOutput: |
|
|
|
return SequenceClassifierOutput(logits=torch.tensor([[1, 2, 3]])) |
|
|
|
|
|
from transformers.pipelines import PIPELINE_REGISTRY |
|
|
|
from transformers import AutoModelForSequenceClassification, TFAutoModelForSequenceClassification |
|
|
|
if __name__ == "__main__": |
|
from transformers import pipeline |
|
classifier = pipeline("zero-shot-classification", |
|
model="cl-tohoku/bert-base-japanese") |
|
from transformers import AutoConfig, AutoModel, AutoModelForImageClassification |
|
|
|
CustomConfig.register_for_auto_class() |
|
CustomModel.register_for_auto_class("AutoModel") |
|
|
|
p = ZeroShotClassificationPipeline(model=CustomModel(CustomConfig()), |
|
tokenizer=classifier.tokenizer) |
|
from huggingface_hub import Repository |
|
|
|
repo = Repository("zero-shot-classification", |
|
clone_from="paulhindemith/zero-shot-classification") |
|
p.save_pretrained("zero-shot-classification") |
|
repo.push_to_hub() |