Rediones-AI / utils /model_serving /model_loader.py
Testys's picture
Pushing First version before making full changes
67d6f5b
raw
history blame
1.1 kB
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM
class ModelLoader:
def __init__(self):
self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
self.blip_model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base')
self.topic_generator_processor = AutoTokenizer.from_pretrained("google/flan-t5-large")
self.topic_generator_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")
def load_blip(self):
model = self.blip_model
processor = self.blip_processor
model.eval()
return model, processor
def load_topic_generator(self):
model = self.topic_generator_model
processor = self.topic_generator_processor
model.eval()
return model, processor
# testing the model
model_load = ModelLoader()
blip_models, blip_processors = model_load.blip_model()
topic_generator_models, topic_generator_processors = model_load.load_topic_generator()