Spaces:
Sleeping
Sleeping
ezequiellopez
commited on
Commit
Β·
e145e85
1
Parent(s):
b0a3f00
setting up
Browse files- .env +2 -0
- README.md +1 -1
- app/main.py +127 -0
- requirements.txt +3 -1
.env
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
REDIS_PORT=6379
|
2 |
+
FASTAPI_PORT=7860
|
README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
|
|
1 |
---
|
2 |
+
title: MPIB Prosocial
|
3 |
emoji: π
|
4 |
colorFrom: indigo
|
5 |
colorTo: blue
|
app/main.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Import required libraries
|
2 |
+
from fastapi import FastAPI, HTTPException
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from typing import List
|
5 |
+
import redis
|
6 |
+
from transformers import BartForSequenceClassification, BartTokenizer, AutoTokenizer, AutoConfig, pipeline
|
7 |
+
from dotenv import load_dotenv
|
8 |
+
import os
|
9 |
+
|
10 |
+
# Load environment variables from .env file
|
11 |
+
load_dotenv('../.env')
|
12 |
+
|
13 |
+
# Access environment variables
|
14 |
+
redis_port = os.getenv("REDIS_PORT")
|
15 |
+
fastapi_port = os.getenv("FASTAPI_PORT")
|
16 |
+
|
17 |
+
|
18 |
+
print("Redis port:", redis_port)
|
19 |
+
print("FastAPI port:", fastapi_port)
|
20 |
+
|
21 |
+
|
22 |
+
# Initialize FastAPI app and Redis client
|
23 |
+
app = FastAPI()
|
24 |
+
redis_client = redis.Redis(host='redis', port=6379)
|
25 |
+
|
26 |
+
# Load BART model and tokenizer
|
27 |
+
#model = BartForSequenceClassification.from_pretrained("facebook/bart-large-mnli")
|
28 |
+
#tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-mnli")
|
29 |
+
|
30 |
+
model = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
31 |
+
|
32 |
+
def score_text_with_labels(model, text: list, labels: list, multi: bool=True):
|
33 |
+
#candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
|
34 |
+
results = [result['scores'] for result in model(text, labels, multi_label=multi)]
|
35 |
+
#return dict(zip(labels, results['scores']))
|
36 |
+
return results
|
37 |
+
|
38 |
+
def smooth_sequence(tweets_scores, window_size):
|
39 |
+
# Calculate the sum of scores for both labels for each tweet
|
40 |
+
tweet_sum_scores = [(sum(scores), index) for index, scores in enumerate(tweets_scores)]
|
41 |
+
# Sort tweets based on their sum scores, then by their original index to stabilize
|
42 |
+
sorted_tweets = sorted(tweet_sum_scores, key=lambda x: (x[0], x[1]))
|
43 |
+
# Extract the original indices of tweets after sorting
|
44 |
+
sorted_indices = [index for _, index in sorted_tweets]
|
45 |
+
# Create a new sequence based on sorted indices
|
46 |
+
smoothed_sequence = [tweets_scores[index] for index in sorted_indices]
|
47 |
+
return smoothed_sequence
|
48 |
+
|
49 |
+
def rerank_on_label(label: str):
|
50 |
+
return 200
|
51 |
+
|
52 |
+
|
53 |
+
# Define Pydantic models
|
54 |
+
class Item(BaseModel):
|
55 |
+
#id: str
|
56 |
+
#title: str = None
|
57 |
+
text: str
|
58 |
+
#type: str
|
59 |
+
#engagements: dict
|
60 |
+
|
61 |
+
class RerankedItems(BaseModel):
|
62 |
+
ranked_ids: List[str]
|
63 |
+
new_items: List[dict]
|
64 |
+
|
65 |
+
# Define a health check endpoint
|
66 |
+
@app.get("/")
|
67 |
+
async def health_check():
|
68 |
+
return {"status": "ok"}
|
69 |
+
|
70 |
+
# Define FastAPI routes and logic
|
71 |
+
@app.post("/rerank/")
|
72 |
+
async def rerank_items(items: List[Item]) -> RerankedItems:
|
73 |
+
reranked_ids = []
|
74 |
+
|
75 |
+
# Process each item
|
76 |
+
for item in items:
|
77 |
+
# Classify the item using Hugging Face BART model
|
78 |
+
labels = classify_item(item.text)
|
79 |
+
|
80 |
+
# Save the item with labels in Redis
|
81 |
+
redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
|
82 |
+
|
83 |
+
# Add the item id to the reranked list
|
84 |
+
reranked_ids.append(item.id)
|
85 |
+
|
86 |
+
# Sort the items based on model confidence
|
87 |
+
reranked_ids.sort(key=lambda x: redis_client.zscore("classified_items", x), reverse=True)
|
88 |
+
|
89 |
+
# Return the reranked items
|
90 |
+
return {"ranked_ids": reranked_ids, "new_items": []} # Ignore "new_items" for now
|
91 |
+
|
92 |
+
# Define an endpoint to classify items and save them in Redis
|
93 |
+
@app.post("/classify/")
|
94 |
+
async def classify_and_save(items: List[Item]) -> None:
|
95 |
+
print("new 1")
|
96 |
+
#labels = ["factful", "civic", "constructive", "politics", "health", "news"]
|
97 |
+
#labels = ["factful", "politics"]
|
98 |
+
labels = ["something else", "news feed, news articles, breaking news", "politics and polititians", "healthcare and health"]
|
99 |
+
#labels = ["health", "politics", "news", "non-health non-politics non-news"]
|
100 |
+
texts = [item.text for item in items]
|
101 |
+
print(texts)
|
102 |
+
|
103 |
+
labels = score_text_with_labels(model=model, text=texts, labels=labels, multi=True)
|
104 |
+
print(labels)
|
105 |
+
return labels
|
106 |
+
#for item in items:
|
107 |
+
# print(item)
|
108 |
+
# Classify the item using Hugging Face BART model
|
109 |
+
#labels = classify_item(item.text)
|
110 |
+
#return score_text_with_labels(model, item.text, labels)
|
111 |
+
# Save the item with labels in Redis
|
112 |
+
#redis_client.hset(item.id, mapping={"title": item.title, "text": item.text, "labels": ",".join(labels)})
|
113 |
+
#return labels
|
114 |
+
#return None
|
115 |
+
|
116 |
+
# Function to classify item text using Hugging Face BART model
|
117 |
+
def classify_item(text: str) -> List[str]:
|
118 |
+
# Tokenize input text
|
119 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
|
120 |
+
print(1)
|
121 |
+
# Perform inference
|
122 |
+
outputs = model(**inputs)
|
123 |
+
print(2)
|
124 |
+
# Get predicted label
|
125 |
+
predicted_label = tokenizer.decode(outputs.logits.argmax())
|
126 |
+
|
127 |
+
return [predicted_label]
|
requirements.txt
CHANGED
@@ -4,4 +4,6 @@ transformers
|
|
4 |
python-dotenv
|
5 |
dotenv-cli
|
6 |
pandas
|
7 |
-
uvicorn
|
|
|
|
|
|
4 |
python-dotenv
|
5 |
dotenv-cli
|
6 |
pandas
|
7 |
+
uvicorn
|
8 |
+
pydantic
|
9 |
+
redis
|