train_lora.py 说明文档

这是这套最小 LoRA 实验里最重要的脚本。
如果说 notebook 更像“老师带你一步一步做”,那么 train_lora.py 更像:

对于初学者来说,这份脚本最大的难点不是某一行特别复杂,而是:

所以读它时,不要一上来逐字死抠。
先抓住整体结构,再看细节,会轻松很多。


一、这份脚本整体分成哪几部分

可以把它拆成 6 块:

  1. 导入依赖
  2. 读取 JSON 的辅助函数
  3. 把样例拼成训练文本的辅助函数
  4. 定义命令行参数
  5. 主训练流程
  6. 脚本入口

只要你先把这 6 块分清楚,阅读难度会下降很多。


二、第一部分:导入依赖

代码:

import argparse
import json
import os
from dataclasses import dataclass
from typing import Dict, List

import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)

这些模块分别干什么

import argparse

argparse 用来处理命令行参数。

也就是说,它负责让你能写出这种命令:

python train_lora.py --train_file data/sample_train.json

import json

读取 JSON 数据文件。

import os

处理目录、路径和系统相关操作。

from dataclasses import dataclass

这里导入了 dataclass,但在这份脚本里实际上没有用到

这对初学者是一个很好的提醒:

from typing import Dict, List

这是类型标注相关的工具。

比如:

它们的作用不是改变程序运行,而是:

import torch

导入 PyTorch。

后面要用它判断 GPU 是否可用。

from datasets import Dataset

导入 Hugging Face 的数据集类。

from peft import LoraConfig, get_peft_model

导入 LoRA / PEFT 的核心工具。

from transformers import (...)

导入本实验用到的 Hugging Face 训练组件:

这些在 notebook 里你已经见过,所以脚本版其实是在把 notebook 流程重新组织一次。


三、第二部分:读取 JSON 的辅助函数

代码:

def load_jsonl_or_json(path: str) -> List[Dict]:
    with open(path, "r", encoding="utf-8") as f:
        text = f.read().strip()
    if not text:
        return []
    if text[0] == "[":
        return json.loads(text)
    return [json.loads(line) for line in text.splitlines() if line.strip()]

这个函数在做什么

它的目标是:

  1. 普通 JSON 列表
  2. JSONL(每行一条 JSON)

这比 notebook 版更工程化,因为它更通用。

逐行解释

def load_jsonl_or_json(path: str) -> List[Dict]:

定义一个函数。

这只是类型提示,不影响运行。

with open(path, "r", encoding="utf-8") as f:

打开文件。

text = f.read().strip()

读出整个文件内容,并去掉首尾空白。

if not text:

如果文件内容是空的,就进入这里。

return []

空文件就返回空列表。

这是一种很稳妥的防御式写法。

if text[0] == "[":

如果文件第一个字符是 [,通常说明它是:

例如:

[
  {"instruction": "..."}
]

return json.loads(text)

直接把整个文本解析成 Python 列表。

return [json.loads(line) for line in text.splitlines() if line.strip()]

如果不是 JSON 列表,那就按 JSONL 处理:

这行是一个列表推导式。

意思是:

这个函数为什么重要

它体现了脚本版比 notebook 版更成熟的一点:


四、第三部分:把样例拼成训练文本

代码:

def build_text(example: Dict) -> str:
    instruction = example.get("instruction", "")
    input_text = example.get("input", "")
    output_text = example.get("output", "")
    parts = []
    if instruction:
        parts.append(f"Instruction: {instruction}")
    if input_text:
        parts.append(f"Input: {input_text}")
    parts.append(f"Response: {output_text}")
    return "\n".join(parts)

这个函数在做什么

它和 notebook 版的 build_text() 是同一个思想:

逐行解释

instruction = example.get("instruction", "")

从样例字典里取出 instruction

如果没有,就用空字符串 ""

input_text = example.get("input", "")

取出 input 字段,没有就给空字符串。

output_text = example.get("output", "")

取出 output 字段,没有就给空字符串。

parts = []

创建一个空列表,用来按顺序装各段文本。

if instruction:

如果 instruction 不是空字符串,就继续。

parts.append(f"Instruction: {instruction}")

把 instruction 拼成一行文本,加进列表。

if input_text:

如果有 input,就继续。

parts.append(f"Input: {input_text}")

把 input 拼成一行文本。

parts.append(f"Response: {output_text}")

无论如何都把输出目标加进去。

return "\n".join(parts)

用换行符把这些部分拼成最终训练文本。

这个函数为什么关键

它是整个训练脚本真正决定“模型看到什么”的地方。

很多初学者以为 LoRA 的重点只在 LoRA 本身,其实不是。
训练格式设计同样重要。


五、第四部分:主函数开头和命令行参数

代码:

def main() -> None:
    parser = argparse.ArgumentParser(description="Minimal LoRA fine-tuning demo for AutoDL course.")
    parser.add_argument("--model_name_or_path", type=str, default="/root/course_lora/models/tiny-gpt2")
    parser.add_argument("--train_file", type=str, required=True)
    parser.add_argument("--validation_file", type=str, required=True)
    parser.add_argument("--output_dir", type=str, default="outputs/tiny_lora_demo")
    parser.add_argument("--num_train_epochs", type=int, default=1)
    parser.add_argument("--per_device_train_batch_size", type=int, default=2)
    parser.add_argument("--per_device_eval_batch_size", type=int, default=2)
    parser.add_argument("--learning_rate", type=float, default=5e-4)
    parser.add_argument("--max_length", type=int, default=128)
    parser.add_argument("--logging_steps", type=int, default=1)
    parser.add_argument("--save_steps", type=int, default=10)
    parser.add_argument("--eval_steps", type=int, default=10)
    args = parser.parse_args()

这一段在做什么

它在定义:

也就是说,这里决定了:

逐行解释

parser = argparse.ArgumentParser(...)

创建一个参数解析器。

description=... 的作用是:

parser.add_argument("--model_name_or_path", type=str, default="/root/course_lora/models/tiny-gpt2")

定义参数:

这说明:

parser.add_argument("--train_file", type=str, required=True)

训练集路径,必须传。

required=True 表示:

parser.add_argument("--validation_file", type=str, required=True)

验证集路径,必须传。

parser.add_argument("--output_dir", type=str, default="outputs/tiny_lora_demo")

输出目录。

parser.add_argument("--num_train_epochs", type=int, default=1)

训练轮数。

parser.add_argument("--per_device_train_batch_size", type=int, default=2)

训练 batch size。

parser.add_argument("--per_device_eval_batch_size", type=int, default=2)

验证 batch size。

parser.add_argument("--learning_rate", type=float, default=5e-4)

学习率。

parser.add_argument("--max_length", type=int, default=128)

最大 token 长度。

parser.add_argument("--logging_steps", type=int, default=1)

日志步长。

parser.add_argument("--save_steps", type=int, default=10)

保存步长。

parser.add_argument("--eval_steps", type=int, default=10)

评估步长。

args = parser.parse_args()

真正开始解析命令行参数,并把结果保存到 args

之后你就能用:

这种方式访问参数值。


六、第五部分:准备目录、读取数据、检查空文件

代码:

    os.makedirs(args.output_dir, exist_ok=True)

    train_records = load_jsonl_or_json(args.train_file)
    val_records = load_jsonl_or_json(args.validation_file)

    if not train_records:
        raise ValueError("Training file is empty.")
    if not val_records:
        raise ValueError("Validation file is empty.")

逐行解释

os.makedirs(args.output_dir, exist_ok=True)

确保输出目录存在。

train_records = load_jsonl_or_json(args.train_file)

读取训练集。

val_records = load_jsonl_or_json(args.validation_file)

读取验证集。

if not train_records:

如果训练集为空,就进入这里。

raise ValueError("Training file is empty.")

主动抛出错误,停止程序。

这是一种很好的工程习惯:

同理,验证集为空时也会报错。


七、第六部分:构造 Dataset

代码:

    train_dataset = Dataset.from_list([{"text": build_text(x)} for x in train_records])
    val_dataset = Dataset.from_list([{"text": build_text(x)} for x in val_records])

这一段在做什么

把 Python 列表转换成 Hugging Face Dataset

这和 notebook 版完全一致,只是换成脚本写法。

逐行解释

[{"text": build_text(x)} for x in train_records]

这是列表推导式。

意思是:

Dataset.from_list(...)

把这个列表变成 Hugging Face Dataset

验证集同理。


八、第七部分:加载 tokenizer 并做 tokenization

代码:

    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def tokenize_fn(batch: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:
        return tokenizer(
            batch["text"],
            truncation=True,
            padding="max_length",
            max_length=args.max_length,
        )

    train_dataset = train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])
    val_dataset = val_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

这一段在做什么

它完成两件事:

  1. 加载 tokenizer
  2. 把训练文本变成 token id

逐行解释

tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

根据命令行参数给出的模型路径或模型名加载 tokenizer。

if tokenizer.pad_token is None:

如果 tokenizer 没有 pad token,就进入这里。

tokenizer.pad_token = tokenizer.eos_token

eos_token 代替 pad token。

def tokenize_fn(batch: Dict[str, List[str]]) -> Dict[str, List[List[int]]]:

定义 tokenization 函数。

类型提示的意思是:

tokenizer(batch["text"], truncation=True, padding="max_length", max_length=args.max_length)

这里三项最重要:

也就是:

train_dataset.map(tokenize_fn, batched=True, remove_columns=["text"])

在整个数据集上批量应用 tokenization,并删除原始文本列。


九、第八部分:加载 base model,接 LoRA

代码:

    model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

这一段在做什么

加载原始模型,然后在它上面接 LoRA。

逐行解释

model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path)

加载 base model。

lora_config = LoraConfig(...)

定义 LoRA 配置。

参数意义:

model = get_peft_model(model, lora_config)

把 LoRA adapter 接到模型上。

model.print_trainable_parameters()

打印:

这一行非常值得学生认真看,因为它最直观地展示了:


十、第九部分:定义 data collator 和训练参数

代码:

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    training_args = TrainingArguments(
        output_dir=args.output_dir,
        num_train_epochs=args.num_train_epochs,
        per_device_train_batch_size=args.per_device_train_batch_size,
        per_device_eval_batch_size=args.per_device_eval_batch_size,
        learning_rate=args.learning_rate,
        logging_steps=args.logging_steps,
        save_steps=args.save_steps,
        eval_steps=args.eval_steps,
        eval_strategy="steps",
        save_strategy="steps",
        report_to="none",
        fp16=torch.cuda.is_available(),
        remove_unused_columns=False,
    )

data_collator 在做什么

它负责把 token 化后的样本整理成适合语言模型训练的 batch。

这里:

说明:

TrainingArguments 里最值得学生理解的参数


十一、第十部分:构造 Trainer 并启动训练

代码:

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=data_collator,
    )

    trainer.train()
    trainer.save_model(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Saved LoRA demo model to: {args.output_dir}")

这一段在做什么

它完成整个训练闭环的最后部分:

  1. 构造 Trainer
  2. 启动训练
  3. 保存训练结果
  4. 保存 tokenizer

逐行解释

trainer = Trainer(...)

把下面这些东西接起来:

trainer.train()

正式开始训练。

trainer.save_model(args.output_dir)

把训练结果保存到输出目录。

在 LoRA 场景里,这一步重点保存的是:

tokenizer.save_pretrained(args.output_dir)

把 tokenizer 也保存到输出目录,方便后续推理。

print(f"Saved LoRA demo model to: {args.output_dir}")

打印最终保存位置。


十二、第十一部分:脚本入口

代码:

if __name__ == "__main__":
    main()

这一段在做什么

这是 Python 脚本的标准入口写法。

它的意思是:

为什么要这样写?

因为这样做的好处是:


十三、这份脚本和 notebook 的关系

这份脚本本质上就是:

两者的主要差别不是算法,而是使用方式:

notebook 版更适合

脚本版更适合


十四、这份脚本最值得学生真正理解的地方

  1. 训练不是一个黑箱,它就是“数据 -> tokenizer -> model -> LoRA -> Trainer -> 输出”
  2. 命令行参数并不是吓人的东西,它只是把训练配置显式写出来
  3. 脚本版通常比 notebook 版更接近真实工程
  4. 只要看懂这份最小脚本,后面再学更复杂的 SFTTrainer 或量化版本就容易很多

十五、一个很值得注意的小细节

这份脚本里有一个未使用的导入:

from dataclasses import dataclass

它不会影响脚本运行,但它提醒学生:

对于初学者,这是一个很重要的阅读策略。