From curry-train
An automated recovery procedure for loss spikes during long-running training — detect a spike, roll back to a recent checkpoint, skip a window of batches, resume. Modeled on the PaLM training paper. Activate when the user asks "loss spike", "training spiked then crashed", "recover from divergence", "PaLM rollback recipe", or experiences instability mid-run.
npx claudepluginhub curryfromuestc/curry-train --plugin curry-trainThis skill uses the workspace's default tool permissions.
A specific recovery recipe for training instabilities that produce a loss spike: rather than killing the run or letting it diverge, roll back to a recent checkpoint, skip a small window of training batches, and resume. Empirically robust for transformer training at scale.
Guides Next.js Cache Components and Partial Prerendering (PPR): 'use cache' directives, cacheLife(), cacheTag(), revalidateTag() for caching, invalidation, static/dynamic optimization. Auto-activates on cacheComponents: true.
Processes PDFs: extracts text/tables/images, merges/splits/rotates pages, adds watermarks, creates/fills forms, encrypts/decrypts, OCRs scans. Activates on PDF mentions or output requests.
Share bugs, ideas, or general feedback.
A specific recovery recipe for training instabilities that produce a loss spike: rather than killing the run or letting it diverge, roll back to a recent checkpoint, skip a small window of training batches, and resume. Empirically robust for transformer training at scale.
"When training spikes, can I recover the run instead of restarting from scratch?"
Often: yes, with the right recipe.
The recipe is appropriate when:
The recipe is not appropriate for:
stage3-kill-criterion and re-tune).stage1-preflight-asserts).This recipe is non-trivial because it requires:
class LossSpikeWatchdog:
def __init__(self, *, threshold_ratio=5.0, sustain_steps=100,
min_history=1000):
self.recent = collections.deque(maxlen=min_history)
self.threshold_ratio = threshold_ratio
self.sustain_steps = sustain_steps
self._spike_start = None
def update(self, loss: float, step: int) -> Optional[dict]:
self.recent.append(loss)
if len(self.recent) < self.recent.maxlen:
return None
rolling_min = min(self.recent)
if loss > self.threshold_ratio * rolling_min:
self._spike_start = self._spike_start or step
if step - self._spike_start >= self.sustain_steps:
# Spike sustained: trigger rollback
return {
"rollback_to_step": max(0, self._spike_start - 100),
"skip_batches": 200,
"reason": "loss_spike_sustained",
}
else:
self._spike_start = None
return None
The training loop checks the watchdog every step; on trigger, it loads the rollback checkpoint and applies the data-skip.
Loss spike rollback and stage3-kill-criterion work together:
A reasonable composite policy: rollback up to 2 times; on the 3rd spike within the same run, kill.
Confirm checkpoint cadence is fine-grained enough. The user needs a checkpoint every ~100 steps for this to work; see stage5-checkpoint-cadence.
Confirm the data loader supports skipping by index. If it's a DataLoader over a deterministic dataset, this is easy. If it's a stream (WebDataset, mosaic), confirm the offset semantics.
Wire up LossSpikeWatchdog to the training loop's step callback (on_step_end in curry_train.loop).
Add the watchdog's trigger event to the run journal — every rollback should be visible after the fact for analysis.
After a successful rollback, log the rollback metadata (from-step, to-step, batches-skipped, reason). This is essential for runs-diff and for distinguishing "clean run" from "recovered run" in subsequent analysis.
The PaLM paper observed that loss spikes were often correlated with specific data shards (e.g. all-junk batches that caused gradient explosion). Without the skip, resuming from rollback would re-encounter the same trigger. With skip, the run progresses past it.
If your data is well-curated and you don't observe data-correlated spikes, you may not need the skip — but it's cheap insurance to leave it in.
skills/stage5-checkpoint-cadence — defines the lightweight-checkpoint cadence required for rollback.skills/stage3-kill-criterion — the upstream kill rule when rollback can't save the run.skills/stage5-run-journal — journals every rollback event.