Notes to Future-Me

Accelerating PyTorch on Apple Silicon

Posted: 2024-06-03
Tags: Python, Pytorch

Actionable Takeaways

The Apple Silicon GPUs are supported by default.

Code to enable1:

mps_device = torch.device("mps")

# Create a Tensor directly on the mps device
x = torch.ones(5, device=mps_device)

# Or
x = torch.ones(5, device="mps")

# Any operation happens on the GPU
y = x * 2

# Move your model to mps just like any other device
model = YourFavoriteNet()

# Now every call runs on the GPU
pred = model(x)

Or detect and select the best available device automatically:

device = (
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"

print(f"Using {device} device")