AavV4 commited on
Commit
59dceb9
·
verified ·
1 Parent(s): ed5e9fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+
5
+ # Define model paths based on uploaded repo files
6
+ MODEL_PATH = "trained_model2/distilroberta_model.pth"
7
+ TOKENIZER_DIR = "trained_model2/distilroberta_tokenizer"
8
+
9
+ # Load tokenizer
10
+ tokenizer_rl = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
11
+
12
+ # Load model
13
+ model_rl = AutoModelForSequenceClassification.from_pretrained('distilroberta-base', num_labels=2)
14
+ model_rl.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device('cpu')))
15
+ model_rl.eval()
16
+
17
+ # RL model classification function
18
+ def classify_with_rl(text):
19
+ inputs = tokenizer_rl(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
20
+ with torch.no_grad():
21
+ outputs = model_rl(**inputs)
22
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
23
+ return {"spam_probability": max(0, min(1, float(probs[0][1])))}
24
+
25
+ # Create API
26
+ iface = gr.Interface(fn=classify_with_rl, inputs=gr.Textbox(), outputs="json")
27
+
28
+ # Launch API
29
+ if __name__ == "__main__":
30
+ iface.launch()