train_lora.py 说明文档这是这套最小 LoRA 实验里最重要的脚本。
如果说 notebook 更像“老师带你一步一步做”,那么 train_lora.py 更像:
对于初学者来说,这份脚本最大的难点不是某一行特别复杂,而是:
所以读它时,不要一上来逐字死抠。
先抓住整体结构,再看细节,会轻松很多。
可以把它拆成 6 块:
只要你先把这 6 块分清楚,阅读难度会下降很多。
代码:
import argparse
import json
import os
from dataclasses import dataclass
from typing import Dict, List
import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)import argparseargparse 用来处理命令行参数。
也就是说,它负责让你能写出这种命令:
python train_lora.py --train_file data/sample_train.jsonimport json读取 JSON 数据文件。
import os处理目录、路径和系统相关操作。
from dataclasses import dataclass这里导入了 dataclass,但在这份脚本里实际上没有用到。
这对初学者是一个很好的提醒:
from typing import Dict, List这是类型标注相关的工具。
比如:
DictList它们的作用不是改变程序运行,而是:
import torch导入 PyTorch。
后面要用它判断 GPU 是否可用。
from datasets import Dataset导入 Hugging Face 的数据集类。
from peft import LoraConfig, get_peft_model导入 LoRA / PEFT 的核心工具。
from transformers import (...)导入本实验用到的 Hugging Face 训练组件:
AutoModelForCausalLMAutoTokenizerDataCollatorForLanguageModelingTrainerTrainingArguments这些在 notebook 里你已经见过,所以脚本版其实是在把 notebook 流程重新组织一次。
代码:
def load_jsonl_or_json(path: str) -> List[Dict]:
with open(path, "r", encoding="utf-8") as f:
text = f.read().strip()
if not text:
return []
if text[0] == "[":
return json.loads(text)
return [json.loads(line) for line in text.splitlines() if line.strip()]它的目标是:
这比 notebook 版更工程化,因为它更通用。
def load_jsonl_or_json(path: str) -> List[Dict]:定义一个函数。
path: str
path 是字符串路径-> List[Dict]
这只是类型提示,不影响运行。
with open(path, "r", encoding="utf-8") as f:打开文件。
"r" 表示只读encoding="utf-8" 表示按 UTF-8 读取text = f.read().strip()读出整个文件内容,并去掉首尾空白。
if not text:如果文件内容是空的,就进入这里。
return []空文件就返回空列表。
这是一种很稳妥的防御式写法。
if text[0] == "[":如果文件第一个字符是 [,通常说明它是:
例如:
[
{"instruction": "..."}
]return json.loads(text)直接把整个文本解析成 Python 列表。
return [json.loads(line) for line in text.splitlines() if line.strip()]如果不是 JSON 列表,那就按 JSONL 处理:
json.loads(...)这行是一个列表推导式。
意思是:
它体现了脚本版比 notebook 版更成熟的一点:
代码:
def build_text(example: Dict) -> str:
instruction = example.get("instruction", "")
input_text = example.get("input", "")
output_text = example.get("output", "")
parts = []
if instruction:
parts.append(f"Instruction: {instruction}")
if input_text:
parts.append(f"Input: {input_text}")
parts.append(f"Response: {output_text}")
return "\n".join(parts)它和 notebook 版的 build_text() 是同一个思想:
instruction = example.get("instruction", "")从样例字典里取出 instruction。
如果没有,就用空字符串 ""。
input_text = example.get("input", "")取出 input 字段,没有就给空字符串。
output_text = example.get("output", "")取出 output 字段,没有就给空字符串。
parts = []创建一个空列表,用来按顺序装各段文本。
if instruction:如果 instruction 不是空字符串,就继续。
parts.append(f"Instruction: {instruction}")把 instruction 拼成一行文本,加进列表。
if input_text:如果有 input,就继续。
parts.append(f"Input: {input_text}")把 input 拼成一行文本。
parts.append(f"Response: {output_text}")无论如何都把输出目标加进去。
return "\n".join(parts)用换行符把这些部分拼成最终训练文本。
它是整个训练脚本真正决定“模型看到什么”的地方。
很多初学者以为 LoRA 的重点只在 LoRA 本身,其实不是。
训练格式设计同样重要。
代码:
def main() -> None:
parser = argparse.ArgumentParser(description="Minimal LoRA fine-tuning demo for AutoDL course.")
parser.add_argument("--model_name_or_path", type=str, default="/root/course_lora/models/tiny-gpt2")
parser.add_argument("--train_file", type=str, required=True)
parser.add_argument("--validation_file", type=str, required=True)
parser.add_argument("--output_dir", type=str, default="outputs/tiny_lora_demo")
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument("--per_device_train_batch_size", type=int, default=2)
parser.add_argument("--per_device_eval_batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=float, default=5e-4)
parser.add_argument("--max_length", type=int, default=128)
parser.add_argument("--logging_steps", type=int, default=1)
parser.add_argument("--save_steps", type=int, default=10)
parser.add_argument("--eval_steps", type=int, default=10)
args = parser.parse_args()它在定义:
也就是说,这里决定了:
run_train.sh 能传什么parser = argparse.ArgumentParser(...)创建一个参数解析器。
description=... 的作用是:
parser.add_argument("--model_name_or_path", type=str, default="/root/course_lora/models/tiny-gpt2")定义参数:
--model_name_or_path这说明:
parser.add_argument("--train_file", type=str, required=True)训练集路径,必须传。
required=True 表示:
parser.add_argument("--validation_file", type=str, required=True)验证集路径,必须传。
parser.add_argument("--output_dir", type=str, default="outputs/tiny_lora_demo")输出目录。
parser.add_argument("--num_train_epochs", type=int, default=1)训练轮数。
parser.add_argument("--per_device_train_batch_size", type=int, default=2)训练 batch size。
parser.add_argument("--per_device_eval_batch_size", type=int, default=2)验证 batch size。
parser.add_argument("--learning_rate", type=float, default=5e-4)学习率。
parser.add_argument("--max_length", type=int, default=128)最大 token 长度。
parser.add_argument("--logging_steps", type=int, default=1)日志步长。
parser.add_argument("--save_steps", type=int, default=10)保存步长。
parser.add_argument("--eval_steps", type=int, default=10)评估步长。
args = parser.parse_args()真正开始解析命令行参数,并把结果保存到 args。
之后你就能用:
args.train_fileargs.learning_rate这种方式访问参数值。
代码:
os.makedirs(args.output_dir, exist_ok=True)
train_records = load_jsonl_or_json(args.train_file)
val_records = load_jsonl_or_json(args.validation_file)
if not train_records:
raise ValueError("Training file is empty.")
if not val_records:
raise ValueError("Validation file is empty.")os.makedirs(args.output_dir, exist_ok=True)确保输出目录存在。
makedirs:创建目录exist_ok=True:如果目录已经存在,不要报错train_records = load_jsonl_or_json(args.train_file)读取训练集。
val_records = load_jsonl_or_json(args.validation_file)读取验证集。
if not train_records:如果训练集为空,就进入这里。
raise ValueError("Training file is empty.")主动抛出错误,停止程序。
这是一种很好的工程习惯:
同理,验证集为空时也会报错。
Dataset代码:
train_dataset = Dataset.from_list([{"text": build_text(x)} for x in train_records])
val_dataset = Dataset.from_list([{"text": build_text(x)} for x in val_records])把 Python 列表转换成 Hugging Face Dataset。
这和 notebook 版完全一致,只是换成脚本写法。
[{"text": build_text(x)} for x in train_records]这是列表推导式。
意思是:
build_text(x) 拼成一段训练文本{"text": ...} 这种字典Dataset.from_list(...)把这个列表变成 Hugging Face Dataset。
验证集同理。
代码:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize_fn(batch: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
return tokenizer(
batch["text"],
truncation=True,
padding="max_length",
max_length=args.max_length,
)
train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
val_dataset = val_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])它完成两件事:
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)根据命令行参数给出的模型路径或模型名加载 tokenizer。
if tokenizer.pad_token is None:如果 tokenizer 没有 pad token,就进入这里。
tokenizer.pad_token = tokenizer.eos_token用 eos_token 代替 pad token。
def tokenize_fn(batch: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:定义 tokenization 函数。
类型提示的意思是:
tokenizer(batch["text"], truncation=True, padding="max_length", max_length=args.max_length)这里三项最重要:
truncation=Truepadding="max_length"max_length=args.max_length也就是:
max_lengthtrain_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])在整个数据集上批量应用 tokenization,并删除原始文本列。
代码:
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()加载原始模型,然后在它上面接 LoRA。
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)加载 base model。
lora_config = LoraConfig(...)定义 LoRA 配置。
参数意义:
r=8
lora_alpha=16
lora_dropout=0.05
bias="none"
task_type="CAUSAL_LM"
model = get_peft_model(model, lora_config)把 LoRA adapter 接到模型上。
model.print_trainable_parameters()打印:
这一行非常值得学生认真看,因为它最直观地展示了:
代码:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
training_args = TrainingArguments(
output_dir=args.output_dir,
num_train_epochs=args.num_train_epochs,
per_device_train_batch_size=args.per_device_train_batch_size,
per_device_eval_batch_size=args.per_device_eval_batch_size,
learning_rate=args.learning_rate,
logging_steps=args.logging_steps,
save_steps=args.save_steps,
eval_steps=args.eval_steps,
eval_strategy="steps",
save_strategy="steps",
report_to="none",
fp16=torch.cuda.is_available(),
remove_unused_columns=False,
)data_collator 在做什么它负责把 token 化后的样本整理成适合语言模型训练的 batch。
这里:
mlm=False说明:
TrainingArguments 里最值得学生理解的参数output_dir
num_train_epochs
per_device_train_batch_size
per_device_eval_batch_size
learning_rate
logging_steps
save_steps
eval_steps
eval_strategy="steps"
save_strategy="steps"
report_to="none"
fp16=torch.cuda.is_available()
remove_unused_columns=False
Trainer 并启动训练代码:
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=data_collator,
)
trainer.train()
trainer.save_model(args.output_dir)
tokenizer.save_pretrained(args.output_dir)
print(f"Saved LoRA demo model to: {args.output_dir}")它完成整个训练闭环的最后部分:
Trainertrainer = Trainer(...)把下面这些东西接起来:
modelargstrain_dataseteval_datasetdata_collatortrainer.train()正式开始训练。
trainer.save_model(args.output_dir)把训练结果保存到输出目录。
在 LoRA 场景里,这一步重点保存的是:
tokenizer.save_pretrained(args.output_dir)把 tokenizer 也保存到输出目录,方便后续推理。
print(f"Saved LoRA demo model to: {args.output_dir}")打印最终保存位置。
代码:
if __name__ == "__main__":
main()这是 Python 脚本的标准入口写法。
它的意思是:
main()为什么要这样写?
因为这样做的好处是:
这份脚本本质上就是:
01_lora_demo.ipynb 的脚本化版本两者的主要差别不是算法,而是使用方式:
SFTTrainer 或量化版本就容易很多这份脚本里有一个未使用的导入:
from dataclasses import dataclass它不会影响脚本运行,但它提醒学生:
对于初学者,这是一个很重要的阅读策略。