File size: 562 Bytes
7c1c90f
 
 
 
 
 
 
 
 
 
 
 
 
2374ae7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import  PreTrainedModel
from facenet_pytorch import MTCNN, InceptionResnetV1
from deepfakeconfig import DeepFakeConfig

class DeepFakeModel(PreTrainedModel):
    config_class = DeepFakeConfig
    def __init__(self, config):
        super().__init__(config)
        self.model = InceptionResnetV1(
            pretrained="vggface2",
            classify=True,
            num_classes=1,
            device=config.DEVICE
        )


DeepFakeConfig.register_for_auto_class()
DeepFakeModel.register_for_auto_class("AutoModelForImageClassification")