speech-analysis / inference.py
shibly100's picture
Upload 5 files
b931ae8 verified
import argparse
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def classify(text, model, tokenizer):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
prediction = torch.argmax(logits, dim=1).item()
label = model.config.id2label.get(prediction, str(prediction))
return label
parser = argparse.ArgumentParser(description="Run inference with your fine-tuned DistilBERT model.")
parser.add_argument("--task", type=str, choices=["classification"], required=True, help="Task to run inference on.")
parser.add_argument("--model_dir", type=str, required=True, help="Relative or absolute path to model directory.")
parser.add_argument("--text", type=str, help="Input text to classify.")
args = parser.parse_args()
# Ensure the model directory is interpreted as a local folder
model_path = Path(args.model_dir).resolve()
tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True)
model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True)
if args.task == "classification":
if not args.text:
raise ValueError("Please provide --text for classification.")
result = classify(args.text, model, tokenizer)
print(f"\nInput: {args.text}\nPrediction: {result}")