File size: 1,102 Bytes
67d6f5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM


class ModelLoader:
    def __init__(self):
        self.blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
        self.blip_model = BlipForConditionalGeneration.from_pretrained('Salesforce/blip-image-captioning-base')
        self.topic_generator_processor = AutoTokenizer.from_pretrained("google/flan-t5-large")
        self.topic_generator_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-large")

    
    def load_blip(self):
        model = self.blip_model
        processor = self.blip_processor
        model.eval()
        return model, processor
    
    def load_topic_generator(self):
        model = self.topic_generator_model
        processor = self.topic_generator_processor
        model.eval()
        return model, processor
        

# testing the model
model_load = ModelLoader()
blip_models, blip_processors = model_load.blip_model()
topic_generator_models, topic_generator_processors = model_load.load_topic_generator()