rabindra-sss commited on
Commit
d6bdad1
·
verified ·
1 Parent(s): 079a7d3

Update backend.py

Browse files
Files changed (1) hide show
  1. backend.py +30 -30
backend.py CHANGED
@@ -1,30 +1,30 @@
1
- import torch
2
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
- from peft import PeftModel, PeftConfig
4
-
5
- # Load model and tokenizer only once at startup
6
- config = PeftConfig.from_pretrained("rabindra-sss/sentiment-distilbert")
7
- base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
8
- model = PeftModel.from_pretrained(base_model, "rabindra-sss/sentiment-distilbert", config=config)
9
- tokenizer = AutoTokenizer.from_pretrained("rabindra-sss/sentiment-distilbert")
10
-
11
- # Ensure model is in evaluation mode for inference
12
- model.eval()
13
-
14
- # Define id2label mappings
15
- id2label = {0: "Negative", 1: "Positive"}
16
-
17
- def predict(text: str) -> str:
18
- # Tokenize the input text
19
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
-
21
- # Run the model to get logits
22
- with torch.no_grad():
23
- outputs = model(**inputs)
24
- logits = outputs.logits
25
-
26
- # Convert logits to predicted class
27
- predictions = torch.argmax(logits, dim=-1)
28
- predicted_label = id2label[predictions.item()]
29
-
30
- return predicted_label
 
1
+ import torch
2
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
3
+ from peft import PeftModel, PeftConfig
4
+
5
+ # Load model and tokenizer only once at startup
6
+ config = PeftConfig.from_pretrained("rabindra-sss/sentiment-distilbert/")
7
+ base_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
8
+ model = PeftModel.from_pretrained(base_model, "rabindra-sss/sentiment-distilbert/", config=config)
9
+ tokenizer = AutoTokenizer.from_pretrained("rabindra-sss/sentiment-distilbert/")
10
+
11
+ # Ensure model is in evaluation mode for inference
12
+ model.eval()
13
+
14
+ # Define id2label mappings
15
+ id2label = {0: "Negative", 1: "Positive"}
16
+
17
+ def predict(text: str) -> str:
18
+ # Tokenize the input text
19
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
20
+
21
+ # Run the model to get logits
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+ logits = outputs.logits
25
+
26
+ # Convert logits to predicted class
27
+ predictions = torch.argmax(logits, dim=-1)
28
+ predicted_label = id2label[predictions.item()]
29
+
30
+ return predicted_label