sharpenb's picture
bdad6d44046fc7851385dc26915bb1f65ba8ad10bc585fd6f826f0fd8bf2da57
09b8c12 verified
raw
history blame
167 Bytes
from torch import nn
FC_CLASS_REGISTRY = {'torch': nn.Linear}
try:
import transformer_engine.pytorch as te
FC_CLASS_REGISTRY['te'] = te.Linear
except:
pass