Code_Mate / main.py
ElPremOoO's picture
Update main.py
c0734d1 verified
from flask import Flask, request, jsonify
import torch
from transformers import RobertaTokenizer
import os
from transformers import RobertaForSequenceClassification
import torch.serialization
# Initialize Flask app
app = Flask(__name__)
# Load the trained model and tokenizer
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
torch.serialization.add_safe_globals([RobertaForSequenceClassification])
model = torch.load("model.pth", map_location=torch.device('cpu'), weights_only=False) # Load the trained model
# Ensure the model is in evaluation mode
model.eval()
@app.route("/")
def home():
return request.url
# @app.route("/predict", methods=["POST"])
@app.route("/predict")
def predict():
try:
# Debugging: print input code to check if the request is received correctly
print("Received code:", request.get_json()["code"])
data = request.get_json()
if "code" not in data:
return jsonify({"error": "Missing 'code' parameter"}), 400
code_input = data["code"]
# Tokenize the input code using the CodeBERT tokenizer
inputs = tokenizer(
code_input,
return_tensors='pt',
truncation=True,
padding='max_length',
max_length=512
)
# Make prediction using the model
with torch.no_grad():
outputs = model(**inputs)
prediction = outputs.logits.squeeze().item() # Extract the predicted score (single float)
print(f"Predicted score: {prediction}") # Debugging: Print prediction
return jsonify({"predicted_score": prediction})
except Exception as e:
return jsonify({"error": str(e)}), 500
# Run the Flask app
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)