ADKU commited on
Commit
b8e7c7f
·
verified ·
1 Parent(s): e86b928

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,22 +1,26 @@
1
  import os
2
  from fastapi import FastAPI
 
3
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
  import torch
5
 
6
- # Set Hugging Face cache directory to a writable path
7
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
8
 
9
  app = FastAPI()
10
 
11
- model_name = "ADKU/ResearchGPT_model" # Replace with your actual Hugging Face model ID
 
12
 
13
- # Load model and tokenizer
14
- model = DistilBertForSequenceClassification.from_pretrained(model_name, cache_dir="/tmp/huggingface")
15
- tokenizer = DistilBertTokenizerFast.from_pretrained(model_name, cache_dir="/tmp/huggingface")
16
 
17
  @app.post("/predict/")
18
- async def predict(text: str):
19
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
  outputs = model(**inputs)
21
  prediction = torch.argmax(outputs.logits, dim=-1).item()
22
  return {"prediction": prediction}
 
 
 
 
 
1
  import os
2
  from fastapi import FastAPI
3
+ from pydantic import BaseModel
4
  from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
5
  import torch
6
 
 
7
  os.environ["HUGGINGFACE_HUB_CACHE"] = "/tmp/huggingface"
8
 
9
  app = FastAPI()
10
 
11
+ model = DistilBertForSequenceClassification.from_pretrained("ADKU/ResearchGPT_model", cache_dir="/tmp/huggingface")
12
+ tokenizer = DistilBertTokenizerFast.from_pretrained("ADKU/ResearchGPT_model", cache_dir="/tmp/huggingface")
13
 
14
+ class InputText(BaseModel):
15
+ inputs: str
 
16
 
17
  @app.post("/predict/")
18
+ async def predict(data: InputText):
19
+ inputs = tokenizer(data.inputs, return_tensors="pt", padding=True, truncation=True)
20
  outputs = model(**inputs)
21
  prediction = torch.argmax(outputs.logits, dim=-1).item()
22
  return {"prediction": prediction}
23
+
24
+ if __name__ == "__main__":
25
+ import uvicorn
26
+ uvicorn.run(app, host="0.0.0.0")