|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
from ..tresnet import TResnetM, TResnetL, TResnetXL |
|
|
|
|
|
def create_model(args): |
|
"""Create a model |
|
""" |
|
model_params = {'args': args, 'num_classes': args.num_classes} |
|
args = model_params['args'] |
|
args.model_name = args.model_name.lower() |
|
|
|
if args.model_name=='tresnet_m': |
|
model = TResnetM(model_params) |
|
elif args.model_name=='tresnet_l': |
|
model = TResnetL(model_params) |
|
elif args.model_name=='tresnet_xl': |
|
model = TResnetXL(model_params) |
|
else: |
|
print("model: {} not found !!".format(args.model_name)) |
|
exit(-1) |
|
|
|
return model |
|
|