ADKU commited on
Commit
e86b928
·
verified ·
1 Parent(s): 7f97f1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -2
app.py CHANGED
@@ -1,11 +1,18 @@
 
1
  from fastapi import FastAPI
2
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
3
  import torch
4
 
 
 
 
5
  app = FastAPI()
6
 
7
- model = DistilBertForSequenceClassification.from_pretrained("ADKU/ResearchGPT_model")
8
- tokenizer = DistilBertTokenizerFast.from_pretrained("ADKU/ResearchGPT_model")
 
 
 
9
 
10
  @app.post("/predict/")
11
  async def predict(text: str):
 
1
+ import os
2
  from fastapi import FastAPI
3
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
  import torch
5
 
6
+ # Set Hugging Face cache directory to a writable path
7
+ os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
8
+
9
  app = FastAPI()
10
 
11
+ model_name = "ADKU/ResearchGPT_model" # Replace with your actual Hugging Face model ID
12
+
13
+ # Load model and tokenizer
14
+ model = DistilBertForSequenceClassification.from_pretrained(model_name, cache_dir="/tmp/huggingface")
15
+ tokenizer = DistilBertTokenizerFast.from_pretrained(model_name, cache_dir="/tmp/huggingface")
16
 
17
  @app.post("/predict/")
18
  async def predict(text: str):