VISTA3D-HF / hugging_face_pipeline.py
BinLiunls's picture
init version
08efd84
from transformers import pipeline
from vista3d_config import VISTA3DConfig
from vista3d_model import VISTA3DModel, register_my_model
from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline
class HuggingFacePipelineHelper:
def __init__(self, pipeline_name: str = "vista3d"):
self.pipeline_name = pipeline_name
def __model_register(self):
register_my_model()
def __pipeline_register(self):
register_simple_pipeline()
def get_pipeline(self):
self.__model_register()
self.__pipeline_register()
return pipeline(self.pipeline_name)
def _update_config(self, config, config_dict):
if config_dict:
for key in config_dict:
if hasattr(config, key) and getattr(config, key) != config_dict[key]:
setattr(config, key, config_dict[key])
return config
def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs):
config = VISTA3DConfig()
config_dict = kwargs.pop("config_dict", None)
self._update_config(config, config_dict)
model = VISTA3DModel(config)
model.from_pretrained(
pretrained_model_name_or_path=pretrained_model_name_or_path
)
return VISTA3DPipeline(model, **kwargs)