0-finetune-chronos-etth.py 代码说明

这份脚本在做什么

这份脚本做的事情其实可以用一句话概括:

它不是完整研究级训练框架,而是一个课堂版 demo。
它的目标是让编程基础还不强的同学,也能看清楚下面这条路径:

  1. 先准备一段历史序列和对应未来序列
  2. 把它们切成很多小窗口
  3. 用 Chronos 自带 tokenizer 把数值变成 token
  4. T5ForConditionalGeneration 做 teacher-forcing 训练
  5. 保存微调后的模型
  6. 比较微调前后的预测结果

如果你把它和前面:

对照着看,会更容易理解:

也就是说,这份脚本是在上一份 demo 的基础上往前走了一步。

先看整体结构

这份脚本可以分成 10 个部分:

  1. 导入依赖
  2. 定义一个数据样本结构 WindowExample
  3. 定义滑动窗口数据集 SlidingWindowDataset
  4. 读取 Chronos 配置,构造 tokenizer
  5. 定义 collate_fn
  6. 选择设备
  7. 定义一个“用 pipeline 做预测”的辅助函数
  8. 主函数里解析参数
  9. 主函数里做训练
  10. 主函数里做微调前后预测对比并保存结果

第 1 部分:导入依赖

import argparse
import json
import math
import os
from dataclasses import dataclass
from pathlib import Path

这些是 Python 标准库:

再看下面这一组:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from transformers import AutoConfig, AutoModelForSeq2SeqLM

from chronos import ChronosConfig, ChronosPipeline

它们的角色分别是:

第 2 部分:WindowExample

@dataclass
class WindowExample:
    context: torch.Tensor
    future: torch.Tensor

这个类非常简单,它只是把一个样本拆成两个部分:

你可以把它理解成:

为什么要这么写?

因为时间序列训练最自然的监督信号就是:

第 3 部分:SlidingWindowDataset

这一段是整份脚本最重要的部分之一。

先看这个类在干什么

它把一整条长时间序列切成很多训练样本。

例如:

这样就能从一条长序列里切出很多 (context, future) 对。

__init__ 里的参数是什么意思

def __init__(
    self,
    values: np.ndarray,
    context_length: int,
    prediction_length: int,
    max_windows: int,
    stride: int,
) -> None:

先看核心变量

self.examples = []
total = context_length + prediction_length
upper = len(values) - total + 1

最关键的 for 循环

for start in range(0, max(upper, 0), stride):

这表示:

取出 context 和 future

context = values[start : start + context_length]
future = values[start + context_length : start + total]

这是最标准的时间序列监督学习切法:

这个判断在防什么

if len(future) < prediction_length:
    break

作用是:

否则最后一个样本会不完整。

为什么要转成 torch.tensor

self.examples.append(
    WindowExample(
        context=torch.tensor(context, dtype=torch.float32),
        future=torch.tensor(future, dtype=torch.float32),
    )
)

因为后面训练要用 PyTorch,所以这里先统一转成:

并且用:

表示这是浮点数时间序列。

这一句在做什么

if len(self.examples) >= max_windows:
    break

这一步很重要,尤其在课堂 demo 里。

它的作用是:

也就是说,这个脚本追求的是:

而不是先追求最优结果。

__len____getitem__

def __len__(self) -> int:
    return len(self.examples)

def __getitem__(self, idx: int) -> WindowExample:
    return self.examples[idx]

这是 PyTorch Dataset 的标准接口:

第 4 部分:build_tokenizer

def build_tokenizer(model_dir: Path):
    cfg = AutoConfig.from_pretrained(model_dir)
    chronos_cfg = ChronosConfig(**cfg.chronos_config)
    return chronos_cfg.create_tokenizer(), chronos_cfg

这段代码非常关键,因为它说明:

它还带了一层 Chronos 专用配置:

逐行看:

这里最重要的理解是:

第 5 部分:build_collate_fn

这一段是第二个最关键的地方。

为什么需要 collate_fn

因为 DataLoader 每次拿到的是一批 WindowExample
但模型真正需要的是:

collate_fn 的作用就是:

先把 batch 里的张量堆起来

contexts = torch.stack([item.context for item in batch], dim=0)
futures = torch.stack([item.future for item in batch], dim=0)

这里的 torch.stack 表示:

结果就会像:

这一句最重要

input_ids, attention_mask, scale = tokenizer.context_input_transform(contexts)

这一步在做:

返回三个东西:

为什么要 scale

因为 Chronos 不是直接把原始数值硬离散化,而是:

这样不同量级的序列更容易共享一套 token 空间。

labels 怎么来

labels, labels_mask = tokenizer.label_input_transform(futures, scale)

这一句表示:

这一步非常像:

为什么要把一部分 label 改成 -100

labels = labels.clone()
labels[~labels_mask] = -100

在 Hugging Face 的 seq2seq 训练里:

所以这里是在说:

最终返回的是什么

return {
    "input_ids": input_ids.long(),
    "attention_mask": attention_mask.long(),
    "labels": labels.long(),
}

这就是标准的 Hugging Face 训练输入格式。

第 6 部分:pick_device

def pick_device() -> torch.device:
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

这一段很朴素:

对学生来说,这也是一个很好的工程习惯:

第 7 部分:forecast_with_pipeline

这段函数的作用是:

为什么要单独写这个函数?

因为我们后面要做两次:

  1. 用微调前的模型预测
  2. 用微调后的模型预测

这样就能直接比较:

逐行看

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if device == "cuda" else torch.float32

这里表示:

再看:

pipeline = ChronosPipeline.from_pretrained(
    str(model_dir),
    device_map=device,
    dtype=dtype,
)

这表示:

这里的 model_dir 可能是:

这正是这份脚本最漂亮的地方:

真正做预测的是这句

forecast = pipeline.predict(context, prediction_length, num_samples=num_samples)

这里:

返回结果不是一条序列,而是:

所以后面我们才会再去算:

第 8 部分:quantiles

def quantiles(samples: np.ndarray):
    return np.quantile(samples, [0.1, 0.5, 0.9], axis=0)

这一段非常短,但很重要。

它做的是:

这里:

这样就能画出:

第 9 部分:main() 里的参数

这一部分先用 argparse 定义命令行参数。

基础资源参数

--base-model
--csv
--target

窗口参数

--context-length
--prediction-length
--stride

数据量控制参数

--max-train-windows
--max-val-windows

这两个参数最主要是为了:

训练参数

--epochs
--batch-size
--learning-rate

推理与输出参数

--num-samples
--output-dir
--plot-path
--metrics-path

第 10 部分:清理代理环境变量

for key in [...]:
    os.environ.pop(key, None)

这一段和你前面本地 Ollama demo 里看到的很像。

作用是:

课堂上最简单的理解就是:

第 11 部分:把路径准备好

base_model = Path(args.base_model).resolve()
csv_path = Path(args.csv).resolve()
output_dir = Path(args.output_dir)
plot_path = Path(args.plot_path)
metrics_path = Path(args.metrics_path)

这一步就是:

后面这几句:

output_dir.mkdir(parents=True, exist_ok=True)
plot_path.parent.mkdir(parents=True, exist_ok=True)
metrics_path.parent.mkdir(parents=True, exist_ok=True)

作用是:

这是一种很实用的工程习惯:

第 12 部分:基础检查

if not base_model.exists():
    raise FileNotFoundError(...)
if not csv_path.exists():
    raise FileNotFoundError(...)

这是在提前防止最常见错误:

比起让程序后面莫名报错,这种“提前检查并报清楚”更适合教学。

第 13 部分:构造 tokenizer 并检查预测长度

tokenizer, chronos_cfg = build_tokenizer(base_model)
if args.prediction_length != chronos_cfg.prediction_length:
    raise ValueError(...)

这个检查非常关键。

为什么?

因为当前 chronos-t5-tiny 的配置本身规定了:

如果你硬把训练标签长度改成别的值,tokenizer 这边就不一致了。

所以这里是在提醒学生:

第 14 部分:读数据并选目标列

df = pd.read_csv(csv_path)
if args.target not in df.columns:
    raise ValueError(...)
values = df[args.target].astype(float).to_numpy()

逐行解释:

为什么最后要变成 NumPy?

因为后面自己切窗口时,NumPy 一维数组最直接。

第 15 部分:划分 train / val / test

split_train = int(len(values) * 0.8)
split_val = int(len(values) * 0.9)

这里使用的是最简单的时间顺序切分:

这比随机打乱更符合时间序列场景,因为:

为什么 val_valuestest_values 前面多切了一段

val_values = values[split_train - args.context_length - args.prediction_length : split_val]
test_values = values[split_val - args.context_length - args.prediction_length :]

这是为了保证:

也就是说,你不能只给验证集剩一小段未来,还得留足够长的历史上下文。

第 16 部分:构造 Dataset 和 DataLoader

train_ds = SlidingWindowDataset(...)
val_ds = SlidingWindowDataset(...)
collate_fn = build_collate_fn(tokenizer)
train_loader = DataLoader(...)
val_loader = DataLoader(...)

这一段的逻辑是:

  1. 先把原始长序列切成很多 (context, future) 样本
  2. 再把这些样本批量喂给模型

这里:

这是 PyTorch 里非常常见的模式。

第 17 部分:加载模型

device = pick_device()
model = AutoModelForSeq2SeqLM.from_pretrained(str(base_model))
model.to(device)
model.train()

逐行看:

这里最值得学生记住的是:

第 18 部分:优化器和损失记录

optimizer = AdamW(model.parameters(), lr=args.learning_rate)
train_losses = []
val_losses = []

第 19 部分:训练循环

这一段是整份脚本最像“标准训练代码”的地方。

外层 epoch 循环

for epoch in range(args.epochs):

意思是:

内层 batch 循环

for batch in train_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    epoch_losses.append(loss.item())

逐行解释:

这就是最标准的 PyTorch 训练闭环。

为什么还要验证

训练完一轮后,这段代码会跑验证集:

model.eval()
with torch.no_grad():
    ...
model.train()

这里:

这一步的意义是:

第 20 部分:保存模型

model.save_pretrained(output_dir)
AutoConfig.from_pretrained(base_model).save_pretrained(output_dir)

第一句表示:

第二句表示:

为什么配置也要存?

因为后面要重新用 ChronosPipeline.from_pretrained(output_dir) 加载这个模型。
如果只有权重、没有配置,很多框架就不知道该怎么还原它。

第 21 部分:取测试窗口

history = test_values[: args.context_length]
future = test_values[args.context_length : args.context_length + args.prediction_length]

这表示:

它和训练时切窗口的思想完全一样,只不过这里我们只拿一个样本来做展示。

第 22 部分:比较微调前后预测

before = forecast_with_pipeline(base_model, ...)
after = forecast_with_pipeline(output_dir, ...)

这两句是整份脚本的教学高潮。

它们分别表示:

这一步让学生能非常直观地看到:

第 23 部分:算 MAE

mae_before = float(np.mean(np.abs(med_b - future)))
mae_after = float(np.mean(np.abs(med_a - future)))

这里的 MAE 就是:

为什么拿中位数预测 med_b / med_a 去和真实值比?

因为:

第 24 部分:画图

plt.plot(x_hist, history, ...)
plt.plot(x_pred, future, ...)
plt.plot(x_pred, med_b, ...)
plt.fill_between(...)
plt.plot(x_pred, med_a, ...)
plt.fill_between(...)

这张图会同时画出:

所以它非常适合课堂上讲:

如果学生不太会看数字指标,这张图往往比 MAE 更直观。

第 25 部分:保存 metrics

metrics = {
    ...
}
metrics_path.write_text(json.dumps(metrics, indent=2))

这一段做的是:

为什么这一步重要?

因为真实实验里,只看终端输出是不够的。
你需要一个能复查的结果文件。

第 26 部分:最后打印结果

print(f"Saved fine-tuned model to: {output_dir}")
print(f"Saved comparison plot to: {plot_path}")
print(f"Saved metrics to: {metrics_path}")
print(f"MAE before fine-tune: {mae_before:.3f}")
print(f"MAE after fine-tune:  {mae_after:.3f}")

这一步很适合教学,因为它会把学生最关心的 5 件事一次性说清楚:

  1. 模型存到哪里了
  2. 图存到哪里了
  3. 指标存到哪里了
  4. 微调前误差是多少
  5. 微调后误差是多少

这份脚本最想让你理解什么

如果把整份脚本压缩成 4 句话,它真正想让你理解的是:

  1. 时间序列 foundation model 也可以像文本模型一样继续微调。
  2. Chronos 的关键不是直接喂原始数值,而是先做 Chronos tokenizer 的数值 token 化。
  3. 最小微调闭环并不神秘,本质上还是“数据集 -> DataLoader -> 模型 -> loss -> optimizer”。
  4. fine-tune 最后是否有价值,不看感觉,要看“微调前 vs 微调后”的结果对比。

如果你是第一次读这份代码,最该先盯住哪几段

如果你觉得整份脚本还是长,最推荐先盯住下面 5 个位置:

  1. SlidingWindowDataset 看清楚时间序列样本是怎样切出来的
  2. build_collate_fn 看清楚数值是怎样变成 token 的
  3. AutoModelForSeq2SeqLM.from_pretrained(...) 看清楚 Chronos 底层其实是什么模型
  4. 训练循环 看清楚 PyTorch 的最基本闭环
  5. before / after 看清楚为什么最终一定要做微调前后对比