ganna217 commited on
Commit
612cae5
·
1 Parent(s): 0b077c2
Files changed (2) hide show
  1. app.py +39 -6
  2. static/index.html +37 -12
app.py CHANGED
@@ -1,44 +1,77 @@
1
  from fastapi import FastAPI
2
  from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
 
4
  from transformers import pipeline
5
  import os
6
  import uvicorn
 
 
 
 
 
7
 
8
  # Set cache directory to a writable location
9
  cache_dir = "/tmp/hf_cache"
10
  os.environ["HF_HOME"] = cache_dir
11
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
12
- os.environ["TRANSFORMERS_CACHE"] = cache_dir # Add this for backward compatibility
13
 
14
  # Create the cache directory if it doesn't exist
15
  if not os.path.exists(cache_dir):
16
  os.makedirs(cache_dir, exist_ok=True)
17
 
18
  app = FastAPI()
 
 
 
 
 
 
 
 
 
 
19
  app.mount("/static", StaticFiles(directory="static"), name="static")
20
 
21
  # Load the zero-shot classification model with explicit cache directory
22
- classifier = pipeline(
23
- "zero-shot-classification",
24
- model="facebook/bart-large-mnli",
25
- cache_dir=cache_dir
26
- )
 
 
 
 
 
 
27
 
28
  @app.get("/")
29
  async def index():
 
30
  return FileResponse("static/index.html")
31
 
32
  @app.post("/classify")
33
  async def classify_text(data: dict):
 
34
  try:
35
  text = data.get("document")
36
  labels = data.get("labels")
37
  if not text or not labels:
 
38
  return {"error": "Please provide both text and labels"}, 400
 
 
 
 
 
 
39
  result = classifier(text, labels, multi_label=False)
 
40
  return {"labels": result["labels"], "scores": result["scores"]}, 200
41
  except Exception as e:
 
42
  return {"error": str(e)}, 500
43
 
44
  if __name__ == "__main__":
 
1
  from fastapi import FastAPI
2
  from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
4
+ from fastapi.middleware.cors import CORSMiddleware
5
  from transformers import pipeline
6
  import os
7
  import uvicorn
8
+ import logging
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  # Set cache directory to a writable location
15
  cache_dir = "/tmp/hf_cache"
16
  os.environ["HF_HOME"] = cache_dir
17
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
18
+ os.environ["TRANSFORMERS_CACHE"] = cache_dir
19
 
20
  # Create the cache directory if it doesn't exist
21
  if not os.path.exists(cache_dir):
22
  os.makedirs(cache_dir, exist_ok=True)
23
 
24
  app = FastAPI()
25
+
26
+ # Add CORS middleware to allow frontend requests
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # Allow all origins (you can restrict this later)
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
  app.mount("/static", StaticFiles(directory="static"), name="static")
36
 
37
  # Load the zero-shot classification model with explicit cache directory
38
+ logger.info("Loading the model...")
39
+ try:
40
+ classifier = pipeline(
41
+ "zero-shot-classification",
42
+ model="facebook/bart-large-mnli",
43
+ cache_dir=cache_dir
44
+ )
45
+ logger.info("Model loaded successfully!")
46
+ except Exception as e:
47
+ logger.error(f"Error loading model: {str(e)}")
48
+ raise
49
 
50
  @app.get("/")
51
  async def index():
52
+ logger.info("Serving index.html")
53
  return FileResponse("static/index.html")
54
 
55
  @app.post("/classify")
56
  async def classify_text(data: dict):
57
+ logger.info(f"Received classify request with data: {data}")
58
  try:
59
  text = data.get("document")
60
  labels = data.get("labels")
61
  if not text or not labels:
62
+ logger.warning("Missing text or labels in request")
63
  return {"error": "Please provide both text and labels"}, 400
64
+
65
+ # Convert labels to list if it's a string
66
+ if isinstance(labels, str):
67
+ labels = [label.strip() for label in labels.split(",")]
68
+
69
+ logger.info(f"Classifying text: {text[:50]}... with labels: {labels}")
70
  result = classifier(text, labels, multi_label=False)
71
+ logger.info(f"Classification result: {result}")
72
  return {"labels": result["labels"], "scores": result["scores"]}, 200
73
  except Exception as e:
74
+ logger.error(f"Error during classification: {str(e)}")
75
  return {"error": str(e)}, 500
76
 
77
  if __name__ == "__main__":
static/index.html CHANGED
@@ -15,7 +15,7 @@
15
  padding: 20px;
16
  display: flex;
17
  justify-content: center;
18
- min-height: 100vh; /* Ensure background covers full height */
19
  }
20
  .container {
21
  max-width: 1000px;
@@ -124,14 +124,14 @@
124
  table {
125
  width: 100%;
126
  border-collapse: collapse;
127
- background: linear-gradient(135deg, #000, #1a1a2e); /* Match body background */
128
  border-radius: 8px;
129
  overflow: hidden;
130
  }
131
  th, td {
132
  padding: 12px;
133
  text-align: left;
134
- color: #fff; /* White text for contrast with dark background */
135
  }
136
  th {
137
  font-weight: bold;
@@ -140,7 +140,7 @@
140
  display: none;
141
  position: relative;
142
  height: 300px;
143
- background: linear-gradient(135deg, #000, #1a1a2e); /* Match body background */
144
  border-radius: 8px;
145
  padding: 10px;
146
  }
@@ -203,7 +203,7 @@
203
  <div class="labels-section">
204
  <label>Labels:</label>
205
  <input type="text" id="labels" placeholder="Enter labels (comma-separated)" value="mystery, drama, fantasy, history">
206
- <button onclick="classifyText()">Classify</button>
207
  </div>
208
  </div>
209
  <div class="results">
@@ -238,35 +238,60 @@
238
  }
239
  });
240
 
 
 
241
  async function classifyText() {
242
  const textInput = document.getElementById('document').value;
243
  const labelsInput = document.getElementById('labels').value;
244
- const labels = labelsInput.split(',').map(l => l.trim()).filter(l => l);
245
  const loadingElement = document.getElementById('loading');
246
  const chartContainer = document.querySelector('.chart-container');
247
 
248
- if (!textInput || textInput === textarea.getAttribute('placeholder') || !labels.length) {
249
- alert('Please enter both text and labels');
 
 
 
 
 
250
  return;
251
  }
252
 
 
 
 
 
 
253
  loadingElement.style.display = 'block';
254
  chartContainer.style.display = 'none';
255
 
256
  try {
257
- const response = await fetch('/classify', { // Updated to relative path for FastAPI
258
  method: 'POST',
259
- headers: { 'Content-Type': 'application/json' },
260
- body: JSON.stringify({ document: textInput, labels })
 
 
261
  });
262
 
263
- if (!response.ok) throw new Error('Classification failed');
 
 
 
 
 
264
 
265
  const results = await response.json();
 
 
 
 
 
 
266
  updateTable(results);
267
  updateChart(results);
268
  chartContainer.style.display = 'block';
269
  } catch (error) {
 
270
  alert('Error: ' + error.message);
271
  } finally {
272
  loadingElement.style.display = 'none';
 
15
  padding: 20px;
16
  display: flex;
17
  justify-content: center;
18
+ min-height: 100vh;
19
  }
20
  .container {
21
  max-width: 1000px;
 
124
  table {
125
  width: 100%;
126
  border-collapse: collapse;
127
+ background: linear-gradient(135deg, #000, #1a1a2e);
128
  border-radius: 8px;
129
  overflow: hidden;
130
  }
131
  th, td {
132
  padding: 12px;
133
  text-align: left;
134
+ color: #fff;
135
  }
136
  th {
137
  font-weight: bold;
 
140
  display: none;
141
  position: relative;
142
  height: 300px;
143
+ background: linear-gradient(135deg, #000, #1a1a2e);
144
  border-radius: 8px;
145
  padding: 10px;
146
  }
 
203
  <div class="labels-section">
204
  <label>Labels:</label>
205
  <input type="text" id="labels" placeholder="Enter labels (comma-separated)" value="mystery, drama, fantasy, history">
206
+ <button id="classifyBtn">Classify</button>
207
  </div>
208
  </div>
209
  <div class="results">
 
238
  }
239
  });
240
 
241
+ document.getElementById("classifyBtn").addEventListener("click", classifyText);
242
+
243
  async function classifyText() {
244
  const textInput = document.getElementById('document').value;
245
  const labelsInput = document.getElementById('labels').value;
 
246
  const loadingElement = document.getElementById('loading');
247
  const chartContainer = document.querySelector('.chart-container');
248
 
249
+ // Validate inputs
250
+ if (!textInput || textInput === textarea.getAttribute('placeholder')) {
251
+ alert('Please enter some text to classify.');
252
+ return;
253
+ }
254
+ if (!labelsInput) {
255
+ alert('Please enter at least one label.');
256
  return;
257
  }
258
 
259
+ // Keep labels as a comma-separated string (as expected by the backend)
260
+ const labels = labelsInput.trim();
261
+ console.log("Sending request with text:", textInput.substring(0, 50) + "...");
262
+ console.log("Labels:", labels);
263
+
264
  loadingElement.style.display = 'block';
265
  chartContainer.style.display = 'none';
266
 
267
  try {
268
+ const response = await fetch('/classify', {
269
  method: 'POST',
270
+ headers: {
271
+ 'Content-Type': 'application/json',
272
+ },
273
+ body: JSON.stringify({ document: textInput, labels: labels }),
274
  });
275
 
276
+ console.log("Response status:", response.status);
277
+
278
+ if (!response.ok) {
279
+ const errorText = await response.text();
280
+ throw new Error(`Classification failed with status ${response.status}: ${errorText}`);
281
+ }
282
 
283
  const results = await response.json();
284
+ console.log("Received results:", results);
285
+
286
+ if (results.error) {
287
+ throw new Error(results.error);
288
+ }
289
+
290
  updateTable(results);
291
  updateChart(results);
292
  chartContainer.style.display = 'block';
293
  } catch (error) {
294
+ console.error("Error during classification:", error);
295
  alert('Error: ' + error.message);
296
  } finally {
297
  loadingElement.style.display = 'none';