File size: 1,304 Bytes
98f2419 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
#!/usr/bin/env python
# coding=utf-8
"""
A model maps "text_only" data to float.
"""
from lmflow.models.regression_model import RegressionModel
from lmflow.datasets.dataset import Dataset
class TextRegressionModel(RegressionModel):
r"""
Initializes a TextRegressionModel instance.
Parameters
------------
model_args :
Model arguments such as model name, path, revision, etc.
args : Optional.
Positional arguments.
kwargs : Optional.
Keyword arguments.
"""
def __init__(
self,
model_args,
*args,
**kwargs
):
"""
Initializes a TextRegressionModel instance.
:param model_args: dictionary with model arguments such as model name, path, revision, etc.
"""
self.inference_func = None
def register_inference_function(self, inference_func):
"""
Registers a regression function.
"""
self.inference_func = inference_func
def inference(self, inputs: Dataset):
"""
Gets regression results of a given dataset.
:inputs: Dataset object, only accept type "text_only".
"""
if self.inference_func is not None:
return self.inference_func(inputs)
else:
pass
|