kusa04 commited on
Commit
052eebe
·
verified ·
1 Parent(s): 590b498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -43
app.py CHANGED
@@ -1,5 +1,5 @@
1
  from collections import Counter
2
- from concurrent.futures import ThreadPoolExecutor # palarell processing
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
  import praw # Reddit's API
@@ -18,55 +18,14 @@ from transformers import (
18
  from transformers.pipelines import AggregationStrategy
19
 
20
  from functions import (
 
21
  scrape_reddit_data,
22
- split_text_by_token_limit,
23
  safe_sentiment,
24
- safe_sentiment_batch,
25
  analyze_detail,
26
  preprocess_text
27
  )
28
 
29
 
30
- # ---------- Cached function for loading the model pipelines ----------
31
- @st.cache_resource
32
- def load_sentiment_pipeline(): # sentiment pipeline
33
- tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
34
- model = AutoModelForSequenceClassification.from_pretrained(
35
- "cardiffnlp/twitter-roberta-base-sentiment-latest",
36
- use_auth_token=st.secrets["hugging_face_token"]
37
- )
38
- sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
39
- max_tokens = tokenizer.model_max_length
40
-
41
- if max_tokens > 10000:
42
- max_tokens = 512
43
- return sentiment_pipeline, tokenizer, max_tokens
44
-
45
-
46
- # class for keyword extraction
47
- @st.cache_resource
48
- class KeyphraseExtractionPipeline(TokenClassificationPipeline):
49
- def __init__(self, model, *args, **kwargs):
50
- super().__init__(
51
- model=AutoModelForTokenClassification.from_pretrained(model),
52
- tokenizer=AutoTokenizer.from_pretrained(model),
53
- *args,
54
- **kwargs
55
- )
56
-
57
- def postprocess(self, all_outputs):
58
- results = super().postprocess(
59
- all_outputs=all_outputs,
60
- aggregation_strategy=AggregationStrategy.SIMPLE,
61
- )
62
- return np.unique([result.get("word").strip() for result in results])
63
-
64
- @st.cache_resource
65
- def keyword_extractor():
66
- model_name = "ml6team/keyphrase-extraction-kbir-inspec"
67
- extractor = KeyphraseExtractionPipeline(model=model_name)
68
- return extractor
69
-
70
 
71
 
72
 
 
1
  from collections import Counter
2
+ from concurrent.futures import ThreadPoolExecutor # palarell processing
3
  import matplotlib.pyplot as plt
4
  import pandas as pd
5
  import praw # Reddit's API
 
18
  from transformers.pipelines import AggregationStrategy
19
 
20
  from functions import (
21
+ load_sentiment_pipeline,
22
  scrape_reddit_data,
 
23
  safe_sentiment,
 
24
  analyze_detail,
25
  preprocess_text
26
  )
27
 
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31