ganna217 commited on
Commit
f7df7d4
·
1 Parent(s): c8e56ea
Files changed (1) hide show
  1. app.py +16 -20
app.py CHANGED
@@ -1,45 +1,41 @@
1
- from fastapi import FastAPI, Request
2
- from fastapi.responses import HTMLResponse
3
  from fastapi.staticfiles import StaticFiles
4
- from fastapi.templating import Jinja2Templates
5
  from transformers import pipeline
6
  import os
7
  import uvicorn
8
 
9
- app = FastAPI()
 
 
 
 
 
 
 
10
 
11
- # Mount the templates directory for serving HTML
12
- templates = Jinja2Templates(directory="templates")
13
 
14
  # Load the zero-shot classification model
15
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
16
 
17
- # Route to serve index.html
18
- @app.get("/", response_class=HTMLResponse)
19
- async def index(request: Request):
20
- return templates.TemplateResponse("index.html", {"request": request})
21
 
22
- # Route to handle text classification requests
23
  @app.post("/classify")
24
  async def classify_text(data: dict):
25
  try:
26
  text = data.get("document")
27
  labels = data.get("labels")
28
-
29
  if not text or not labels:
30
  return {"error": "Please provide both text and labels"}, 400
31
-
32
- # Perform classification
33
  result = classifier(text, labels, multi_label=False)
34
- response = {
35
- "labels": result["labels"],
36
- "scores": result["scores"]
37
- }
38
- return response, 200
39
  except Exception as e:
40
  return {"error": str(e)}, 500
41
 
42
- # Run the app on Hugging Face's required port
43
  if __name__ == "__main__":
44
  port = int(os.environ.get("PORT", 7860))
45
  uvicorn.run(app, host="0.0.0.0", port=port)
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import FileResponse
3
  from fastapi.staticfiles import StaticFiles
 
4
  from transformers import pipeline
5
  import os
6
  import uvicorn
7
 
8
+ # Set cache directory to a writable location
9
+ cache_dir = "/tmp/hf_cache"
10
+ os.environ["HF_HOME"] = cache_dir
11
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
12
+
13
+ # Create the cache directory if it doesn't exist
14
+ if not os.path.exists(cache_dir):
15
+ os.makedirs(cache_dir, exist_ok=True)
16
 
17
+ app = FastAPI()
18
+ app.mount("/static", StaticFiles(directory="static"), name="static")
19
 
20
  # Load the zero-shot classification model
21
  classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
22
 
23
+ @app.get("/")
24
+ async def index():
25
+ return FileResponse("static/index.html")
 
26
 
 
27
  @app.post("/classify")
28
  async def classify_text(data: dict):
29
  try:
30
  text = data.get("document")
31
  labels = data.get("labels")
 
32
  if not text or not labels:
33
  return {"error": "Please provide both text and labels"}, 400
 
 
34
  result = classifier(text, labels, multi_label=False)
35
+ return {"labels": result["labels"], "scores": result["scores"]}, 200
 
 
 
 
36
  except Exception as e:
37
  return {"error": str(e)}, 500
38
 
 
39
  if __name__ == "__main__":
40
  port = int(os.environ.get("PORT", 7860))
41
  uvicorn.run(app, host="0.0.0.0", port=port)