Chronos ETTh1 最小预测示例¶

  • 模型:本地 chronos-t5-tiny
  • 数据:ETTh1.csv
  • 目标:取 OT 列,用最后一段历史预测未来 96 个点
In [ ]:
from pathlib import Path
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
In [ ]:
for key in [
    'http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'all_proxy', 'ALL_PROXY'
]:
    os.environ.pop(key, None)

csv_path = Path('data/ETTh1.csv')
model_path = Path('models/chronos-t5-tiny')
assert csv_path.exists(), f'Missing dataset: {csv_path}'
assert model_path.exists(), f'Missing model: {model_path}'
csv_path
In [ ]:
df = pd.read_csv(csv_path)
df.head()
In [ ]:
target = 'OT'
context_length = 512
prediction_length = 96

series = df[target].astype(float)
history = series.iloc[-(context_length + prediction_length):-prediction_length].reset_index(drop=True)
future = series.iloc[-prediction_length:].reset_index(drop=True)

history.shape, future.shape
In [ ]:
plt.figure(figsize=(10, 3))
plt.plot(history.values, label='history', color='royalblue')
plt.title('History window used as Chronos context')
plt.grid(alpha=0.3)
plt.legend()
plt.show()
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if device == 'cuda' else torch.float32

pipeline = ChronosPipeline.from_pretrained(
    str(model_path),
    device_map=device,
    dtype=dtype,
)

device, dtype
In [ ]:
context = torch.tensor(history.to_numpy(), dtype=torch.float32)
forecast = pipeline.predict(context, prediction_length, num_samples=20)
samples = forecast[0].detach().cpu().numpy()
low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
mae = float(np.mean(np.abs(median - future.to_numpy())))
mae
In [ ]:
plt.figure(figsize=(10, 4))
hist_index = np.arange(len(history))
pred_index = np.arange(len(history), len(history) + prediction_length)

plt.plot(hist_index, history, color='royalblue', label='history')
plt.plot(pred_index, future, color='black', linestyle='--', label='ground truth')
plt.plot(pred_index, median, color='tomato', label='median forecast')
plt.fill_between(pred_index, low, high, color='tomato', alpha=0.25, label='80% interval')
plt.title(f'Chronos forecast on ETTh1 ({target}), MAE={mae:.3f}')
plt.xlabel('time step')
plt.ylabel(target)
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

你可以继续改什么¶

  • 把 prediction_length 改成 24、48、192
  • 把目标列从 OT 改成其他传感器列
  • 记录不同设置下预测区间的变化