Token-classification / core /classificator.py
Arthour
Dependencies fix
c9ac1af
from dataclasses import dataclass
from typing import List, Dict, Any
import requests
class ClassificationError(Exception):
pass
@dataclass
class Classification:
entity: str
start: int
end: int
def dict(self) -> Dict[str, Any]:
return {
'entity': self.entity,
'start': self.start,
'end': self.end
}
class Classificator:
def __init__(self, config: Dict[str, Any]):
"""
Initialize the classificator with the given configuration
"""
self._config = config
def classify(self, text: str) -> List[Classification]:
raw_data = self.send_request(text)
return self.post_process(raw_data)
def send_request(self, text: str) -> List[Dict[str, Any]]:
"""
Process the text and return a list of dictionaries with the following keys
"""
headers = {
'Authorization': self._config['auth_endpoint_token'],
'Content-Type': 'application/json',
}
try:
response = requests.post(self._config['endpoint_url'], headers=headers, json={'inputs': text})
return response.json()
except Exception:
raise ClassificationError('Classification failed')
@staticmethod
def post_process(raw_data: List[Dict[str, Any]]) -> List[Classification]:
"""
Process the raw data and return a list of dictionaries with the following keys
raw_data is a list of dictionaries with the following keys
{'entity': 'B-Evaluation', 'score': 0.86011535, 'index': 1, 'word': 'Things', 'start': 0, 'end': 6}
result is a list of classifications with the following keys
Classification(entity='Evaluation', start=0, end=6)
"""
classifications = []
current_entity = None
for item in raw_data:
if current_entity is None or current_entity != item['entity'][2:]:
current_entity = item['entity'][2:]
classifications.append(
Classification(
entity=current_entity,
start=item['start'],
end=item['end']
)
)
else:
classifications[-1].end = item['end']
return classifications