Chronos ETTh1 最小预测示例¶
- 模型:本地
chronos-t5-tiny - 数据:
ETTh1.csv - 目标:取
OT列,用最后一段历史预测未来96个点
In [ ]:
from pathlib import Path
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from chronos import ChronosPipeline
In [ ]:
for key in [
'http_proxy', 'https_proxy', 'HTTP_PROXY', 'HTTPS_PROXY', 'all_proxy', 'ALL_PROXY'
]:
os.environ.pop(key, None)
csv_path = Path('data/ETTh1.csv')
model_path = Path('models/chronos-t5-tiny')
assert csv_path.exists(), f'Missing dataset: {csv_path}'
assert model_path.exists(), f'Missing model: {model_path}'
csv_path
In [ ]:
df = pd.read_csv(csv_path)
df.head()
In [ ]:
target = 'OT'
context_length = 512
prediction_length = 96
series = df[target].astype(float)
history = series.iloc[-(context_length + prediction_length):-prediction_length].reset_index(drop=True)
future = series.iloc[-prediction_length:].reset_index(drop=True)
history.shape, future.shape
In [ ]:
plt.figure(figsize=(10, 3))
plt.plot(history.values, label='history', color='royalblue')
plt.title('History window used as Chronos context')
plt.grid(alpha=0.3)
plt.legend()
plt.show()
In [ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = torch.bfloat16 if device == 'cuda' else torch.float32
pipeline = ChronosPipeline.from_pretrained(
str(model_path),
device_map=device,
dtype=dtype,
)
device, dtype
In [ ]:
context = torch.tensor(history.to_numpy(), dtype=torch.float32)
forecast = pipeline.predict(context, prediction_length, num_samples=20)
samples = forecast[0].detach().cpu().numpy()
low, median, high = np.quantile(samples, [0.1, 0.5, 0.9], axis=0)
mae = float(np.mean(np.abs(median - future.to_numpy())))
mae
In [ ]:
plt.figure(figsize=(10, 4))
hist_index = np.arange(len(history))
pred_index = np.arange(len(history), len(history) + prediction_length)
plt.plot(hist_index, history, color='royalblue', label='history')
plt.plot(pred_index, future, color='black', linestyle='--', label='ground truth')
plt.plot(pred_index, median, color='tomato', label='median forecast')
plt.fill_between(pred_index, low, high, color='tomato', alpha=0.25, label='80% interval')
plt.title(f'Chronos forecast on ETTh1 ({target}), MAE={mae:.3f}')
plt.xlabel('time step')
plt.ylabel(target)
plt.grid(alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()
你可以继续改什么¶
- 把
prediction_length改成24、48、192 - 把目标列从
OT改成其他传感器列 - 记录不同设置下预测区间的变化