kusa04 commited on
Commit
a9d660f
·
verified ·
1 Parent(s): 052eebe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -1
app.py CHANGED
@@ -18,7 +18,6 @@ from transformers import (
18
  from transformers.pipelines import AggregationStrategy
19
 
20
  from functions import (
21
- load_sentiment_pipeline,
22
  scrape_reddit_data,
23
  safe_sentiment,
24
  analyze_detail,
@@ -26,6 +25,46 @@ from functions import (
26
  )
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
 
 
18
  from transformers.pipelines import AggregationStrategy
19
 
20
  from functions import (
 
21
  scrape_reddit_data,
22
  safe_sentiment,
23
  analyze_detail,
 
25
  )
26
 
27
 
28
+ # ---------- Cached function for loading the model pipelines ----------
29
+ @st.cache_resource
30
+ def load_sentiment_pipeline(): # sentiment pipeline
31
+ tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
32
+ model = AutoModelForSequenceClassification.from_pretrained(
33
+ "cardiffnlp/twitter-roberta-base-sentiment-latest",
34
+ use_auth_token=st.secrets["hugging_face_token"]
35
+ )
36
+ sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
37
+ max_tokens = tokenizer.model_max_length
38
+
39
+ if max_tokens > 10000:
40
+ max_tokens = 512
41
+ return sentiment_pipeline, tokenizer, max_tokens
42
+
43
+
44
+ # class for keyword extraction
45
+ @st.cache_resource
46
+ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
47
+ def __init__(self, model, *args, **kwargs):
48
+ super().__init__(
49
+ model=AutoModelForTokenClassification.from_pretrained(model),
50
+ tokenizer=AutoTokenizer.from_pretrained(model),
51
+ *args,
52
+ **kwargs
53
+ )
54
+
55
+ def postprocess(self, all_outputs):
56
+ results = super().postprocess(
57
+ all_outputs=all_outputs,
58
+ aggregation_strategy=AggregationStrategy.SIMPLE,
59
+ )
60
+ return np.unique([result.get("word").strip() for result in results])
61
+
62
+ @st.cache_resource
63
+ def keyword_extractor():
64
+ model_name = "ml6team/keyphrase-extraction-kbir-inspec"
65
+ extractor = KeyphraseExtractionPipeline(model=model_name)
66
+ return extractor
67
+ # ---------------------------------------------------------------------------
68
 
69
 
70