hasibzunair's picture
inital files
46fdf2a
raw
history blame
651 Bytes
import logging
logger = logging.getLogger(__name__)
from ..tresnet import TResnetM, TResnetL, TResnetXL
def create_model(args):
"""Create a model
"""
model_params = {'args': args, 'num_classes': args.num_classes}
args = model_params['args']
args.model_name = args.model_name.lower()
if args.model_name=='tresnet_m':
model = TResnetM(model_params)
elif args.model_name=='tresnet_l':
model = TResnetL(model_params)
elif args.model_name=='tresnet_xl':
model = TResnetXL(model_params)
else:
print("model: {} not found !!".format(args.model_name))
exit(-1)
return model