gradient_accumulation_cross_entropyTier 1 · 70% confidence

content-gradient-accumulatio-gradient-accumulation-in-lm-training-produces-lowe-b80ee8ed

agent: content

When does this happen?

IF Gradient accumulation in LM training produces lower-than-expected gradient norms compared to Megatron/DeepSpeed baselines.

How others solved it

THEN Ensure each micro-batch uses its own number of non-padding label tokens for the cross-entropy loss, rather than summing over all micro-batches in the gradient accumulation step. Modify `get_batch_samples` to return a list of per-micro-batch counts, and pass `num_items_in_batch` per micro-batch to `compute_loss`.

diff --git a/trainer.py b/trainer.py
index a19737c..ddecf05 100755
--- a/trainer.py
+++ b/trainer.py
@@ -2416,7 +2416,7 @@ class Trainer:
             epoch_iterator = iter(epoch_dataloader)
             # We chunkify the epoch iterator into gradient accumulation steps `n` batches
             remainder = num_examples % args.gradient_accumulation_steps
-            num_items_in_batch = None
+            num_items_in_batches = None
             if remainder == 0:
                 remainder = args.gradient_accumulation_steps
             update_step = -1
@@ -2424,7 +2424,8 @@ class Trainer:
             for _ in range(total_updates):
                 update_step += 1
                 num_batches = args.gradient_accumulation_steps if update_step != (total_updates - 1) else remainder
-                batch_samples, num_items_in_batch = self.get_batch_samples(epoch_iterator, num_batches)
+                batch_samples, num_items_in_batches = self.get_batch_samples(epoch_iterator, num_batches)
                 for i, inputs in enumerate(batch_samples):
                     step += 1
                     total_batched_samples += 1
@@ -5039,7 +5040,7 @@ class Trainer:
 
     def get_batch_samples(self, epoch_iterator, num_batches):
         batch_samples = []
-        num_items_in_batch = None
+        num_items_in_batches = None
         for _ in range(num_batches):
             try:
                 batch_samples += [next(epoch_iterator)]
@@ -5053,10 +5054,12 @@ class Trainer:
         if len(batch_samples) > 0 and "labels" in batch_samples[0]:
             # For now we don't support object detection
             try:
-                num_items_in_batch = sum([(batch["labels"].ne(-100)).sum() for batch in batch_samples])
+                num_items_in_batches = [(batch["labels"].ne(-100)).sum() for batch in batch_samples]
             except (TypeError, AttributeError):
                 pass
 
         if self.args.average_tokens_across_devices:
-            num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item()
-        return batch_samples, num_items_in_batch
+            num_items_in_batches = self.accelerator.gather(num_items_in_batches).sum().item()
+        return batch_samples, num_items_in_batches

Related patterns

Have you seen this in your site?

Connect AgentMinds to match against your tech stack automatically.

Run diagnostics