emanuelaboros commited on
Commit
7138c9f
·
1 Parent(s): dd2cf99

checking if other model thigns are needed

Browse files
Files changed (3) hide show
  1. config.json +1 -1
  2. configuration_lang.py +1 -1
  3. modeling_lang.py +3 -26
config.json CHANGED
@@ -141,7 +141,7 @@
141
  },
142
  "layer_norm_eps": 1e-12,
143
  "max_position_embeddings": 512,
144
- "model_type": "stacked_bert",
145
  "num_attention_heads": 8,
146
  "num_hidden_layers": 8,
147
  "pad_token_id": 0,
 
141
  },
142
  "layer_norm_eps": 1e-12,
143
  "max_position_embeddings": 512,
144
+ "model_type": "lang_detect",
145
  "num_attention_heads": 8,
146
  "num_hidden_layers": 8,
147
  "pad_token_id": 0,
configuration_lang.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
 
4
 
5
  class ImpressoConfig(PretrainedConfig):
6
- model_type = "stacked_bert"
7
 
8
  def __init__(
9
  self,
 
3
 
4
 
5
  class ImpressoConfig(PretrainedConfig):
6
+ model_type = "lang_detect"
7
 
8
  def __init__(
9
  self,
modeling_lang.py CHANGED
@@ -1,10 +1,7 @@
1
- from transformers.modeling_outputs import TokenClassifierOutput
2
  import torch
3
  import torch.nn as nn
4
- from transformers import PreTrainedModel, AutoModel, AutoConfig, BertConfig
5
- from torch.nn import CrossEntropyLoss
6
- from typing import Optional, Tuple, Union
7
- import logging, json, os
8
  import floret
9
  from .configuration_lang import ImpressoConfig
10
 
@@ -26,9 +23,6 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
26
 
27
  #
28
  def forward(self, input_ids, attention_mask=None, **kwargs):
29
- # print(
30
- # f"Check if it arrives here: {input_ids}, ---, {type(input_ids)} ----- {type(self.model_floret)}"
31
- # )
32
  if isinstance(input_ids, str):
33
  # If the input is a single string, make it a list for floret
34
  texts = [input_ids]
@@ -37,13 +31,11 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
37
  else:
38
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
39
 
40
- # Use the SafeFloretWrapper to get predictions
41
  predictions, probabilities = self.model_floret.predict(texts, k=1)
42
- # print(f"Predictions: {predictions}, Probabilities: {probabilities}")
43
  return (
44
  predictions,
45
  probabilities,
46
- ) # Dummy tensor with shape (batch_size, num_classes)
47
 
48
  def state_dict(self, *args, **kwargs):
49
  # Return an empty state dictionary
@@ -53,21 +45,6 @@ class ExtendedMultitaskModelForTokenClassification(PreTrainedModel):
53
  # Ignore loading since there are no parameters
54
  pass
55
 
56
- # print("Ignoring state_dict since model has no parameters.")
57
-
58
- # def get_floret_model(self):
59
- # return self.model_floret
60
-
61
- # def get_extended_attention_mask(
62
- # self, attention_mask, input_shape, device=None, dtype=torch.float
63
- # ):
64
- # if attention_mask is None:
65
- # attention_mask = torch.ones(input_shape, device=device)
66
- # extended_attention_mask = attention_mask[:, None, None, :]
67
- # extended_attention_mask = extended_attention_mask.to(dtype=dtype)
68
- # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
69
- # return extended_attention_mask
70
-
71
  @property
72
  def device(self):
73
  return next(self.parameters()).device
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from transformers import PreTrainedModel
4
+ import logging
 
 
5
  import floret
6
  from .configuration_lang import ImpressoConfig
7
 
 
23
 
24
  #
25
  def forward(self, input_ids, attention_mask=None, **kwargs):
 
 
 
26
  if isinstance(input_ids, str):
27
  # If the input is a single string, make it a list for floret
28
  texts = [input_ids]
 
31
  else:
32
  raise ValueError(f"Unexpected input type: {type(input_ids)}")
33
 
 
34
  predictions, probabilities = self.model_floret.predict(texts, k=1)
 
35
  return (
36
  predictions,
37
  probabilities,
38
+ )
39
 
40
  def state_dict(self, *args, **kwargs):
41
  # Return an empty state dictionary
 
45
  # Ignore loading since there are no parameters
46
  pass
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  @property
49
  def device(self):
50
  return next(self.parameters()).device