from transformers.modeling_outputs import TokenClassifierOutput import torch import torch.nn as nn from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig from torch.nn import CrossEntropyLoss from typing import Optional, Tuple, Union import logging, json, os import floret from .configuration_stacked import ImpressoConfig logger = logging.getLogger(__name__) def get_info(label_map): num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()} return num_token_labels_dict # class MyCustomModel: # def __init__(self): # # Custom initialization # pass # # @classmethod # def from_pretrained(cls, *args, **kwargs): # print("Ignoring weights and using custom initialization.") # return cls() class SafeFloretWrapper: """ A safe wrapper for floret model that keeps it off-device to avoid segmentation faults. This class is pure Python and never interacts with PyTorch tensors or devices. """ def __init__(self, model_path): print(f"Loading floret model from {model_path}") self.model_floret = floret.load_model(model_path) def predict(self, texts, k=1): # Floret expects strings, not tensors predictions, probabilities = self.model_floret.predict(texts, k=k) return predictions, probabilities class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): config_class = ImpressoConfig _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) self.config = config # Load floret model self.dummy_param = nn.Parameter(torch.zeros(1)) self.safe_floret = SafeFloretWrapper(self.config.filename) # self.model_floret = SafeFloretWrapper(model_floret) # input_ids = "this is a text" # predictions, probabilities = self.model_floret.predict([input_ids], k=1) # def forward(self, input_ids, attention_mask=None, **kwargs): # Convert input_ids to strings using tokenizer print( f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}" ) if isinstance(input_ids, str): # If the input is a single string, make it a list for floret texts = [input_ids] elif isinstance(input_ids, list) and all(isinstance(t, str) for t in input_ids): texts = input_ids else: raise ValueError(f"Unexpected input type: {type(input_ids)}") # Use the SafeFloretWrapper to get predictions predictions, probabilities = self.safe_floret.predict(texts) print(f"Predictions: {predictions}") print(f"Probabilities: {probabilities}") # print(self.model_floret(input_ids)) # if input_ids is not None: # tokenizer = kwargs.get("tokenizer") # texts = tokenizer.batch_decode(input_ids, skip_special_tokens=True) # else: # texts = kwargs.get("text", None) # # if texts: # # Floret expects strings, not tensors # predictions = [self.model_floret(text) for text in texts] # # Convert predictions to tensors for Hugging Face compatibility # return torch.tensor(predictions) # else: # If no text is found, return dummy output return torch.zeros((1, 2)) # Dummy tensor with shape (batch_size, num_classes) def state_dict(self, *args, **kwargs): # Return an empty state dictionary return {} def load_state_dict(self, state_dict, strict=True): # Ignore loading since there are no parameters print("Ignoring state_dict since model has no parameters.") def get_floret_model(self): return self.model_floret def get_extended_attention_mask( self, attention_mask, input_shape, device=None, dtype=torch.float ): if attention_mask is None: attention_mask = torch.ones(input_shape, device=device) extended_attention_mask = attention_mask[:, None, None, :] extended_attention_mask = extended_attention_mask.to(dtype=dtype) extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 return extended_attention_mask @property def device(self): return next(self.parameters()).device @classmethod def from_pretrained(cls, *args, **kwargs): print("Ignoring weights and using custom initialization.") # Manually create the config config = ImpressoConfig(**kwargs) # Pass the manually created config to the class model = cls(config) return model # class ExtendedMultitaskModelForTokenClassification(PreTrainedModel): # # config_class = ImpressoConfig # _keys_to_ignore_on_load_missing = [r"position_ids"] # # def __init__(self, config): # super().__init__(config) # # self.num_token_labels_dict = get_info(config.label_map) # # self.config = config # # # print(f"I dont think it arrives here: {self.config}") # # self.bert = AutoModel.from_pretrained( # # config.pretrained_config["_name_or_path"], config=config.pretrained_config # # ) # self.model_floret = floret.load_model(self.config.filename) # # print(f"Model loaded: {self.model_floret}") # # if "classifier_dropout" not in config.__dict__: # # classifier_dropout = 0.1 # # else: # # classifier_dropout = ( # # config.classifier_dropout # # if config.classifier_dropout is not None # # else config.hidden_dropout_prob # # ) # # self.dropout = nn.Dropout(classifier_dropout) # # # # # Additional transformer layers # # self.transformer_encoder = nn.TransformerEncoder( # # nn.TransformerEncoderLayer( # # d_model=config.hidden_size, nhead=config.num_attention_heads # # ), # # num_layers=2, # # ) # # # For token classification, create a classifier for each task # # self.token_classifiers = nn.ModuleDict( # # { # # task: nn.Linear(config.hidden_size, num_labels) # # for task, num_labels in self.num_token_labels_dict.items() # # } # # ) # # # # # Initialize weights and apply final processing # # self.post_init() # # def get_floret_model(self): # return self.model_floret # # @classmethod # def from_pretrained(cls, *args, **kwargs): # print("Ignoring weights and using custom initialization.") # # # Manually create the config # config = ImpressoConfig() # # # Pass the manually created config to the class # model = cls(config) # return model # # # def forward( # # self, # # input_ids: Optional[torch.Tensor] = None, # # attention_mask: Optional[torch.Tensor] = None, # # token_type_ids: Optional[torch.Tensor] = None, # # position_ids: Optional[torch.Tensor] = None, # # head_mask: Optional[torch.Tensor] = None, # # inputs_embeds: Optional[torch.Tensor] = None, # # labels: Optional[torch.Tensor] = None, # # token_labels: Optional[dict] = None, # # output_attentions: Optional[bool] = None, # # output_hidden_states: Optional[bool] = None, # # return_dict: Optional[bool] = None, # # ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: # # r""" # # token_labels (`dict` of `torch.LongTensor` of shape `(batch_size, seq_length)`, *optional*): # # Labels for computing the token classification loss. Keys should match the tasks. # # """ # # return_dict = ( # # return_dict if return_dict is not None else self.config.use_return_dict # # ) # # # # bert_kwargs = { # # "input_ids": input_ids, # # "attention_mask": attention_mask, # # "token_type_ids": token_type_ids, # # "position_ids": position_ids, # # "head_mask": head_mask, # # "inputs_embeds": inputs_embeds, # # "output_attentions": output_attentions, # # "output_hidden_states": output_hidden_states, # # "return_dict": return_dict, # # } # # # # if any( # # keyword in self.config.name_or_path.lower() # # for keyword in ["llama", "deberta"] # # ): # # bert_kwargs.pop("token_type_ids") # # bert_kwargs.pop("head_mask") # # # # outputs = self.bert(**bert_kwargs) # # # # # For token classification # # token_output = outputs[0] # # token_output = self.dropout(token_output) # # # # # Pass through additional transformer layers # # token_output = self.transformer_encoder(token_output.transpose(0, 1)).transpose( # # 0, 1 # # ) # # # # # Collect the logits and compute the loss for each task # # task_logits = {} # # total_loss = 0 # # for task, classifier in self.token_classifiers.items(): # # logits = classifier(token_output) # # task_logits[task] = logits # # if token_labels and task in token_labels: # # loss_fct = CrossEntropyLoss() # # loss = loss_fct( # # logits.view(-1, self.num_token_labels_dict[task]), # # token_labels[task].view(-1), # # ) # # total_loss += loss # # # # if not return_dict: # # output = (task_logits,) + outputs[2:] # # return ((total_loss,) + output) if total_loss != 0 else output # # print(f"Is there anobidy coming here?") # # return TokenClassifierOutput( # # loss=total_loss, # # logits=task_logits, # # hidden_states=outputs.hidden_states, # # attentions=outputs.attentions, # # )