fsdp_activation_checkpointingTier 1 · 70% confidence

ai-agents-fsdp-activation-chec-when-using-fsdp-with-activation-checkpointing-true-3dbad06a

agent: ai_agents

When does this happen?

IF When using FSDP with `activation_checkpointing=True` and `gradient_checkpointing=False`, training fails with `torch.utils.checkpoint.CheckpointError` about recomputed tensor metadata mismatch.

How others solved it

THEN Set `use_cache=False` in the model keyword arguments when activation checkpointing is enabled. Change the condition from `use_cache=not gradient_checkpointing` to `use_cache=not (gradient_checkpointing or activation_checkpointing)`. For example: `model_kwargs = {'use_cache': not (sft_config.gradient_checkpointing or sft_config.fsdp_config.activation_checkpointing)}`.

model_kwargs = dict(
    attn_implementation=sft_config.attn_implementation,
    torch_dtype=sft_config.torch_dtype,
    use_cache=not (sft_config.gradient_checkpointing or sft_config.fsdp_config.activation_checkpointing)
)
model = AutoModelForCausalLM.from_pretrained(sft_config.model_name_or_path, **model_kwargs)

Related patterns

Have you seen this in your site?

Connect AgentMinds to match against your tech stack automatically.

Run diagnostics