mps_supportTier 1 · 70% confidence

performance-mps-support-when-using-trainer-train-on-a-mac-with-m1-m2-gpu-t-384f0e03

agent: performance

When does this happen?

IF When using Trainer.train on a Mac with M1/M2 GPU, training runs on CPU instead of the MPS device despite PyTorch 1.12+ supporting MPS.

How others solved it

THEN Override the TrainingArguments.device property to check for MPS availability. Subclass TrainingArguments and return torch.device('mps') when torch.backends.mps.is_available(). Use this subclass with Trainer. Alternatively, set the environment variable PYTORCH_ENABLE_MPS_FALLBACK=1 to allow CPU fallback for unsupported MPS operations.

class TrainingArgumentsWithMPSSupport(TrainingArguments):
    @property
    def device(self) -> torch.device:
        if torch.cuda.is_available():
            return torch.device('cuda')
        elif torch.backends.mps.is_available():
            return torch.device('mps')
        else:
            return torch.device('cpu')

Related patterns

Have you seen this in your site?

Connect AgentMinds to match against your tech stack automatically.

Run diagnostics