From 5b84c5252b62ce20038851e8564095e0103df169 Mon Sep 17 00:00:00 2001 From: AmitChaubey Date: Mon, 13 Apr 2026 11:23:51 +0100 Subject: [PATCH] Refactor MNIST forward style and improve device selection --- mnist/main.py | 43 ++++++++++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/mnist/main.py b/mnist/main.py index dee5a384cb..b51e90bdec 100644 --- a/mnist/main.py +++ b/mnist/main.py @@ -17,23 +17,20 @@ def __init__(self): self.fc1 = nn.Linear(9216, 128) self.fc2 = nn.Linear(128, 10) - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Run a forward pass and return class log-probabilities.""" + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) + x = F.relu(self.fc1(x)) x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output + return F.log_softmax(self.fc2(x), dim=1) -def train(args, model, device, train_loader, optimizer, epoch): +def train(args, model: nn.Module, device, train_loader, optimizer, epoch: int) -> None: + """Train for one epoch.""" model.train() for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -50,7 +47,8 @@ def train(args, model, device, train_loader, optimizer, epoch): break -def test(model, device, test_loader): +def test(model: nn.Module, device, test_loader) -> None: + """Evaluate model on the test set.""" model.eval() test_loss = 0 correct = 0 @@ -69,7 +67,12 @@ def test(model, device, test_loader): 100. * correct / len(test_loader.dataset))) -def main(): +def _mps_available() -> bool: + """Return True when MPS backend is available.""" + return hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + + +def main() -> None: # Training settings parser = argparse.ArgumentParser(description='PyTorch MNIST Example') parser.add_argument('--batch-size', type=int, default=64, metavar='N', @@ -94,12 +97,22 @@ def main(): help='For Saving the current Model') args = parser.parse_args() - use_accel = not args.no_accel and torch.accelerator.is_available() + use_accel = not args.no_accel and ( + torch.accelerator.is_available() or torch.cuda.is_available() or _mps_available() + ) torch.manual_seed(args.seed) if use_accel: - device = torch.accelerator.current_accelerator() + if torch.accelerator.is_available(): + device = torch.accelerator.current_accelerator() + elif torch.cuda.is_available(): + device = torch.device("cuda") + elif _mps_available(): + device = torch.device("mps") + else: + device = torch.device("cpu") + use_accel = False else: device = torch.device("cpu")