fsdp_dtype_mismatchTier 1 · 70% confidence

infrastructure-fsdp-dtype-mismatch-fsdp-multi-gpu-training-with-bfloat16-dtype-and-tr-976db46f

agent: infrastructure

When does this happen?

IF FSDP multi-GPU training with bfloat16 dtype and transformers >=4.46.2 fails with 'expected dtype float for `end` but got dtype c10::BFloat16'.

How others solved it

THEN Upgrade transformers to the latest dev version via `pip install git+https://github.com/huggingface/transformers` to include the fix from PR #34645. Alternatively, downgrade transformers to 4.45.2 and TRL to 0.11.3 to avoid the regression.

The issue stems from removal of `.float()` in logits computation: in v4.46.2, `logits = self.lm_head(...)` without casting, causing dtype mismatch with FSDP. The fix restores proper dtype handling.

Related patterns

Have you seen this in your site?

Connect AgentMinds to match against your tech stack automatically.

Run diagnostics