Contour · 等高线图#contour#log-log-axes#optimum-overlay#method-comparison

Step Law · Hyperparameter contour plot of LR × batch size with optima

Step Law · 学习率×批大小超参数等高线图(带各算法最优点)

Reproduction of arXiv:2503.04715 Figure 1. A log-log learning-rate × batch-size hyperparameter landscape for a 1B/100B-token LLM, with five concentric relative-loss contours (+0.125% / +0.25% / +0.5% / +1% / +2%) and the global minimum (red ✕), Step Law's prediction (★), DeepSeek Law (▲), Porian Law (■) and the Microsoft / OpenAI law lines overlaid. The right-hand colour bar reads absolute Loss; the legend pins all six methods.

arXiv:2503.04715 Figure 1 复现。1B/100B token LLM 的 log-log LR × Batch Size 超参等高线图:五圈相对 loss 等高线(+0.125% / +0.25% / +0.5% / +1% / +2%),叠加 Global Minimum(红 ✕)/ Step Law(金 ★)/ DeepSeek(青 ▲)/ Porian(紫 ■)/ Microsoft / OpenAI 六种法则的预测最优点。右侧色带是绝对 Loss,图例固定在右上。

@paper · 来自论文

Predictable Scale: Part I, Step Law — Optimal Hyperparameter Scaling Law in Large Language Model Pre-training

Predictable Scale: Part I, Step Law — 大语言模型预训练的最优超参数缩放律

Houyi Li et al. (StepFun) · arXiv 2025

// original from paper · 论文原图
original
// reproduced via predictscale_contour.py · 脚本复现download png
rendered
predictscale_contour.py
download .py
"""Reproduction of Step Law (Predictable Scale) Figure 1.

Anonymised reproduction of arXiv:2503.04715 Figure 1:
A learning-rate x batch-size hyperparameter contour plot for a 1B-parameter
model trained on 100B tokens. Filled / line contours show relative loss
percentiles around the empirical optimum, with the global minimum (red X)
and Step Law's predicted optimum (yellow star) overlaid.

All data is synthesized from the published power-law form
    eta(N, D) = 1.79 * N^-0.713 * D^0.307
    B(D)      = 0.58 * D^0.571
together with a quadratic relative-loss model in (log_eta, log_B)-space,
so no checkpoints, training logs, or external assets are needed.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.colors import LinearSegmentedColormap

plt.rcParams.update({
    "font.family": "serif",
    "font.size": 9.5,
    "axes.linewidth": 0.8,
})

N = 1.0e9
D = 1.0e11

eta_star = 1.79 * (N ** -0.713) * (D ** 0.307)
B_star = 0.58 * (D ** 0.571)

LR = np.logspace(np.log10(2e-4), np.log10(8e-3), 220)
BS = np.logspace(np.log10(2e4), np.log10(2e6), 220)
LRg, BSg = np.meshgrid(LR, BS)

x = np.log10(LRg) - np.log10(eta_star)
y = np.log10(BSg) - np.log10(B_star)

# Anisotropic quadratic in tilted coordinates ~ relative loss landscape.
theta = np.deg2rad(-22)
xr = x * np.cos(theta) - y * np.sin(theta)
yr = x * np.sin(theta) + y * np.cos(theta)
rel_pct = 0.094 + 1.4 * (xr ** 2) + 4.5 * (yr ** 2) + 0.6 * xr * yr
loss = 2.073 + 0.0009 * rel_pct

LEVELS = [0.125, 0.25, 0.5, 1.0, 2.0]
contour_colors = ["#4d3b8a", "#7a3a8c", "#c4407c", "#e07a3a", "#f0a83a"]

fig, ax = plt.subplots(figsize=(7.4, 5.4))

cf = ax.contourf(LRg, BSg, loss, levels=40, cmap=LinearSegmentedColormap.from_list(
    "step_law_loss",
    ["#4a3585", "#785a9c", "#a587b2", "#cba4b3", "#e0bdaa", "#eed7b5", "#f3e8ce"],
), alpha=0.0)
cs = ax.contour(LRg, BSg, rel_pct, levels=LEVELS, colors=contour_colors,
                linewidths=1.6)
fmt = {lv: f"+{lv:.3f}%" for lv in LEVELS}
ax.clabel(cs, inline=True, fmt=fmt, fontsize=7.2, inline_spacing=4)

ax.scatter([eta_star * 1.05], [B_star * 1.02], marker="*", s=170,
           color="#e8c84a", edgecolor="#7a5a10", linewidth=0.8, zorder=6,
           label="Ours (Step Law)")
ax.scatter([eta_star * 0.95], [B_star * 0.96], marker="X", s=120,
           color="#cf3b3b", edgecolor="#5a1414", linewidth=0.6, zorder=7,
           label="Global Minimum")
ax.scatter([eta_star * 0.55], [B_star * 1.4], marker="^", s=110,
           color="#26b6c4", edgecolor="#114a5a", linewidth=0.6, zorder=5,
           label="DeepSeek Law")
ax.scatter([eta_star * 1.35], [B_star * 1.3], marker="s", s=90,
           color="#a73da6", edgecolor="#3e0d3e", linewidth=0.6, zorder=5,
           label="Porian Law")

ax.axvline(eta_star * 0.18, ls="--", color="#e6624a", lw=1.0, alpha=0.85)
ax.axvline(eta_star * 0.21, ls="--", color="#e6624a", lw=1.0, alpha=0.85)
ax.text(eta_star * 0.20, ax.get_ylim()[0] * 1.6 if False else 5e4,
        "+2.000%", color="#e6624a", fontsize=7.0, rotation=90,
        ha="right", va="bottom")

cb = plt.colorbar(plt.cm.ScalarMappable(
    norm=plt.Normalize(vmin=2.06, vmax=2.18),
    cmap=LinearSegmentedColormap.from_list(
        "loss_bar",
        ["#3a5a90", "#6f86b6", "#b9b8d2", "#e8b9b3", "#dc6a72", "#b8395f"],
    ),
), ax=ax, pad=0.012, fraction=0.04, aspect=22)
cb.set_label("Loss", rotation=270, labelpad=12, fontsize=10)
cb.ax.tick_params(labelsize=8)

ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xlim(2e-4, 8e-3)
ax.set_ylim(2e4, 2e6)
ax.set_xticks([5e-4, 1e-3, 5e-3])
ax.set_xticklabels([r"$5\times 10^{-4}$", r"$10^{-3}$", r"$5\times 10^{-3}$"])
ax.set_yticks([1e5, 1e6])
ax.set_xlabel("Learning Rate", fontweight="bold", fontsize=10)
ax.set_ylabel("Batch Size", fontweight="bold", fontsize=10)
ax.grid(True, which="both", linewidth=0.3, color="#bbb", linestyle=":")

leg_handles = [
    Line2D([0], [0], marker="X", color="none", markerfacecolor="#cf3b3b",
           markeredgecolor="#5a1414", markersize=10, label="Global Minimum"),
    Line2D([0], [0], marker="*", color="none", markerfacecolor="#e8c84a",
           markeredgecolor="#7a5a10", markersize=12, label="Ours (Step Law)"),
    Line2D([0], [0], marker="^", color="none", markerfacecolor="#26b6c4",
           markeredgecolor="#114a5a", markersize=10, label="DeepSeek Law"),
    Line2D([0], [0], marker="s", color="none", markerfacecolor="#a73da6",
           markeredgecolor="#3e0d3e", markersize=9, label="Porian Law"),
    Line2D([0], [0], color="#e6624a", lw=1.2, ls="--",
           label="Microsoft Law"),
    Line2D([0], [0], color="#f0a83a", lw=1.4, ls="-.",
           label="OpenAI Law"),
]
ax.legend(handles=leg_handles, loc="upper right", frameon=True, fontsize=7.5,
          edgecolor="#888", framealpha=0.95)

plt.tight_layout()
plt.savefig("predictscale_contour_repro.png", dpi=200, bbox_inches="tight")
print("saved predictscale_contour_repro.png")
uploaded by @Trae1ounG9 views · 0 downloads