|
from transformers import pipeline |
|
from vista3d_config import VISTA3DConfig |
|
from vista3d_model import VISTA3DModel, register_my_model |
|
from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline |
|
|
|
|
|
class HuggingFacePipelineHelper: |
|
|
|
def __init__(self, pipeline_name: str = "vista3d"): |
|
self.pipeline_name = pipeline_name |
|
|
|
def __model_register(self): |
|
register_my_model() |
|
|
|
def __pipeline_register(self): |
|
register_simple_pipeline() |
|
|
|
def get_pipeline(self): |
|
self.__model_register() |
|
self.__pipeline_register() |
|
return pipeline(self.pipeline_name) |
|
|
|
def _update_config(self, config, config_dict): |
|
if config_dict: |
|
for key in config_dict: |
|
if hasattr(config, key) and getattr(config, key) != config_dict[key]: |
|
setattr(config, key, config_dict[key]) |
|
return config |
|
|
|
def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs): |
|
config = VISTA3DConfig() |
|
config_dict = kwargs.pop("config_dict", None) |
|
self._update_config(config, config_dict) |
|
model = VISTA3DModel(config) |
|
model.from_pretrained( |
|
pretrained_model_name_or_path=pretrained_model_name_or_path |
|
) |
|
return VISTA3DPipeline(model, **kwargs) |
|
|