Spaces:
Sleeping
Sleeping
update
Browse files- app.py +39 -6
- static/index.html +37 -12
app.py
CHANGED
@@ -1,44 +1,77 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.responses import FileResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
|
|
4 |
from transformers import pipeline
|
5 |
import os
|
6 |
import uvicorn
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
# Set cache directory to a writable location
|
9 |
cache_dir = "/tmp/hf_cache"
|
10 |
os.environ["HF_HOME"] = cache_dir
|
11 |
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
|
12 |
-
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
13 |
|
14 |
# Create the cache directory if it doesn't exist
|
15 |
if not os.path.exists(cache_dir):
|
16 |
os.makedirs(cache_dir, exist_ok=True)
|
17 |
|
18 |
app = FastAPI()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
20 |
|
21 |
# Load the zero-shot classification model with explicit cache directory
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
@app.get("/")
|
29 |
async def index():
|
|
|
30 |
return FileResponse("static/index.html")
|
31 |
|
32 |
@app.post("/classify")
|
33 |
async def classify_text(data: dict):
|
|
|
34 |
try:
|
35 |
text = data.get("document")
|
36 |
labels = data.get("labels")
|
37 |
if not text or not labels:
|
|
|
38 |
return {"error": "Please provide both text and labels"}, 400
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
result = classifier(text, labels, multi_label=False)
|
|
|
40 |
return {"labels": result["labels"], "scores": result["scores"]}, 200
|
41 |
except Exception as e:
|
|
|
42 |
return {"error": str(e)}, 500
|
43 |
|
44 |
if __name__ == "__main__":
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.responses import FileResponse
|
3 |
from fastapi.staticfiles import StaticFiles
|
4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
5 |
from transformers import pipeline
|
6 |
import os
|
7 |
import uvicorn
|
8 |
+
import logging
|
9 |
+
|
10 |
+
# Set up logging
|
11 |
+
logging.basicConfig(level=logging.INFO)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
# Set cache directory to a writable location
|
15 |
cache_dir = "/tmp/hf_cache"
|
16 |
os.environ["HF_HOME"] = cache_dir
|
17 |
os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
|
18 |
+
os.environ["TRANSFORMERS_CACHE"] = cache_dir
|
19 |
|
20 |
# Create the cache directory if it doesn't exist
|
21 |
if not os.path.exists(cache_dir):
|
22 |
os.makedirs(cache_dir, exist_ok=True)
|
23 |
|
24 |
app = FastAPI()
|
25 |
+
|
26 |
+
# Add CORS middleware to allow frontend requests
|
27 |
+
app.add_middleware(
|
28 |
+
CORSMiddleware,
|
29 |
+
allow_origins=["*"], # Allow all origins (you can restrict this later)
|
30 |
+
allow_credentials=True,
|
31 |
+
allow_methods=["*"],
|
32 |
+
allow_headers=["*"],
|
33 |
+
)
|
34 |
+
|
35 |
app.mount("/static", StaticFiles(directory="static"), name="static")
|
36 |
|
37 |
# Load the zero-shot classification model with explicit cache directory
|
38 |
+
logger.info("Loading the model...")
|
39 |
+
try:
|
40 |
+
classifier = pipeline(
|
41 |
+
"zero-shot-classification",
|
42 |
+
model="facebook/bart-large-mnli",
|
43 |
+
cache_dir=cache_dir
|
44 |
+
)
|
45 |
+
logger.info("Model loaded successfully!")
|
46 |
+
except Exception as e:
|
47 |
+
logger.error(f"Error loading model: {str(e)}")
|
48 |
+
raise
|
49 |
|
50 |
@app.get("/")
|
51 |
async def index():
|
52 |
+
logger.info("Serving index.html")
|
53 |
return FileResponse("static/index.html")
|
54 |
|
55 |
@app.post("/classify")
|
56 |
async def classify_text(data: dict):
|
57 |
+
logger.info(f"Received classify request with data: {data}")
|
58 |
try:
|
59 |
text = data.get("document")
|
60 |
labels = data.get("labels")
|
61 |
if not text or not labels:
|
62 |
+
logger.warning("Missing text or labels in request")
|
63 |
return {"error": "Please provide both text and labels"}, 400
|
64 |
+
|
65 |
+
# Convert labels to list if it's a string
|
66 |
+
if isinstance(labels, str):
|
67 |
+
labels = [label.strip() for label in labels.split(",")]
|
68 |
+
|
69 |
+
logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
|
70 |
result = classifier(text, labels, multi_label=False)
|
71 |
+
logger.info(f"Classification result: {result}")
|
72 |
return {"labels": result["labels"], "scores": result["scores"]}, 200
|
73 |
except Exception as e:
|
74 |
+
logger.error(f"Error during classification: {str(e)}")
|
75 |
return {"error": str(e)}, 500
|
76 |
|
77 |
if __name__ == "__main__":
|
static/index.html
CHANGED
@@ -15,7 +15,7 @@
|
|
15 |
padding: 20px;
|
16 |
display: flex;
|
17 |
justify-content: center;
|
18 |
-
min-height: 100vh;
|
19 |
}
|
20 |
.container {
|
21 |
max-width: 1000px;
|
@@ -124,14 +124,14 @@
|
|
124 |
table {
|
125 |
width: 100%;
|
126 |
border-collapse: collapse;
|
127 |
-
background: linear-gradient(135deg, #000, #1a1a2e);
|
128 |
border-radius: 8px;
|
129 |
overflow: hidden;
|
130 |
}
|
131 |
th, td {
|
132 |
padding: 12px;
|
133 |
text-align: left;
|
134 |
-
color: #fff;
|
135 |
}
|
136 |
th {
|
137 |
font-weight: bold;
|
@@ -140,7 +140,7 @@
|
|
140 |
display: none;
|
141 |
position: relative;
|
142 |
height: 300px;
|
143 |
-
background: linear-gradient(135deg, #000, #1a1a2e);
|
144 |
border-radius: 8px;
|
145 |
padding: 10px;
|
146 |
}
|
@@ -203,7 +203,7 @@
|
|
203 |
<div class="labels-section">
|
204 |
<label>Labels:</label>
|
205 |
<input type="text" id="labels" placeholder="Enter labels (comma-separated)" value="mystery, drama, fantasy, history">
|
206 |
-
<button
|
207 |
</div>
|
208 |
</div>
|
209 |
<div class="results">
|
@@ -238,35 +238,60 @@
|
|
238 |
}
|
239 |
});
|
240 |
|
|
|
|
|
241 |
async function classifyText() {
|
242 |
const textInput = document.getElementById('document').value;
|
243 |
const labelsInput = document.getElementById('labels').value;
|
244 |
-
const labels = labelsInput.split(',').map(l => l.trim()).filter(l => l);
|
245 |
const loadingElement = document.getElementById('loading');
|
246 |
const chartContainer = document.querySelector('.chart-container');
|
247 |
|
248 |
-
|
249 |
-
|
|
|
|
|
|
|
|
|
|
|
250 |
return;
|
251 |
}
|
252 |
|
|
|
|
|
|
|
|
|
|
|
253 |
loadingElement.style.display = 'block';
|
254 |
chartContainer.style.display = 'none';
|
255 |
|
256 |
try {
|
257 |
-
const response = await fetch('/classify', {
|
258 |
method: 'POST',
|
259 |
-
headers: {
|
260 |
-
|
|
|
|
|
261 |
});
|
262 |
|
263 |
-
|
|
|
|
|
|
|
|
|
|
|
264 |
|
265 |
const results = await response.json();
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
updateTable(results);
|
267 |
updateChart(results);
|
268 |
chartContainer.style.display = 'block';
|
269 |
} catch (error) {
|
|
|
270 |
alert('Error: ' + error.message);
|
271 |
} finally {
|
272 |
loadingElement.style.display = 'none';
|
|
|
15 |
padding: 20px;
|
16 |
display: flex;
|
17 |
justify-content: center;
|
18 |
+
min-height: 100vh;
|
19 |
}
|
20 |
.container {
|
21 |
max-width: 1000px;
|
|
|
124 |
table {
|
125 |
width: 100%;
|
126 |
border-collapse: collapse;
|
127 |
+
background: linear-gradient(135deg, #000, #1a1a2e);
|
128 |
border-radius: 8px;
|
129 |
overflow: hidden;
|
130 |
}
|
131 |
th, td {
|
132 |
padding: 12px;
|
133 |
text-align: left;
|
134 |
+
color: #fff;
|
135 |
}
|
136 |
th {
|
137 |
font-weight: bold;
|
|
|
140 |
display: none;
|
141 |
position: relative;
|
142 |
height: 300px;
|
143 |
+
background: linear-gradient(135deg, #000, #1a1a2e);
|
144 |
border-radius: 8px;
|
145 |
padding: 10px;
|
146 |
}
|
|
|
203 |
<div class="labels-section">
|
204 |
<label>Labels:</label>
|
205 |
<input type="text" id="labels" placeholder="Enter labels (comma-separated)" value="mystery, drama, fantasy, history">
|
206 |
+
<button id="classifyBtn">Classify</button>
|
207 |
</div>
|
208 |
</div>
|
209 |
<div class="results">
|
|
|
238 |
}
|
239 |
});
|
240 |
|
241 |
+
document.getElementById("classifyBtn").addEventListener("click", classifyText);
|
242 |
+
|
243 |
async function classifyText() {
|
244 |
const textInput = document.getElementById('document').value;
|
245 |
const labelsInput = document.getElementById('labels').value;
|
|
|
246 |
const loadingElement = document.getElementById('loading');
|
247 |
const chartContainer = document.querySelector('.chart-container');
|
248 |
|
249 |
+
// Validate inputs
|
250 |
+
if (!textInput || textInput === textarea.getAttribute('placeholder')) {
|
251 |
+
alert('Please enter some text to classify.');
|
252 |
+
return;
|
253 |
+
}
|
254 |
+
if (!labelsInput) {
|
255 |
+
alert('Please enter at least one label.');
|
256 |
return;
|
257 |
}
|
258 |
|
259 |
+
// Keep labels as a comma-separated string (as expected by the backend)
|
260 |
+
const labels = labelsInput.trim();
|
261 |
+
console.log("Sending request with text:", textInput.substring(0, 50) + "...");
|
262 |
+
console.log("Labels:", labels);
|
263 |
+
|
264 |
loadingElement.style.display = 'block';
|
265 |
chartContainer.style.display = 'none';
|
266 |
|
267 |
try {
|
268 |
+
const response = await fetch('/classify', {
|
269 |
method: 'POST',
|
270 |
+
headers: {
|
271 |
+
'Content-Type': 'application/json',
|
272 |
+
},
|
273 |
+
body: JSON.stringify({ document: textInput, labels: labels }),
|
274 |
});
|
275 |
|
276 |
+
console.log("Response status:", response.status);
|
277 |
+
|
278 |
+
if (!response.ok) {
|
279 |
+
const errorText = await response.text();
|
280 |
+
throw new Error(`Classification failed with status ${response.status}: ${errorText}`);
|
281 |
+
}
|
282 |
|
283 |
const results = await response.json();
|
284 |
+
console.log("Received results:", results);
|
285 |
+
|
286 |
+
if (results.error) {
|
287 |
+
throw new Error(results.error);
|
288 |
+
}
|
289 |
+
|
290 |
updateTable(results);
|
291 |
updateChart(results);
|
292 |
chartContainer.style.display = 'block';
|
293 |
} catch (error) {
|
294 |
+
console.error("Error during classification:", error);
|
295 |
alert('Error: ' + error.message);
|
296 |
} finally {
|
297 |
loadingElement.style.display = 'none';
|