kusa04 commited on
Commit
e652020
·
verified ·
1 Parent(s): dbe9ae9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -3
app.py CHANGED
@@ -17,9 +17,6 @@ from transformers import (
17
  from transformers.pipelines import AggregationStrategy
18
 
19
  from functions import (
20
- load_sentiment_pipeline,
21
- KeyphraseExtractionPipeline,
22
- keyword_extractor,
23
  scrape_reddit_data,
24
  split_text_by_token_limit,
25
  safe_sentiment,
@@ -28,6 +25,45 @@ from functions import (
28
  )
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
 
 
17
  from transformers.pipelines import AggregationStrategy
18
 
19
  from functions import (
 
 
 
20
  scrape_reddit_data,
21
  split_text_by_token_limit,
22
  safe_sentiment,
 
25
  )
26
 
27
 
28
+ # ---------- Cached function for loading the model pipelines ----------
29
+ @st.cache_resource
30
+ def load_sentiment_pipeline(): # sentiment pipeline
31
+ tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
32
+ model = AutoModelForSequenceClassification.from_pretrained(
33
+ "cardiffnlp/twitter-roberta-base-sentiment-latest",
34
+ use_auth_token=st.secrets["hugging_face_token"]
35
+ )
36
+ sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
37
+ max_tokens = tokenizer.model_max_length
38
+
39
+ if max_tokens > 10000:
40
+ max_tokens = 512
41
+ return sentiment_pipeline, tokenizer, max_tokens
42
+
43
+
44
+ # class for keyword extraction
45
+ @st.cache_resource
46
+ class KeyphraseExtractionPipeline(TokenClassificationPipeline):
47
+ def __init__(self, model, *args, **kwargs):
48
+ super().__init__(
49
+ model=AutoModelForTokenClassification.from_pretrained(model),
50
+ tokenizer=AutoTokenizer.from_pretrained(model),
51
+ *args,
52
+ **kwargs
53
+ )
54
+
55
+ def postprocess(self, all_outputs):
56
+ results = super().postprocess(
57
+ all_outputs=all_outputs,
58
+ aggregation_strategy=AggregationStrategy.SIMPLE,
59
+ )
60
+ return np.unique([result.get("word").strip() for result in results])
61
+
62
+ @st.cache_resource
63
+ def keyword_extractor():
64
+ model_name = "ml6team/keyphrase-extraction-kbir-inspec"
65
+ extractor = KeyphraseExtractionPipeline(model=model_name)
66
+ return extractor
67
 
68
 
69