Cancer-Risk-Prediction-from-WSI / classes /genomic_plip_model.py
VatsalPatel18's picture
Upload 32 files
7b1328a verified
raw
history blame contribute delete
No virus
708 Bytes
import torch
from transformers import PreTrainedModel, CLIPConfig, CLIPModel
class GenomicPLIPModel(PreTrainedModel):
config_class = CLIPConfig
def __init__(self, config):
super(GenomicPLIPModel, self).__init__(config)
vision_config = CLIPModel.config_class.from_pretrained('openai/clip-vit-base-patch32')
self.vision_model = CLIPModel(vision_config).vision_model
self.vision_projection = torch.nn.Linear(768, 512)
def forward(self, pixel_values):
vision_output = self.vision_model(pixel_values)
pooled_output = vision_output.pooler_output
vision_features = self.vision_projection(pooled_output)
return vision_features