echarlaix HF staff commited on
Commit
242da95
1 Parent(s): eeab0c8

Get openvino exporter config

Browse files
Files changed (1) hide show
  1. export.py +6 -10
export.py CHANGED
@@ -62,22 +62,18 @@ def convert_openvino(model_id: str, task: str, folder: str) -> List:
62
  if not isinstance(ov_model, OVStableDiffusionPipeline):
63
  try:
64
  model = TasksManager.get_model_from_task(task, model_id)
65
- onnx_config_class = TasksManager.get_exporter_config_constructor(
66
- exporter="onnx",
67
  model=model,
68
  task=task,
69
  model_name=model_id,
70
  model_type=model.config.model_type.replace("_", "-"),
71
  )
72
- onnx_config = onnx_config_class(model.config)
73
- inputs = onnx_config.generate_dummy_inputs(framework="pt")
74
 
75
- if isinstance(ov_model, (OVModelForCausalLM, OVModelForSeq2SeqLM)):
76
- ov_outputs = ov_model.generate(**inputs)
77
- outputs = model.generate(**inputs)
78
- else:
79
- ov_outputs = ov_model(**inputs)
80
- outputs = model(**inputs)
81
 
82
  if isinstance(outputs, torch.Tensor):
83
  outputs = {"logits": outputs}
 
62
  if not isinstance(ov_model, OVStableDiffusionPipeline):
63
  try:
64
  model = TasksManager.get_model_from_task(task, model_id)
65
+ exporter_config_class = TasksManager.get_exporter_config_constructor(
66
+ exporter="openvino",
67
  model=model,
68
  task=task,
69
  model_name=model_id,
70
  model_type=model.config.model_type.replace("_", "-"),
71
  )
72
+ openvino_config = exporter_config_class(model.config)
73
+ inputs = openvino_config.generate_dummy_inputs(framework="pt")
74
 
75
+ ov_outputs = ov_model(**inputs)
76
+ outputs = model(**inputs)
 
 
 
 
77
 
78
  if isinstance(outputs, torch.Tensor):
79
  outputs = {"logits": outputs}