lang-detect / modeling_stacked.py
emanuelaboros's picture
testin the trick
f729b09
raw
history blame
10.2 kB
from transformers.modeling_outputs import TokenClassifierOutput
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
from torch.nn import CrossEntropyLoss
from typing import Optional, Tuple, Union
import logging, json, os
import floret
from .configuration_stacked import ImpressoConfig
logger = logging.getLogger(__name__)
def get_info(label_map):
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
return num_token_labels_dict
# class MyCustomModel:
# def __init__(self):
# # Custom initialization
# pass
#
# @classmethod
# def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
# return cls()
class SafeFloretWrapper:
"""
A safe wrapper for floret model that keeps it off-device to avoid segmentation faults.
This class is pure Python and never interacts with PyTorch tensors or devices.
"""
def __init__(self, model_path):
print(f"Loading floret model from {model_path}")
self.model_floret = floret.load_model(model_path)
def predict(self, texts, k=1):
# Floret expects strings, not tensors
predictions, probabilities = self.model_floret.predict(texts, k=k)
return predictions, probabilities
class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
config_class = ImpressoConfig
_keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config):
super().__init__(config)
self.config = config
# Load floret model
self.dummy_param = nn.Parameter(torch.zeros(1))
self.safe_floret = SafeFloretWrapper(self.config.filename)
# self.model_floret = SafeFloretWrapper(model_floret)
# input_ids = "this is a text"
# predictions, probabilities = self.model_floret.predict([input_ids], k=1)
#
def forward(self, input_ids, attention_mask=None, **kwargs):
# Convert input_ids to strings using tokenizer
print(
f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
)
if isinstance(input_ids, str):
# If the input is a single string, make it a list for floret
texts = [input_ids]
elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids):
texts = input_ids
else:
raise ValueError(f"Unexpected input type: {type(input_ids)}")
# Use the SafeFloretWrapper to get predictions
predictions, probabilities = self.safe_floret.predict(texts)
print(f"Predictions: {predictions}")
print(f"Probabilities: {probabilities}")
# print(self.model_floret(input_ids))
# if input_ids is not None:
# tokenizer = kwargs.get("tokenizer")
# texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
# else:
# texts = kwargs.get("text", None)
#
# if texts:
# # Floret expects strings, not tensors
# predictions = [self.model_floret(text) for text in texts]
# # Convert predictions to tensors for Hugging Face compatibility
# return torch.tensor(predictions)
# else:
# If no text is found, return dummy output
return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes)
def state_dict(self, *args, **kwargs):
# Return an empty state dictionary
return {}
def load_state_dict(self, state_dict, strict=True):
# Ignore loading since there are no parameters
print("Ignoring state_dict since model has no parameters.")
def get_floret_model(self):
return self.model_floret
def get_extended_attention_mask(
self, attention_mask, input_shape, device=None, dtype=torch.float
):
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
extended_attention_mask = attention_mask[:, None, None, :]
extended_attention_mask = extended_attention_mask.to(dtype=dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
return extended_attention_mask
@property
def device(self):
return next(self.parameters()).device
@classmethod
def from_pretrained(cls, *args, **kwargs):
print("Ignoring weights and using custom initialization.")
# Manually create the config
config = ImpressoConfig(**kwargs)
# Pass the manually created config to the class
model = cls(config)
return model
# class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
#
# config_class = ImpressoConfig
# _keys_to_ignore_on_load_missing = [r"position_ids"]
#
# def __init__(self, config):
# super().__init__(config)
# # self.num_token_labels_dict = get_info(config.label_map)
# # self.config = config
# # # print(f"I dont think it arrives here: {self.config}")
# # self.bert = AutoModel.from_pretrained(
# # config.pretrained_config["_name_or_path"], config=config.pretrained_config
# # )
# self.model_floret = floret.load_model(self.config.filename)
# # print(f"Model loaded: {self.model_floret}")
# # if "classifier_dropout" not in config.__dict__:
# # classifier_dropout = 0.1
# # else:
# # classifier_dropout = (
# # config.classifier_dropout
# # if config.classifier_dropout is not None
# # else config.hidden_dropout_prob
# # )
# # self.dropout = nn.Dropout(classifier_dropout)
# #
# # # Additional transformer layers
# # self.transformer_encoder = nn.TransformerEncoder(
# # nn.TransformerEncoderLayer(
# # d_model=config.hidden_size, nhead=config.num_attention_heads
# # ),
# # num_layers=2,
# # )
#
# # For token classification, create a classifier for each task
# # self.token_classifiers = nn.ModuleDict(
# # {
# # task: nn.Linear(config.hidden_size, num_labels)
# # for task, num_labels in self.num_token_labels_dict.items()
# # }
# # )
# #
# # # Initialize weights and apply final processing
# # self.post_init()
#
# def get_floret_model(self):
# return self.model_floret
#
# @classmethod
# def from_pretrained(cls, *args, **kwargs):
# print("Ignoring weights and using custom initialization.")
#
# # Manually create the config
# config = ImpressoConfig()
#
# # Pass the manually created config to the class
# model = cls(config)
# return model
#
# # def forward(
# # self,
# # input_ids: Optional[torch.Tensor] = None,
# # attention_mask: Optional[torch.Tensor] = None,
# # token_type_ids: Optional[torch.Tensor] = None,
# # position_ids: Optional[torch.Tensor] = None,
# # head_mask: Optional[torch.Tensor] = None,
# # inputs_embeds: Optional[torch.Tensor] = None,
# # labels: Optional[torch.Tensor] = None,
# # token_labels: Optional[dict] = None,
# # output_attentions: Optional[bool] = None,
# # output_hidden_states: Optional[bool] = None,
# # return_dict: Optional[bool] = None,
# # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
# # r"""
# # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*):
# # Labels for computing the token classification loss. Keys should match the tasks.
# # """
# # return_dict = (
# # return_dict if return_dict is not None else self.config.use_return_dict
# # )
# #
# # bert_kwargs = {
# # "input_ids": input_ids,
# # "attention_mask": attention_mask,
# # "token_type_ids": token_type_ids,
# # "position_ids": position_ids,
# # "head_mask": head_mask,
# # "inputs_embeds": inputs_embeds,
# # "output_attentions": output_attentions,
# # "output_hidden_states": output_hidden_states,
# # "return_dict": return_dict,
# # }
# #
# # if any(
# # keyword in self.config.name_or_path.lower()
# # for keyword in ["llama", "deberta"]
# # ):
# # bert_kwargs.pop("token_type_ids")
# # bert_kwargs.pop("head_mask")
# #
# # outputs = self.bert(**bert_kwargs)
# #
# # # For token classification
# # token_output = outputs[0]
# # token_output = self.dropout(token_output)
# #
# # # Pass through additional transformer layers
# # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose(
# # 0, 1
# # )
# #
# # # Collect the logits and compute the loss for each task
# # task_logits = {}
# # total_loss = 0
# # for task, classifier in self.token_classifiers.items():
# # logits = classifier(token_output)
# # task_logits[task] = logits
# # if token_labels and task in token_labels:
# # loss_fct = CrossEntropyLoss()
# # loss = loss_fct(
# # logits.view(-1, self.num_token_labels_dict[task]),
# # token_labels[task].view(-1),
# # )
# # total_loss += loss
# #
# # if not return_dict:
# # output = (task_logits,) + outputs[2:]
# # return ((total_loss,) + output) if total_loss != 0 else output
# # print(f"Is there anobidy coming here?")
# # return TokenClassifierOutput(
# # loss=total_loss,
# # logits=task_logits,
# # hidden_states=outputs.hidden_states,
# # attentions=outputs.attentions,
# # )