File size: 1,421 Bytes
075c00d 8e5dcff 075c00d 137cb44 075c00d 137cb44 d86a2a4 527eb4e ce6d631 d7edcb3 075c00d d7edcb3 075c00d 3776ec2 770b039 3776ec2 e2ba16e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
from transformers import Pipeline
class MultitaskTokenClassificationPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
return text
def _forward(self, text):
print(f"Do we arrive here? {text}")
print(f"Let's check the model: {self.model.get_floret_model()}")
# predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
self.model(text)
return text
def postprocess(self, text, **kwargs):
"""
Postprocess the outputs of the model
:param outputs:
:param kwargs:
:return:
"""
# print(f"Let's check the model: {self.model.get_floret_model()}")
# predictions, probabilities = self.model.get_floret_model().predict([text], k=1)
#
# label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix
# confidence = float(
# probabilities[0][0]
# ) # Convert to float for JSON serialization
#
# # Format as JSON-compatible dictionary
# model_output = {"label": label, "confidence": round(confidence * 100, 2)}
# print("Formatted Model Output:", model_output)
return text
|