Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from collections import Counter
|
2 |
-
from concurrent.futures import ThreadPoolExecutor
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
5 |
import praw # Reddit's API
|
@@ -18,55 +18,14 @@ from transformers import (
|
|
18 |
from transformers.pipelines import AggregationStrategy
|
19 |
|
20 |
from functions import (
|
|
|
21 |
scrape_reddit_data,
|
22 |
-
split_text_by_token_limit,
|
23 |
safe_sentiment,
|
24 |
-
safe_sentiment_batch,
|
25 |
analyze_detail,
|
26 |
preprocess_text
|
27 |
)
|
28 |
|
29 |
|
30 |
-
# ---------- Cached function for loading the model pipelines ----------
|
31 |
-
@st.cache_resource
|
32 |
-
def load_sentiment_pipeline(): # sentiment pipeline
|
33 |
-
tokenizer = AutoTokenizer.from_pretrained("cardiffnlp/twitter-roberta-base-sentiment-latest")
|
34 |
-
model = AutoModelForSequenceClassification.from_pretrained(
|
35 |
-
"cardiffnlp/twitter-roberta-base-sentiment-latest",
|
36 |
-
use_auth_token=st.secrets["hugging_face_token"]
|
37 |
-
)
|
38 |
-
sentiment_pipeline = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=0) # -1 to 0
|
39 |
-
max_tokens = tokenizer.model_max_length
|
40 |
-
|
41 |
-
if max_tokens > 10000:
|
42 |
-
max_tokens = 512
|
43 |
-
return sentiment_pipeline, tokenizer, max_tokens
|
44 |
-
|
45 |
-
|
46 |
-
# class for keyword extraction
|
47 |
-
@st.cache_resource
|
48 |
-
class KeyphraseExtractionPipeline(TokenClassificationPipeline):
|
49 |
-
def __init__(self, model, *args, **kwargs):
|
50 |
-
super().__init__(
|
51 |
-
model=AutoModelForTokenClassification.from_pretrained(model),
|
52 |
-
tokenizer=AutoTokenizer.from_pretrained(model),
|
53 |
-
*args,
|
54 |
-
**kwargs
|
55 |
-
)
|
56 |
-
|
57 |
-
def postprocess(self, all_outputs):
|
58 |
-
results = super().postprocess(
|
59 |
-
all_outputs=all_outputs,
|
60 |
-
aggregation_strategy=AggregationStrategy.SIMPLE,
|
61 |
-
)
|
62 |
-
return np.unique([result.get("word").strip() for result in results])
|
63 |
-
|
64 |
-
@st.cache_resource
|
65 |
-
def keyword_extractor():
|
66 |
-
model_name = "ml6team/keyphrase-extraction-kbir-inspec"
|
67 |
-
extractor = KeyphraseExtractionPipeline(model=model_name)
|
68 |
-
return extractor
|
69 |
-
|
70 |
|
71 |
|
72 |
|
|
|
1 |
from collections import Counter
|
2 |
+
from concurrent.futures import ThreadPoolExecutor # palarell processing
|
3 |
import matplotlib.pyplot as plt
|
4 |
import pandas as pd
|
5 |
import praw # Reddit's API
|
|
|
18 |
from transformers.pipelines import AggregationStrategy
|
19 |
|
20 |
from functions import (
|
21 |
+
load_sentiment_pipeline,
|
22 |
scrape_reddit_data,
|
|
|
23 |
safe_sentiment,
|
|
|
24 |
analyze_detail,
|
25 |
preprocess_text
|
26 |
)
|
27 |
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
|
31 |
|