Shitao commited on
Commit
712acfa
·
verified ·
1 Parent(s): ef2764c

Update OmniGen/model.py

Browse files
Files changed (1) hide show
  1. OmniGen/model.py +7 -2
OmniGen/model.py CHANGED
@@ -9,6 +9,7 @@ from typing import Dict
9
  from diffusers.loaders import PeftAdapterMixin
10
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
11
  from huggingface_hub import snapshot_download
 
12
 
13
  from OmniGen.transformer import Phi3Config, Phi3Transformer
14
 
@@ -187,14 +188,18 @@ class OmniGen(nn.Module, PeftAdapterMixin):
187
 
188
  @classmethod
189
  def from_pretrained(cls, model_name):
190
- if not os.path.exists(os.path.join(model_name, 'model.pt')):
191
  cache_folder = os.getenv('HF_HUB_CACHE')
192
  model_name = snapshot_download(repo_id=model_name,
193
  cache_dir=cache_folder,
194
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
195
  config = Phi3Config.from_pretrained(model_name)
196
  model = cls(config)
197
- ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
 
 
 
 
198
  model.load_state_dict(ckpt)
199
  return model
200
 
 
9
  from diffusers.loaders import PeftAdapterMixin
10
  from timm.models.vision_transformer import PatchEmbed, Attention, Mlp
11
  from huggingface_hub import snapshot_download
12
+ from safetensors.torch import load_file
13
 
14
  from OmniGen.transformer import Phi3Config, Phi3Transformer
15
 
 
188
 
189
  @classmethod
190
  def from_pretrained(cls, model_name):
191
+ if not os.path.exists(model_name):
192
  cache_folder = os.getenv('HF_HUB_CACHE')
193
  model_name = snapshot_download(repo_id=model_name,
194
  cache_dir=cache_folder,
195
  ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
196
  config = Phi3Config.from_pretrained(model_name)
197
  model = cls(config)
198
+ if os.path.exists(os.path.join(model_name, 'model.safetensors')):
199
+ print("Loading safetensors")
200
+ ckpt = load_file(os.path.join(model_name, 'model.safetensors'))
201
+ else:
202
+ ckpt = torch.load(os.path.join(model_name, 'model.pt'), map_location='cpu')
203
  model.load_state_dict(ckpt)
204
  return model
205