kusa04 commited on
Commit
dbe9ae9
·
verified ·
1 Parent(s): 65db17e

Update functions.py

Browse files
Files changed (1) hide show
  1. functions.py +37 -37
functions.py CHANGED
@@ -17,44 +17,44 @@ from transformers import (
17
  from transformers.pipelines import AggregationStrategy
18
 
19
 
20
- # ---------- Cached function for loading the model pipelines ----------
21
- @st.cache_resource
22
- def load_sentiment_pipeline(): # sentiment pipeline
23
- tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
24
- model = AutoModelForSequenceClassification.from_pretrained(
25
- "cardiffnlp/twitter-roberta-base-sentiment-latest",
26
- use_auth_token=st.secrets["hugging_face_token"]
27
- )
28
- sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
29
- max_tokens = tokenizer.model_max_length
30
 
31
- if max_tokens > 10000:
32
- max_tokens = 512
33
- return sentiment_pipeline, tokenizer, max_tokens
34
-
35
-
36
- # class for keyword extraction
37
- @st.cache_resource
38
- class KeyphraseExtractionPipeline(TokenClassificationPipeline):
39
- def __init__(self, model, *args, **kwargs):
40
- super().__init__(
41
- model=AutoModelForTokenClassification.from_pretrained(model),
42
- tokenizer=AutoTokenizer.from_pretrained(model),
43
- *args,
44
- **kwargs
45
- )
46
-
47
- def postprocess(self, all_outputs):
48
- results = super().postprocess(
49
- all_outputs=all_outputs,
50
- aggregation_strategy=AggregationStrategy.SIMPLE,
51
- )
52
- return np.unique([result.get("word").strip() for result in results])
53
-
54
- def keyword_extractor():
55
- model_name = "ml6team/keyphrase-extraction-kbir-inspec"
56
- extractor = KeyphraseExtractionPipeline(model=model_name)
57
- return extractor
58
 
59
 
60
  # Function to normalize text by replacing multiple spaces/newlines with a single space
 
17
  from transformers.pipelines import AggregationStrategy
18
 
19
 
20
+ # # ---------- Cached function for loading the model pipelines ----------
21
+ # @st.cache_resource
22
+ # def load_sentiment_pipeline(): # sentiment pipeline
23
+ # tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
24
+ # model = AutoModelForSequenceClassification.from_pretrained(
25
+ # "cardiffnlp/twitter-roberta-base-sentiment-latest",
26
+ # use_auth_token=st.secrets["hugging_face_token"]
27
+ # )
28
+ # sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
29
+ # max_tokens = tokenizer.model_max_length
30
 
31
+ # if max_tokens > 10000:
32
+ # max_tokens = 512
33
+ # return sentiment_pipeline, tokenizer, max_tokens
34
+
35
+
36
+ # # class for keyword extraction
37
+ # @st.cache_resource
38
+ # class KeyphraseExtractionPipeline(TokenClassificationPipeline):
39
+ # def __init__(self, model, *args, **kwargs):
40
+ # super().__init__(
41
+ # model=AutoModelForTokenClassification.from_pretrained(model),
42
+ # tokenizer=AutoTokenizer.from_pretrained(model),
43
+ # *args,
44
+ # **kwargs
45
+ # )
46
+
47
+ # def postprocess(self, all_outputs):
48
+ # results = super().postprocess(
49
+ # all_outputs=all_outputs,
50
+ # aggregation_strategy=AggregationStrategy.SIMPLE,
51
+ # )
52
+ # return np.unique([result.get("word").strip() for result in results])
53
+
54
+ # def keyword_extractor():
55
+ # model_name = "ml6team/keyphrase-extraction-kbir-inspec"
56
+ # extractor = KeyphraseExtractionPipeline(model=model_name)
57
+ # return extractor
58
 
59
 
60
  # Function to normalize text by replacing multiple spaces/newlines with a single space