from transformers import Pipeline | |
class MultitaskTokenClassificationPipeline(Pipeline): | |
def _sanitize_parameters(self, **kwargs): | |
preprocess_kwargs = {} | |
if "text" in kwargs: | |
preprocess_kwargs["text"] = kwargs["text"] | |
return preprocess_kwargs, {}, {} | |
def preprocess(self, text, **kwargs): | |
return text | |
def _forward(self, text): | |
print(f"Do we arrive here? {text}") | |
print(f"Let's check the model: {self.model.get_floret_model()}") | |
# predictions, probabilities = self.model.get_floret_model().predict([text], k=1) | |
self.model(text) | |
return text | |
def postprocess(self, text, **kwargs): | |
""" | |
Postprocess the outputs of the model | |
:param outputs: | |
:param kwargs: | |
:return: | |
""" | |
# print(f"Let's check the model: {self.model.get_floret_model()}") | |
# predictions, probabilities = self.model.get_floret_model().predict([text], k=1) | |
# | |
# label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix | |
# confidence = float( | |
# probabilities[0][0] | |
# ) # Convert to float for JSON serialization | |
# | |
# # Format as JSON-compatible dictionary | |
# model_output = {"label": label, "confidence": round(confidence * 100, 2)} | |
# print("Formatted Model Output:", model_output) | |
return text | |