|
import torch |
|
from transformers import PreTrainedModel, CLIPConfig, CLIPModel |
|
|
|
class GenomicPLIPModel(PreTrainedModel): |
|
config_class = CLIPConfig |
|
|
|
def __init__(self, config): |
|
super(GenomicPLIPModel, self).__init__(config) |
|
vision_config = CLIPModel.config_class.from_pretrained('openai/clip-vit-base-patch32') |
|
self.vision_model = CLIPModel(vision_config).vision_model |
|
self.vision_projection = torch.nn.Linear(768, 512) |
|
|
|
def forward(self, pixel_values): |
|
vision_output = self.vision_model(pixel_values) |
|
pooled_output = vision_output.pooler_output |
|
vision_features = self.vision_projection(pooled_output) |
|
|
|
return vision_features |
|
|