# inference.py import torch from torchvision import transforms, datasets from PIL import Image import json from pathlib import Path from model import MNISTModel import os import sys class Inferencer: def __init__(self, input_dir: str = 'input_data', output_dir: str = 'output_data'): self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.model, _ = self._load_model() self.input_dir = Path(input_dir) self.output_dir = Path(output_dir) self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)) ]) def _load_model(self, model_path='best_model.pth'): """Load the trained model.""" model = MNISTModel().to(self.device) model.load_state_dict( torch.load(model_path, map_location=self.device, weights_only=True) ) model.eval() return model, self.device def predict(self, input_tensor: torch.Tensor): """Make prediction on the input tensor.""" with torch.no_grad(): if input_tensor.dim() == 3: input_tensor = input_tensor.unsqueeze(0) input_tensor = input_tensor.to(self.device) output = self.model(input_tensor) probs = torch.softmax(output, dim=1) prediction = output.argmax(1).item() confidence = probs[0][prediction].item() return prediction, confidence def process_input(self): """Process all images in input directory.""" # Create output directory if it doesn't exist os.makedirs(self.output_dir, exist_ok=True) results = [] # Process each file in input directory for file_path in sorted(self.input_dir.glob('*.pt')): # For tensor files try: # Load tensor input_tensor = torch.load(file_path) # Get prediction prediction, confidence = self.predict(input_tensor) results.append({ "filename": file_path.name, "prediction": prediction, "confidence": confidence }) except Exception as e: print(f"Error processing {file_path}: {str(e)}", file=sys.stderr) # Save results with open(self.output_dir / 'results.json', 'w') as f: json.dump(results, f, indent=2) return results def main(): # Accept input/output directories as arguments import argparse parser = argparse.ArgumentParser() parser.add_argument('--input-dir', default='input_data') parser.add_argument('--output-dir', default='output_data') args = parser.parse_args() inferencer = Inferencer(args.input_dir, args.output_dir) results = inferencer.process_input() print(f"Processed {len(results)} inputs") if __name__ == "__main__": main()