HowJanusSeesItself / model_loader.py
thomasgauthier's picture
ZeroGPU fixes
fc91aa0
raw
history blame
563 Bytes
import torch
from transformers import AutoModelForCausalLM
from janus.models import MultiModalityCausalLM, VLChatProcessor
import spaces
@spaces.GPU
def load_model_and_processor(model_path):
vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path)
tokenizer = vl_chat_processor.tokenizer
vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained(
model_path, trust_remote_code=True, torch_dtype=torch.bfloat16
)
vl_gpt = vl_gpt.to(torch.bfloat16).cuda().eval()
return vl_gpt, vl_chat_processor