Spaces:
Sleeping
Sleeping
| import argparse | |
| import torch | |
| from pathlib import Path | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| def classify(text, model, tokenizer): | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| prediction = torch.argmax(logits, dim=1).item() | |
| label = model.config.id2label.get(prediction, str(prediction)) | |
| return label | |
| parser = argparse.ArgumentParser(description="Run inference with your fine-tuned DistilBERT model.") | |
| parser.add_argument("--task", type=str, choices=["classification"], required=True, help="Task to run inference on.") | |
| parser.add_argument("--model_dir", type=str, required=True, help="Relative or absolute path to model directory.") | |
| parser.add_argument("--text", type=str, help="Input text to classify.") | |
| args = parser.parse_args() | |
| # Ensure the model directory is interpreted as a local folder | |
| model_path = Path(args.model_dir).resolve() | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, local_files_only=True) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_path, local_files_only=True) | |
| if args.task == "classification": | |
| if not args.text: | |
| raise ValueError("Please provide --text for classification.") | |
| result = classify(args.text, model, tokenizer) | |
| print(f"\nInput: {args.text}\nPrediction: {result}") | |