from transformers.modeling_outputs import TokenClassifierOutput | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig | |
from torch.nn import CrossEntropyLoss | |
from typing import Optional, Tuple, Union | |
import logging, json, os | |
import floret | |
from .configuration_stacked import ImpressoConfig | |
logger = logging.getLogger(__name__) | |
def get_info(label_map): | |
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} | |
return num_token_labels_dict | |
# class MyCustomModel: | |
# def __init__(self): | |
# # Custom initialization | |
# pass | |
# | |
# @classmethod | |
# def from_pretrained(cls, *args, **kwargs): | |
# print("Ignoring weights and using custom initialization.") | |
# return cls() | |
class SafeFloretWrapper: | |
""" | |
A safe wrapper for floret model that keeps it off-device to avoid segmentation faults. | |
This class is pure Python and never interacts with PyTorch tensors or devices. | |
""" | |
def __init__(self, model_path): | |
print(f"Loading floret model from {model_path}") | |
self.model_floret = floret.load_model(model_path) | |
def predict(self, texts, k=1): | |
# Floret expects strings, not tensors | |
predictions, probabilities = self.model_floret.predict(texts, k=k) | |
return predictions, probabilities | |
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): | |
config_class = ImpressoConfig | |
_keys_to_ignore_on_load_missing = [r"position_ids"] | |
def __init__(self, config): | |
super().__init__(config) | |
self.config = config | |
# Load floret model | |
self.dummy_param = nn.Parameter(torch.zeros(1)) | |
self.safe_floret = SafeFloretWrapper(self.config.filename) | |
# self.model_floret = SafeFloretWrapper(model_floret) | |
# input_ids = "this is a text" | |
# predictions, probabilities = self.model_floret.predict([input_ids], k=1) | |
# | |
def forward(self, input_ids, attention_mask=None, **kwargs): | |
# Convert input_ids to strings using tokenizer | |
print( | |
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}" | |
) | |
if isinstance(input_ids, str): | |
# If the input is a single string, make it a list for floret | |
texts = [input_ids] | |
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): | |
texts = input_ids | |
else: | |
raise ValueError(f"Unexpected input type: {type(input_ids)}") | |
# Use the SafeFloretWrapper to get predictions | |
predictions, probabilities = self.safe_floret.predict(texts) | |
print(f"Predictions: {predictions}") | |
print(f"Probabilities: {probabilities}") | |
# print(self.model_floret(input_ids)) | |
# if input_ids is not None: | |
# tokenizer = kwargs.get("tokenizer") | |
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True) | |
# else: | |
# texts = kwargs.get("text", None) | |
# | |
# if texts: | |
# # Floret expects strings, not tensors | |
# predictions = [self.model_floret(text) for text in texts] | |
# # Convert predictions to tensors for Hugging Face compatibility | |
# return torch.tensor(predictions) | |
# else: | |
# If no text is found, return dummy output | |
return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes) | |
def state_dict(self, *args, **kwargs): | |
# Return an empty state dictionary | |
return {} | |
def load_state_dict(self, state_dict, strict=True): | |
# Ignore loading since there are no parameters | |
print("Ignoring state_dict since model has no parameters.") | |
def get_floret_model(self): | |
return self.model_floret | |
def get_extended_attention_mask( | |
self, attention_mask, input_shape, device=None, dtype=torch.float | |
): | |
if attention_mask is None: | |
attention_mask = torch.ones(input_shape, device=device) | |
extended_attention_mask = attention_mask[:, None, None, :] | |
extended_attention_mask = extended_attention_mask.to(dtype=dtype) | |
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 | |
return extended_attention_mask | |
def device(self): | |
return next(self.parameters()).device | |
def from_pretrained(cls, *args, **kwargs): | |
print("Ignoring weights and using custom initialization.") | |
# Manually create the config | |
config = ImpressoConfig(**kwargs) | |
# Pass the manually created config to the class | |
model = cls(config) | |
return model | |
# class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): | |
# | |
# config_class = ImpressoConfig | |
# _keys_to_ignore_on_load_missing = [r"position_ids"] | |
# | |
# def __init__(self, config): | |
# super().__init__(config) | |
# # self.num_token_labels_dict = get_info(config.label_map) | |
# # self.config = config | |
# # # print(f"I dont think it arrives here: {self.config}") | |
# # self.bert = AutoModel.from_pretrained( | |
# # config.pretrained_config["_name_or_path"], config=config.pretrained_config | |
# # ) | |
# self.model_floret = floret.load_model(self.config.filename) | |
# # print(f"Model loaded: {self.model_floret}") | |
# # if "classifier_dropout" not in config.__dict__: | |
# # classifier_dropout = 0.1 | |
# # else: | |
# # classifier_dropout = ( | |
# # config.classifier_dropout | |
# # if config.classifier_dropout is not None | |
# # else config.hidden_dropout_prob | |
# # ) | |
# # self.dropout = nn.Dropout(classifier_dropout) | |
# # | |
# # # Additional transformer layers | |
# # self.transformer_encoder = nn.TransformerEncoder( | |
# # nn.TransformerEncoderLayer( | |
# # d_model=config.hidden_size, nhead=config.num_attention_heads | |
# # ), | |
# # num_layers=2, | |
# # ) | |
# | |
# # For token classification, create a classifier for each task | |
# # self.token_classifiers = nn.ModuleDict( | |
# # { | |
# # task: nn.Linear(config.hidden_size, num_labels) | |
# # for task, num_labels in self.num_token_labels_dict.items() | |
# # } | |
# # ) | |
# # | |
# # # Initialize weights and apply final processing | |
# # self.post_init() | |
# | |
# def get_floret_model(self): | |
# return self.model_floret | |
# | |
# @classmethod | |
# def from_pretrained(cls, *args, **kwargs): | |
# print("Ignoring weights and using custom initialization.") | |
# | |
# # Manually create the config | |
# config = ImpressoConfig() | |
# | |
# # Pass the manually created config to the class | |
# model = cls(config) | |
# return model | |
# | |
# # def forward( | |
# # self, | |
# # input_ids: Optional[torch.Tensor] = None, | |
# # attention_mask: Optional[torch.Tensor] = None, | |
# # token_type_ids: Optional[torch.Tensor] = None, | |
# # position_ids: Optional[torch.Tensor] = None, | |
# # head_mask: Optional[torch.Tensor] = None, | |
# # inputs_embeds: Optional[torch.Tensor] = None, | |
# # labels: Optional[torch.Tensor] = None, | |
# # token_labels: Optional[dict] = None, | |
# # output_attentions: Optional[bool] = None, | |
# # output_hidden_states: Optional[bool] = None, | |
# # return_dict: Optional[bool] = None, | |
# # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: | |
# # r""" | |
# # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): | |
# # Labels for computing the token classification loss. Keys should match the tasks. | |
# # """ | |
# # return_dict = ( | |
# # return_dict if return_dict is not None else self.config.use_return_dict | |
# # ) | |
# # | |
# # bert_kwargs = { | |
# # "input_ids": input_ids, | |
# # "attention_mask": attention_mask, | |
# # "token_type_ids": token_type_ids, | |
# # "position_ids": position_ids, | |
# # "head_mask": head_mask, | |
# # "inputs_embeds": inputs_embeds, | |
# # "output_attentions": output_attentions, | |
# # "output_hidden_states": output_hidden_states, | |
# # "return_dict": return_dict, | |
# # } | |
# # | |
# # if any( | |
# # keyword in self.config.name_or_path.lower() | |
# # for keyword in ["llama", "deberta"] | |
# # ): | |
# # bert_kwargs.pop("token_type_ids") | |
# # bert_kwargs.pop("head_mask") | |
# # | |
# # outputs = self.bert(**bert_kwargs) | |
# # | |
# # # For token classification | |
# # token_output = outputs[0] | |
# # token_output = self.dropout(token_output) | |
# # | |
# # # Pass through additional transformer layers | |
# # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( | |
# # 0, 1 | |
# # ) | |
# # | |
# # # Collect the logits and compute the loss for each task | |
# # task_logits = {} | |
# # total_loss = 0 | |
# # for task, classifier in self.token_classifiers.items(): | |
# # logits = classifier(token_output) | |
# # task_logits[task] = logits | |
# # if token_labels and task in token_labels: | |
# # loss_fct = CrossEntropyLoss() | |
# # loss = loss_fct( | |
# # logits.view(-1, self.num_token_labels_dict[task]), | |
# # token_labels[task].view(-1), | |
# # ) | |
# # total_loss += loss | |
# # | |
# # if not return_dict: | |
# # output = (task_logits,) + outputs[2:] | |
# # return ((total_loss,) + output) if total_loss != 0 else output | |
# # print(f"Is there anobidy coming here?") | |
# # return TokenClassifierOutput( | |
# # loss=total_loss, | |
# # logits=task_logits, | |
# # hidden_states=outputs.hidden_states, | |
# # attentions=outputs.attentions, | |
# # ) | |