Stanford-TH commited on
Commit
a6ed790
1 Parent(s): d5249c9

commit files to HF hub

Browse files
Files changed (2) hide show
  1. config.json +19 -2
  2. genre_pipe.py +102 -0
config.json CHANGED
@@ -1,10 +1,26 @@
1
  {
 
2
  "architectures": [
3
  "GenreModel"
4
  ],
5
  "auto_map": {
6
- "AutoConfig": "genre_configuration.GenreConfig",
7
- "AutoModel": "genre_model.GenreModel"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  },
9
  "id2label": {
10
  "0": "Action",
@@ -25,6 +41,7 @@
25
  "8": "Fantasy",
26
  "9": "History"
27
  },
 
28
  "model_type": "custom-bert-base-uncased",
29
  "torch_dtype": "float32",
30
  "transformers_version": "4.39.3"
 
1
  {
2
+ "_name_or_path": "Stanford-TH/GenrePrediction",
3
  "architectures": [
4
  "GenreModel"
5
  ],
6
  "auto_map": {
7
+ "AutoConfig": "Stanford-TH/GenrePrediction--genre_configuration.GenreConfig",
8
+ "AutoModel": "Stanford-TH/GenrePrediction--genre_model.GenreModel"
9
+ },
10
+ "custom_pipelines": {
11
+ "weighted-genre-classification": {
12
+ "default": {
13
+ "model": {
14
+ "pt": "Stanford-TH/GenrePrediction"
15
+ }
16
+ },
17
+ "impl": "genre_pipe.GenrePredictionPipeline",
18
+ "pt": [
19
+ "AutoModel"
20
+ ],
21
+ "tf": [],
22
+ "type": "text"
23
+ }
24
  },
25
  "id2label": {
26
  "0": "Action",
 
41
  "8": "Fantasy",
42
  "9": "History"
43
  },
44
+ "label2id": null,
45
  "model_type": "custom-bert-base-uncased",
46
  "torch_dtype": "float32",
47
  "transformers_version": "4.39.3"
genre_pipe.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+ from transformers import Pipeline
3
+ from transformers import AutoTokenizer
4
+ from transformers.utils import ModelOutput
5
+ import numpy as np
6
+ import unicodedata
7
+ import re
8
+ import torch
9
+
10
+ class Preprocess_Text:
11
+ @staticmethod
12
+ def remove_tags(sentence):
13
+ return re.sub('<.*?>', ' ', sentence)
14
+
15
+ @staticmethod
16
+ def remove_accents(sentence):
17
+ return unicodedata.normalize('NFD', sentence).encode('ascii', 'ignore').decode("utf-8")
18
+
19
+ @staticmethod
20
+ def remove_punctuation(sentence):
21
+ sentence = re.sub(r'[?|!|\'|"|#]', '', sentence)
22
+ sentence = re.sub(r'[.,;:(){}[\]\\/<>|-]', ' ', sentence)
23
+ return sentence.replace("\n", " ")
24
+
25
+ @staticmethod
26
+ def keep_alpha(sentence):
27
+ return re.sub('[^a-z A-Z]+', ' ', sentence)
28
+
29
+ @staticmethod
30
+ def lower_case(sentence):
31
+ return sentence.lower()
32
+
33
+ def __call__(self, text):
34
+ text = self.remove_tags(text)
35
+ text = self.remove_accents(text)
36
+ text = self.remove_punctuation(text)
37
+ text = self.keep_alpha(text)
38
+ text = self.lower_case(text)
39
+ return text
40
+
41
+ class GenrePredictionPipeline(Pipeline):
42
+ def _sanitize_parameters(self, **kwargs):
43
+ preprocess_kwargs = {}
44
+ if "text" in kwargs:
45
+ preprocess_kwargs['text'] = kwargs['text']
46
+ return preprocess_kwargs,{},{}
47
+
48
+ def preprocess(self,text,**kwargs):
49
+ text_preprocessing_obj = Preprocess_Text()
50
+ processed_description = text_preprocessing_obj(text)
51
+
52
+ try:
53
+ if type(processed_description) == str:
54
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
55
+ encoded_text = tokenizer.encode_plus(
56
+ text, None, add_special_tokens=True, max_length=512,
57
+ padding='max_length', return_token_type_ids=True, truncation=True,
58
+ return_tensors=self.framework, return_overflowing_tokens=True )
59
+
60
+ maximum_overflowed_samples = len(encoded_text.pop('overflow_to_sample_mapping'))
61
+
62
+ try:
63
+ numbers = [[x for x in encoded_text.word_ids(batch_index=i) if x is not None][-1]
64
+ for i in range(maximum_overflowed_samples)]
65
+ except IndexError:
66
+ return None,torch.zeros(17,dtype='float32')
67
+
68
+ sequence_length = numbers[-1]
69
+ weights = [numbers[0]] + [numbers[i] - numbers[i-1] for i in range(1, len(numbers))]
70
+ weights = (torch.tensor(weights) / sequence_length).to(self.device) # Normalize weights
71
+ return {"model_inputs":encoded_text,"weights":weights,"max_length":sequence_length}
72
+ else:
73
+ raise AttributeError()
74
+ except Exception as error:
75
+ print("Wrong format {}".format(str(error)))
76
+ return -1
77
+
78
+ def _forward(self,model_inputs):
79
+ weights,max_length = model_inputs.pop('weights'),model_inputs.pop('max_length')
80
+ with torch.no_grad():
81
+ outputs = self.model(**model_inputs['model_inputs'])
82
+
83
+ return {"model_outputs":outputs,"weights":weights,"max_length":max_length}
84
+
85
+ def postprocess(self, model_outputs: ModelOutput, **postprocess_parameters: Dict) -> Any:
86
+ # Apply sigmoid activation and calculate weighted logits
87
+ print(model_outputs,postprocess_parameters)
88
+ logits = torch.sigmoid(model_outputs.pop('model_outputs'))
89
+ probabilities = logits * model_outputs.pop('weights').unsqueeze(1)
90
+
91
+ probabilities = probabilities.sum(dim=0)
92
+
93
+ top_scores, top_indices = torch.topk(probabilities, 3) # Get the top 3 scores and their indices
94
+
95
+ print(top_scores,top_indices)
96
+
97
+ top_genres = [self.model.config.id2label[str(idx.item())] for idx in top_indices.squeeze()]
98
+ top_scores = top_scores.detach().cpu().numpy()
99
+
100
+ genre_scores = {genre: score for genre, score in zip(top_genres, top_scores.squeeze())}
101
+
102
+ return genre_scores