Chapter 0: TL;DR¶

Spoilers¶

In this short chapter, we'll get right to it and fine-tune a small-ish language model, Microsoft's Phi-3 Mini 4K Instruct, to translate English into Yoda-speak. You can think of this initial chapters as a recipe you can just follow. It is a "shoot first, ask questions later" kind of chapter.

You'll learn how to:

  • load a quantized model using BitsAndBytes
  • configure low-rank adapters (LoRA) using Hugging Face's peft
  • load and format a dataset
  • fine-tune the model using the supervised fine-tuning trainer (SFTTrainer) from Hugging Face's trl
  • use the fine-tuned model to generate a few sentences

Setup¶

For better reproducibility during training, use the pinned versions below, the same versions used in the book:

In [ ]:
# Original versions
#!pip install transformers==4.46.2 peft==0.13.2 accelerate==1.1.1 trl==0.12.1 bitsandbytes==0.45.2 datasets==3.1.0 huggingface-hub==0.26.2 safetensors==0.4.5 pandas==2.2.2 matplotlib==3.8.0 numpy==1.26.4
# Updated versions - October 2025
!pip install transformers==4.56.1 peft==0.17.0 accelerate==1.10.0 trl==0.23.1 bitsandbytes==0.47.0 datasets==4.0.0 huggingface-hub==0.34.4 safetensors==0.6.2 pandas==2.2.2 matplotlib==3.10.0 numpy==2.0.2
In [ ]:
# If you're running on Colab
#!pip install datasets bitsandbytes trl
In [2]:
# If you're running on runpod.io's Jupyter Template
!pip install datasets bitsandbytes trl transformers peft huggingface-hub accelerate safetensors pandas matplotlib
Collecting datasets
  Downloading datasets-4.6.1-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.49.2-py3-none-macosx_14_0_arm64.whl.metadata (10 kB)
Collecting trl
  Downloading trl-0.29.0-py3-none-any.whl.metadata (11 kB)
Collecting transformers
  Downloading transformers-5.3.0-py3-none-any.whl.metadata (32 kB)
Collecting peft
  Downloading peft-0.18.1-py3-none-any.whl.metadata (14 kB)
Requirement already satisfied: huggingface-hub in /opt/anaconda3/lib/python3.13/site-packages (1.3.2)
Collecting accelerate
  Downloading accelerate-1.13.0-py3-none-any.whl.metadata (19 kB)
Collecting safetensors
  Downloading safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl.metadata (4.1 kB)
Requirement already satisfied: pandas in /opt/anaconda3/lib/python3.13/site-packages (2.2.2)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.13/site-packages (3.10.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (3.17.0)
Requirement already satisfied: numpy>=1.17 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (1.26.4)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl.metadata (3.1 kB)
Requirement already satisfied: dill<0.4.1,>=0.3.0 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (0.3.8)
Requirement already satisfied: requests>=2.32.2 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (2.32.3)
Requirement already satisfied: httpx<1.0.0 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (0.28.1)
Requirement already satisfied: tqdm>=4.66.3 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (4.67.1)
Collecting xxhash (from datasets)
  Downloading xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (13 kB)
Collecting multiprocess<0.70.19 (from datasets)
  Downloading multiprocess-0.70.18-py313-none-any.whl.metadata (7.2 kB)
Requirement already satisfied: fsspec<=2026.2.0,>=2023.1.0 in /opt/anaconda3/lib/python3.13/site-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2025.3.2)
Requirement already satisfied: packaging in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (24.2)
Requirement already satisfied: pyyaml>=5.1 in /opt/anaconda3/lib/python3.13/site-packages (from datasets) (6.0.2)
Requirement already satisfied: hf-xet<2.0.0,>=1.2.0 in /opt/anaconda3/lib/python3.13/site-packages (from huggingface-hub) (1.2.0)
Requirement already satisfied: shellingham in /opt/anaconda3/lib/python3.13/site-packages (from huggingface-hub) (1.5.0)
Requirement already satisfied: typer-slim in /opt/anaconda3/lib/python3.13/site-packages (from huggingface-hub) (0.21.1)
Requirement already satisfied: typing-extensions>=4.1.0 in /opt/anaconda3/lib/python3.13/site-packages (from huggingface-hub) (4.12.2)
Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /opt/anaconda3/lib/python3.13/site-packages (from fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (3.11.10)
Requirement already satisfied: anyio in /opt/anaconda3/lib/python3.13/site-packages (from httpx<1.0.0->datasets) (4.7.0)
Requirement already satisfied: certifi in /opt/anaconda3/lib/python3.13/site-packages (from httpx<1.0.0->datasets) (2026.1.4)
Requirement already satisfied: httpcore==1.* in /opt/anaconda3/lib/python3.13/site-packages (from httpx<1.0.0->datasets) (1.0.9)
Requirement already satisfied: idna in /opt/anaconda3/lib/python3.13/site-packages (from httpx<1.0.0->datasets) (3.7)
Requirement already satisfied: h11>=0.16 in /opt/anaconda3/lib/python3.13/site-packages (from httpcore==1.*->httpx<1.0.0->datasets) (0.16.0)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Requirement already satisfied: torch<3,>=2.3 in /opt/anaconda3/lib/python3.13/site-packages (from bitsandbytes) (2.10.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.13/site-packages (from torch<3,>=2.3->bitsandbytes) (72.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.13/site-packages (from torch<3,>=2.3->bitsandbytes) (1.13.3)
Requirement already satisfied: networkx>=2.5.1 in /opt/anaconda3/lib/python3.13/site-packages (from torch<3,>=2.3->bitsandbytes) (3.4.2)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.13/site-packages (from torch<3,>=2.3->bitsandbytes) (3.1.6)
Requirement already satisfied: regex!=2019.12.17 in /opt/anaconda3/lib/python3.13/site-packages (from transformers) (2024.11.6)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Downloading tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl.metadata (7.3 kB)
Requirement already satisfied: typer in /opt/anaconda3/lib/python3.13/site-packages (from transformers) (0.9.0)
Requirement already satisfied: psutil in /opt/anaconda3/lib/python3.13/site-packages (from peft) (5.9.0)
Requirement already satisfied: python-dateutil>=2.8.2 in /opt/anaconda3/lib/python3.13/site-packages (from pandas) (2.9.0.post0)
Requirement already satisfied: pytz>=2020.1 in /opt/anaconda3/lib/python3.13/site-packages (from pandas) (2024.1)
Requirement already satisfied: tzdata>=2022.7 in /opt/anaconda3/lib/python3.13/site-packages (from pandas) (2025.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (1.3.1)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (4.55.3)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (1.4.8)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (11.1.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.13/site-packages (from matplotlib) (3.2.0)
Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (2.4.4)
Requirement already satisfied: aiosignal>=1.1.2 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.2.0)
Requirement already satisfied: attrs>=17.3.0 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (24.3.0)
Requirement already satisfied: frozenlist>=1.1.1 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.5.0)
Requirement already satisfied: multidict<7.0,>=4.5 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (6.1.0)
Requirement already satisfied: propcache>=0.2.0 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (0.3.1)
Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/anaconda3/lib/python3.13/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]<=2026.2.0,>=2023.1.0->datasets) (1.18.0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.13/site-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /opt/anaconda3/lib/python3.13/site-packages (from requests>=2.32.2->datasets) (3.3.2)
Requirement already satisfied: urllib3<3,>=1.21.1 in /opt/anaconda3/lib/python3.13/site-packages (from requests>=2.32.2->datasets) (2.3.0)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.13/site-packages (from sympy>=1.13.3->torch<3,>=2.3->bitsandbytes) (1.3.0)
Requirement already satisfied: sniffio>=1.1 in /opt/anaconda3/lib/python3.13/site-packages (from anyio->httpx<1.0.0->datasets) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.13/site-packages (from jinja2->torch<3,>=2.3->bitsandbytes) (3.0.2)
Requirement already satisfied: click<9.0.0,>=7.1.1 in /opt/anaconda3/lib/python3.13/site-packages (from typer->transformers) (8.1.8)
Downloading datasets-4.6.1-py3-none-any.whl (520 kB)
Downloading multiprocess-0.70.18-py313-none-any.whl (151 kB)
Downloading dill-0.4.0-py3-none-any.whl (119 kB)
Downloading bitsandbytes-0.49.2-py3-none-macosx_14_0_arm64.whl (131 kB)
Downloading trl-0.29.0-py3-none-any.whl (528 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 528.8/528.8 kB 4.2 MB/s  0:00:00
Downloading transformers-5.3.0-py3-none-any.whl (10.7 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.7/10.7 MB 6.8 MB/s  0:00:01.0 MB/s eta 0:00:01:02
Downloading tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl (3.0 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.0/3.0 MB 13.0 MB/s  0:00:00 ? eta -:--:--
Downloading peft-0.18.1-py3-none-any.whl (556 kB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 557.0/557.0 kB 11.8 MB/s  0:00:00
Downloading accelerate-1.13.0-py3-none-any.whl (383 kB)
Downloading safetensors-0.7.0-cp38-abi3-macosx_11_0_arm64.whl (447 kB)
Downloading pyarrow-23.0.1-cp313-cp313-macosx_12_0_arm64.whl (34.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 34.2/34.2 MB 2.3 MB/s  0:00:14 eta 0:00:010:01:02
Downloading xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl (30 kB)
Installing collected packages: xxhash, safetensors, pyarrow, dill, multiprocess, bitsandbytes, tokenizers, datasets, accelerate, transformers, trl, peft
  Attempting uninstall: pyarrow
    Found existing installation: pyarrow 19.0.0
    Uninstalling pyarrow-19.0.0:
      Successfully uninstalled pyarrow-19.0.0
  Attempting uninstall: dill[0m╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  2/12 [pyarrow]
    Found existing installation: dill 0.3.84m╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  2/12 [pyarrow]
    Uninstalling dill-0.3.8:[38;2;249;38;114m╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  2/12 [pyarrow]
      Successfully uninstalled dill-0.3.8114m╸━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  2/12 [pyarrow]
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12/12 [peft]8;5;237m━━━ 11/12 [peft]]rate]
Successfully installed accelerate-1.13.0 bitsandbytes-0.49.2 datasets-4.6.1 dill-0.4.0 multiprocess-0.70.18 peft-0.18.1 pyarrow-23.0.1 safetensors-0.7.0 tokenizers-0.22.2 transformers-5.3.0 trl-0.29.0 xxhash-3.6.0

Imports¶

In [3]:
import os
import torch
from contextlib import nullcontext
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

Loading a Quantized Base Model¶

We start by loading a quantized model, so it takes up less space in the GPU's RAM. A quantized model replaces the original weights with approximate values that are represented by fewer bits. The simplest and most straightforward way to quantize a model is to turn its weights from 32-bit floating-point (FP32) numbers into 4-bit floating-point numbers (NF4). This simple yet powerful change already reduces the model's memory footprint by roughly a factor of eight.

We can use an instance of BitsAndBytesConfig as the quantization_config argument while loading a model using the from_pretrained() method. To keep it flexible, so you can try it out with any other model of your choice, we're using Hugging Face's AutoModelForCausalLM. The repo you choose to use determines the model being loaded.

Without further ado, here's our quantized model being loaded:

In [4]:
bnb_config = BitsAndBytesConfig(
   load_in_4bit=True,
   bnb_4bit_quant_type="nf4",
   bnb_4bit_use_double_quant=True,
   bnb_4bit_compute_dtype=torch.float32
)
repo_id = 'microsoft/Phi-3-mini-4k-instruct'
model = AutoModelForCausalLM.from_pretrained(repo_id,
                                             device_map="cuda:0",
                                             quantization_config=bnb_config
)
config.json:   0%|          | 0.00/967 [00:00<?, ?B/s]
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
model.safetensors.index.json: 0.00B [00:00, ?B/s]
Downloading (incomplete total...): 0.00B [00:00, ?B/s]
Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]
Cancellation requested; stopping current tasks.
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[4], line 8
      1 bnb_config = BitsAndBytesConfig(
      2    load_in_4bit=True,
      3    bnb_4bit_quant_type="nf4",
      4    bnb_4bit_use_double_quant=True,
      5    bnb_4bit_compute_dtype=torch.float32
      6 )
      7 repo_id = 'microsoft/Phi-3-mini-4k-instruct'
----> 8 model = AutoModelForCausalLM.from_pretrained(repo_id,
      9                                              device_map="cuda:0",
     10                                              quantization_config=bnb_config
     11 )

File /opt/anaconda3/lib/python3.13/site-packages/transformers/models/auto/auto_factory.py:374, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    372     if model_class.config_class == config.sub_configs.get("text_config", None):
    373         config = config.get_text_config()
--> 374     return model_class.from_pretrained(
    375         pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    376     )
    377 raise ValueError(
    378     f"Unrecognized configuration class {config.__class__} for this kind of AutoModel: {cls.__name__}.\n"
    379     f"Model type should be one of {', '.join(c.__name__ for c in cls._model_mapping)}."
    380 )

File /opt/anaconda3/lib/python3.13/site-packages/transformers/modeling_utils.py:4060, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   4055     logger.warning_once(
   4056         "A kernel_config was provided but use_kernels is False; setting use_kernels=True automatically. To suppress this warning, explicitly set use_kernels to True."
   4057     )
   4058     use_kernels = True
-> 4060 checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
   4061     pretrained_model_name_or_path=pretrained_model_name_or_path,
   4062     variant=variant,
   4063     gguf_file=gguf_file,
   4064     use_safetensors=use_safetensors,
   4065     download_kwargs=download_kwargs_with_commit,
   4066     user_agent=user_agent,
   4067     is_remote_code=cls.is_remote_code(),
   4068     transformers_explicit_filename=getattr(config, "transformers_weights", None),
   4069 )
   4071 is_quantized = hf_quantizer is not None
   4073 if gguf_file:

File /opt/anaconda3/lib/python3.13/site-packages/transformers/modeling_utils.py:738, in _get_resolved_checkpoint_files(pretrained_model_name_or_path, variant, gguf_file, use_safetensors, user_agent, is_remote_code, transformers_explicit_filename, download_kwargs)
    736 sharded_metadata = None
    737 if is_sharded:
--> 738     checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
    739         pretrained_model_name_or_path,
    740         resolved_archive_file,
    741         cache_dir=cache_dir,
    742         force_download=force_download,
    743         proxies=proxies,
    744         local_files_only=local_files_only,
    745         token=token,
    746         user_agent=user_agent,
    747         revision=revision,
    748         subfolder=subfolder,
    749         _commit_hash=commit_hash,
    750     )
    751 else:
    752     checkpoint_files = [resolved_archive_file] if pretrained_model_name_or_path is not None else None

File /opt/anaconda3/lib/python3.13/site-packages/transformers/utils/hub.py:874, in get_checkpoint_shard_files(pretrained_model_name_or_path, index_filename, cache_dir, force_download, proxies, local_files_only, token, user_agent, revision, subfolder, _commit_hash, **deprecated_kwargs)
    870     return shard_filenames, sharded_metadata
    872 # At this stage pretrained_model_name_or_path is a model identifier on the Hub. Try to get everything from cache,
    873 # or download the files
--> 874 cached_filenames = cached_files(
    875     pretrained_model_name_or_path,
    876     shard_filenames,
    877     cache_dir=cache_dir,
    878     force_download=force_download,
    879     proxies=proxies,
    880     local_files_only=local_files_only,
    881     token=token,
    882     user_agent=user_agent,
    883     revision=revision,
    884     subfolder=subfolder,
    885     _commit_hash=_commit_hash,
    886 )
    888 return cached_filenames, sharded_metadata

File /opt/anaconda3/lib/python3.13/site-packages/transformers/utils/hub.py:434, in cached_files(path_or_repo_id, filenames, cache_dir, force_download, proxies, token, revision, local_files_only, subfolder, repo_type, user_agent, _raise_exceptions_for_gated_repo, _raise_exceptions_for_missing_entries, _raise_exceptions_for_connection_errors, _commit_hash, **deprecated_kwargs)
    420         hf_hub_download(
    421             path_or_repo_id,
    422             filenames[0],
   (...)
    431             local_files_only=local_files_only,
    432         )
    433     else:
--> 434         snapshot_download(
    435             path_or_repo_id,
    436             allow_patterns=full_filenames,
    437             repo_type=repo_type,
    438             revision=revision,
    439             cache_dir=cache_dir,
    440             user_agent=user_agent,
    441             force_download=force_download,
    442             proxies=proxies,
    443             token=token,
    444             local_files_only=local_files_only,
    445         )
    447 except Exception as e:
    448     # We cannot recover from them
    449     if isinstance(e, RepositoryNotFoundError) and not isinstance(e, GatedRepoError):

File /opt/anaconda3/lib/python3.13/site-packages/huggingface_hub/utils/_validators.py:89, in validate_hf_hub_args.<locals>._inner_fn(*args, **kwargs)
     85         validate_repo_id(arg_value)
     87 kwargs = smoothly_deprecate_legacy_arguments(fn_name=fn.__name__, kwargs=kwargs)
---> 89 return fn(*args, **kwargs)

File /opt/anaconda3/lib/python3.13/site-packages/huggingface_hub/_snapshot_download.py:449, in snapshot_download(repo_id, repo_type, revision, cache_dir, local_dir, library_name, library_version, user_agent, etag_timeout, force_download, token, local_files_only, allow_patterns, ignore_patterns, max_workers, tqdm_class, headers, endpoint, dry_run)
    427 def _inner_hf_hub_download(repo_file: str) -> None:
    428     results.append(
    429         hf_hub_download(  # type: ignore
    430             repo_id,
   (...)
    446         )
    447     )
--> 449 thread_map(
    450     _inner_hf_hub_download,
    451     filtered_repo_files,
    452     desc=tqdm_desc,
    453     max_workers=max_workers,
    454     tqdm_class=tqdm_class,
    455 )
    457 bytes_progress.set_description("Download complete")
    459 if dry_run:

File /opt/anaconda3/lib/python3.13/site-packages/tqdm/contrib/concurrent.py:69, in thread_map(fn, *iterables, **tqdm_kwargs)
     55 """
     56 Equivalent of `list(map(fn, *iterables))`
     57 driven by `concurrent.futures.ThreadPoolExecutor`.
   (...)
     66     [default: max(32, cpu_count() + 4)].
     67 """
     68 from concurrent.futures import ThreadPoolExecutor
---> 69 return _executor_map(ThreadPoolExecutor, fn, *iterables, **tqdm_kwargs)

File /opt/anaconda3/lib/python3.13/site-packages/tqdm/contrib/concurrent.py:51, in _executor_map(PoolExecutor, fn, *iterables, **tqdm_kwargs)
     47 with ensure_lock(tqdm_class, lock_name=lock_name) as lk:
     48     # share lock in case workers are already using `tqdm`
     49     with PoolExecutor(max_workers=max_workers, initializer=tqdm_class.set_lock,
     50                       initargs=(lk,)) as ex:
---> 51         return list(tqdm_class(ex.map(fn, *iterables, chunksize=chunksize), **kwargs))

File /opt/anaconda3/lib/python3.13/site-packages/tqdm/notebook.py:250, in tqdm_notebook.__iter__(self)
    248 try:
    249     it = super().__iter__()
--> 250     for obj in it:
    251         # return super(tqdm...) will not catch exception
    252         yield obj
    253 # NB: except ... [ as ...] breaks IPython async KeyboardInterrupt

File /opt/anaconda3/lib/python3.13/site-packages/tqdm/std.py:1181, in tqdm.__iter__(self)
   1178 time = self._time
   1180 try:
-> 1181     for obj in iterable:
   1182         yield obj
   1183         # Update and possibly print the progressbar.
   1184         # Note: does not call self.update(1) for speed optimisation.

File /opt/anaconda3/lib/python3.13/concurrent/futures/_base.py:619, in Executor.map.<locals>.result_iterator()
    616 while fs:
    617     # Careful not to keep a reference to the popped future
    618     if timeout is None:
--> 619         yield _result_or_cancel(fs.pop())
    620     else:
    621         yield _result_or_cancel(fs.pop(), end_time - time.monotonic())

File /opt/anaconda3/lib/python3.13/concurrent/futures/_base.py:317, in _result_or_cancel(***failed resolving arguments***)
    315 try:
    316     try:
--> 317         return fut.result(timeout)
    318     finally:
    319         fut.cancel()

File /opt/anaconda3/lib/python3.13/concurrent/futures/_base.py:451, in Future.result(self, timeout)
    448 elif self._state == FINISHED:
    449     return self.__get_result()
--> 451 self._condition.wait(timeout)
    453 if self._state in [CANCELLED, CANCELLED_AND_NOTIFIED]:
    454     raise CancelledError()

File /opt/anaconda3/lib/python3.13/threading.py:359, in Condition.wait(self, timeout)
    357 try:    # restore state no matter what (e.g., KeyboardInterrupt)
    358     if timeout is None:
--> 359         waiter.acquire()
    360         gotit = True
    361     else:

KeyboardInterrupt: 

"The Phi-3-Mini-4K-Instruct is a 3.8B parameters, lightweight, state-of-the-art open model trained with the Phi-3 datasets that includes both synthetic data and the filtered publicly available websites data with a focus on high-quality and reasoning dense properties. The model belongs to the Phi-3 family with the Mini version in two variants 4K and 128K which is the context length (in tokens) that it can support."
Source: Hugging Face Hub

Once the model is loaded, you can see how much space it occupies in memory using the get_memory_footprint() method.

In [5]:
print(model.get_memory_footprint()/1e6)
2206.341312

Even though it's been quantized, the model still takes up a bit more than 2 gigabytes of RAM. The quantization procedure focuses on the linear layers within the Transformer decoder blocks (also referred to as "layers" in some cases):

In [6]:
model
Out[6]:
Phi3ForCausalLM(
  (model): Phi3Model(
    (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
    (layers): ModuleList(
      (0-31): 32 x Phi3DecoderLayer(
        (self_attn): Phi3Attention(
          (o_proj): Linear4bit(in_features=3072, out_features=3072, bias=False)
          (qkv_proj): Linear4bit(in_features=3072, out_features=9216, bias=False)
        )
        (mlp): Phi3MLP(
          (gate_up_proj): Linear4bit(in_features=3072, out_features=16384, bias=False)
          (down_proj): Linear4bit(in_features=8192, out_features=3072, bias=False)
          (activation_fn): SiLU()
        )
        (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
        (resid_attn_dropout): Dropout(p=0.0, inplace=False)
        (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (norm): Phi3RMSNorm((3072,), eps=1e-05)
    (rotary_emb): Phi3RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
)

A quantized model can be used directly for inference, but it cannot be trained any further. Those pesky Linear4bit layers take up much less space, which is the whole point of quantization; however, we cannot update them.

We need to add something else to our mix, a sprinkle of adapters.

Setting Up Low-Rank Adapters (LoRA)¶

Low-rank adapters can be attached to each and every one of the quantized layers. The adapters are mostly regular Linear layers that can be easily updated as usual. The clever trick in this case is that these adapters are significantly smaller than the layers that have been quantized.

Since the quantized layers are frozen (they cannot be updated), setting up LoRA adapters on a quantized model drastically reduces the total number of trainable parameters to just 1% (or less) of its original size.

We can set up LoRA adapters in three easy steps:

  • Call prepare_model_for_kbit_training() to improve numerical stability during training.
  • Create an instance of LoraConfig.
  • Apply the configuration to the quantized base model using the get_peft_model() method.

Let's try it out with our model:

In [7]:
model = prepare_model_for_kbit_training(model)

config = LoraConfig(
    r=8,                   # the rank of the adapter, the lower the fewer parameters you'll need to train
    lora_alpha=16,         # multiplier, usually 2*r
    bias="none",           # BEWARE: training biases *modifies* base model's behavior
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
    # Newer models, such as Phi-3 at time of writing, may require
    # manually setting target modules
    target_modules=['o_proj', 'qkv_proj', 'gate_up_proj', 'down_proj'],
)

model = get_peft_model(model, config)
model
Out[7]:
PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Phi3ForCausalLM(
      (model): Phi3Model(
        (embed_tokens): Embedding(32064, 3072, padding_idx=32000)
        (layers): ModuleList(
          (0-31): 32 x Phi3DecoderLayer(
            (self_attn): Phi3Attention(
              (o_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (qkv_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=9216, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=9216, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
            )
            (mlp): Phi3MLP(
              (gate_up_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=16384, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=16384, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=8192, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=8192, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (activation_fn): SiLU()
            )
            (input_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
            (post_attention_layernorm): Phi3RMSNorm((3072,), eps=1e-05)
            (resid_attn_dropout): Dropout(p=0.0, inplace=False)
            (resid_mlp_dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (norm): Phi3RMSNorm((3072,), eps=1e-05)
        (rotary_emb): Phi3RotaryEmbedding()
      )
      (lm_head): Linear(in_features=3072, out_features=32064, bias=False)
    )
  )
)

The output of the other three LoRA layers (qkv_proj, gate_up_proj, and down_proj) was suppressed to shorten the output.

Did you get the following error?

ValueError: Please specify `target_modules` in `peft_config`

Most likely, you don't need to specify the target_modules if you're using one of the well-known models. The peft library takes care of it by automatically choosing the appropriate targets. However, there may be a gap between the time a popular model is released and the time the library gets updated. So, if you get the error above, look for the quantized layers in your model and list their names in the target_modules argument.

The quantized layers (Linear4bit) have turned into lora.Linear4bit modules where the quantized layer itself became the base_layer with some regular Linear layers (lora_A and lora_B) added to the mix.

These extra layers would make the model only slightly larger. However, the model preparation function (prepare_model_for_kbit_training()) turned every non-quantized layer to full precision (FP32), thus resulting in a 30% larger model:

In [8]:
print(model.get_memory_footprint()/1e6)
2651.074752

Since most parameters are frozen, only a tiny fraction of the total number of parameters are currently trainable, thanks to LoRA!

In [9]:
trainable_parms, tot_parms = model.get_nb_trainable_parameters()
print(f'Trainable parameters:             {trainable_parms/1e6:.2f}M')
print(f'Total parameters:                 {tot_parms/1e6:.2f}M')
print(f'Fraction of trainable parameters: {100*trainable_parms/tot_parms:.2f}%')
Trainable parameters:             12.58M
Total parameters:                 3833.66M
Fraction of trainable parameters: 0.33%

The model is ready to be fine-tuned, but we are still missing one key component: our dataset.

Formatting Your Dataset¶

"Like Yoda, speak, you must. Hrmmm."

Master Yoda

The dataset yoda_sentences consists of 720 sentences translated from English to Yoda-speak. The dataset is hosted on the Hugging Face Hub and we can easily load it using the load_dataset() method from the Hugging Face datasets library:

In [10]:
dataset = load_dataset("dvgodoy/yoda_sentences", split="train")
dataset
README.md:   0%|          | 0.00/531 [00:00<?, ?B/s]
sentences.csv: 0.00B [00:00, ?B/s]
Generating train split:   0%|          | 0/720 [00:00<?, ? examples/s]
Out[10]:
Dataset({
    features: ['sentence', 'translation', 'translation_extra'],
    num_rows: 720
})

The dataset has three columns:

  • original English sentence (sentence)
  • basic translation to Yoda-speak (translation)
  • enhanced translation including typical Yesss and Hrrmm interjections (translation_extra)
In [11]:
dataset[0]
Out[11]:
{'sentence': 'The birch canoe slid on the smooth planks.',
 'translation': 'On the smooth planks, the birch canoe slid.',
 'translation_extra': 'On the smooth planks, the birch canoe slid. Yes, hrrrm.'}

The SFTTrainer we'll be using to fine-tune the model can automatically handle datasets in conversational format.

{"messages":[
  {"role": "system", "content": "<general directives>"},
  {"role": "user", "content": "<prompt text>"},
  {"role": "assistant", "content": "<ideal generated text>"}
]}

IMPORTANT UPDATE: unfortunately, in more recent versions of the trl library, the "instruction" format is not properly supported anymore, thus leading to the chat template not being applied to the dataset. In order to avoid this issue, we can convert the dataset to the "conversational" format.


In [15]:
# Adapted from trl.extras.dataset_formatting.instructions_formatting_function
# Converts dataset from prompt/completion format (not supported anymore)
# to the conversational format
def format_dataset(examples):
    if isinstance(examples["prompt"], list):
        output_texts = []
        for i in range(len(examples["prompt"])):
            converted_sample = [
                {"role": "user", "content": examples["prompt"][i]},
                {"role": "assistant", "content": examples["completion"][i]},
            ]
            output_texts.append(converted_sample)
        return {'messages': output_texts}
    else:
        converted_sample = [
            {"role": "user", "content": examples["prompt"]},
            {"role": "assistant", "content": examples["completion"]},
        ]
        return {'messages': converted_sample}
In [16]:
dataset = dataset.rename_column("sentence", "prompt")
dataset = dataset.rename_column("translation_extra", "completion")
dataset = dataset.map(format_dataset)
dataset = dataset.remove_columns(['prompt', 'completion', 'translation'])
messages = dataset[0]['messages']
messages
Map:   0%|          | 0/720 [00:00<?, ? examples/s]

Tokenizer¶

Before moving into the actual training, we still need to load the tokenizer that corresponds to our model. The tokenizer is an important part of this process, determining how to convert text into tokens in the same way used to train the model.

For instruction/chat models, the tokenizer also contains its corresponding chat template that specifies:

  • Which special tokens should be used, and where they should be placed.
  • Where the system directives, user prompt, and model response should be placed.
  • What is the generation prompt, that is, the special token that triggers the model's response (more on that in the "Querying the Model" section)

IMPORTANT UPDATE: due to changes in the default collator used by the SFTTrainer class while building the dataset, the EOS token (which is, in Phi-3, the same as the PAD token) was masked in the labels too thus leading to the model not being able to properly stop token generation.

In order to address this change, we can assign the UNK token to the PAD token, so the EOS token becomes unique and therefore not masked as part of the labels.


In [17]:
tokenizer = AutoTokenizer.from_pretrained(repo_id)
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id

tokenizer.chat_template
tokenizer_config.json: 0.00B [00:00, ?B/s]
tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]
tokenizer.json: 0.00B [00:00, ?B/s]
added_tokens.json:   0%|          | 0.00/306 [00:00<?, ?B/s]
special_tokens_map.json:   0%|          | 0.00/599 [00:00<?, ?B/s]
Out[17]:
"{% for message in messages %}{% if message['role'] == 'system' %}{{'<|system|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'user' %}{{'<|user|>\n' + message['content'] + '<|end|>\n'}}{% elif message['role'] == 'assistant' %}{{'<|assistant|>\n' + message['content'] + '<|end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}"

Never mind the seemingly overcomplicated template (I have added line breaks and indentation to it so it's easier to read). It simply organizes the messages into a coherent block with the appropriate tags, as shown below (tokenize=False ensures we get readable text back instead of a numeric sequence of token IDs):

In [18]:
print(tokenizer.apply_chat_template(messages, tokenize=False))
<|user|>
The birch canoe slid on the smooth planks.<|end|>
<|assistant|>
On the smooth planks, the birch canoe slid. Yes, hrrrm.<|end|>
<|endoftext|>

Notice that each interaction is wrapped in either <|user|> or <|assistant|> tokens at the beginning and <|end|> at the end. Moreover, the <|endoftext|> token indicates the end of the whole block.

Different models will have different templates and tokens to indicate the beginning and end of sentences and blocks.

We're now ready to tackle the actual fine-tuning!

Fine-Tuning with SFTTrainer¶

Fine-tuning a model, whether large or otherwise, follows exactly the same training procedure as training a model from scratch. We could write our own training loop in pure PyTorch, or we could use Hugging Face's Trainer to fine-tune our model.

It is much easier, however, to use SFTTrainer instead (which uses Trainer underneath, by the way), since it takes care of most of the nitty-gritty details for us, as long as we provide it with the following four arguments:

  • a model
  • a tokenizer
  • a dataset
  • a configuration object

We've already got the first three elements; let's work on the last one.

SFTConfig¶

There are many parameters that we can set in the configuration object. We have divided them into four groups:

  • Memory usage optimization parameters related to gradient accumulation and checkpointing
  • Dataset-related arguments, such as the max_seq_length required by your data, and whether you are packing or not the sequences
  • Typical training parameters such as the learning_rate and the num_train_epochs
  • Environment and logging parameters such as output_dir (this will be the name of the model if you choose to push it to the Hugging Face Hub once it's trained), logging_dir, and logging_steps.

While the learning rate is a very important parameter (as a starting point, you can try the learning rate used to train the base model in the first place), it's actually the maximum sequence length that's more likely to cause out-of-memory issues.

Make sure to always pick the shortest possible max_seq_length that makes sense for your use case. In ours, the sentences—both in English and Yoda-speak—are quite short, and a sequence of 64 tokens is more than enough to cover the prompt, the completion, and the added special tokens.

Flash attention (which, unfortunately, isn't supported in Colab), allows for more flexibility in working with longer sequences, avoiding the potential issue of OOM errors.


IMPORTANT UPDATE: The release of trl version 0.20 brought several changes to the SFTConfig:

  • packing is performed differently than it was, unless packing_strategy='wrapped' is set;
  • the max_seq_length argument was renamed to max_length;
  • the bf16 defaults to True but, at the time of this update (Oct/2025), it didn't check if the BF16 type was actually available or not, so it's included in the configuration now.

In [21]:
sft_config = SFTConfig(
    ## GROUP 1: Memory usage
    # These arguments will squeeze the most out of your GPU's RAM
    # Checkpointing
    gradient_checkpointing=True,
    # this saves a LOT of memory
    # Set this to avoid exceptions in newer versions of PyTorch
    gradient_checkpointing_kwargs={'use_reentrant': False},
    # Gradient Accumulation / Batch size
    # Actual batch (for updating) is same (1x) as micro-batch size
    gradient_accumulation_steps=1,
    # The initial (micro) batch size to start off with
    per_device_train_batch_size=16,
    # If batch size would cause OOM, halves its size until it works
    auto_find_batch_size=True,

    ## GROUP 2: Dataset-related
    max_length=64, # renamed in v0.20
    # Dataset
    # packing a dataset means no padding is needed
    packing=True,
    packing_strategy='wrapped', # added to approximate original packing behavior

    ## GROUP 3: These are typical training parameters
    num_train_epochs=10,
    learning_rate=3e-4,
    # Optimizer
    # 8-bit Adam optimizer - doesn't help much if you're using LoRA!
    optim='paged_adamw_8bit',

    ## GROUP 4: Logging parameters
    logging_steps=10,
    logging_dir='./logs',
    output_dir='./phi3-mini-yoda-adapter',
    report_to='none',

    # ensures bf16 (the new default) is only used when it is actually available
    bf16=torch.cuda.is_bf16_supported(including_emulation=False)
)

SFTTrainer¶

"It is training time!"

The Hulk


IMPORTANT UPDATE: Up to version 0.23 of trl, there was a known issue where training failed if the LoRA configuration had already been applied to the model, as the trainer froze the whole model, including the adapters.

If the model already contained the adapters, as in our case, training would only work if the underlying original model (model.base_model.model) was used together with the peft_config argument.

This issue was fixed in version 0.23.1 of trl, released in October 2025.


We can now finally create an instance of the supervised fine-tuning trainer:

In [22]:
trainer = SFTTrainer(
    model=model,
    processing_class=tokenizer,
    args=sft_config,
    train_dataset=dataset,
)
/usr/local/lib/python3.12/dist-packages/peft/mapping_func.py:73: UserWarning: You are trying to modify a model with PEFT for a second time. If you want to reload the model with a different config, make sure to call `.unload()` before.
  warnings.warn(
/usr/local/lib/python3.12/dist-packages/peft/tuners/tuners_utils.py:196: UserWarning: Already found a `peft_config` attribute in the model. This will lead to having multiple adapters in the model. Make sure to know what you are doing!
  warnings.warn(
Tokenizing train dataset:   0%|          | 0/720 [00:00<?, ? examples/s]
Packing train dataset:   0%|          | 0/720 [00:00<?, ? examples/s]

The SFTTrainer had already preprocessed our dataset, so we can take a look inside and see how each mini-batch was assembled:

In [23]:
dl = trainer.get_train_dataloader()
batch = next(iter(dl))
In [24]:
batch['input_ids'][0], batch['labels'][0]
Out[24]:
(tensor([ 3974, 29892,  4337,   278,   325,   271, 29892,   366,  1818, 29889,
         32007, 32000, 32010,   450,   289,   935,   310,   278,   282,   457,
          5447,   471,   528,  4901,   322,  6501, 29889, 32007, 32001, 26399,
          1758,  4317, 29889,  1383,  4901,   322,  6501, 29892,   278,   289,
           935,   310,   278,   282,   457,  5447,   471, 29889, 32007, 32000,
         32010,   951,  5989,  2507, 17354,   322, 13328,   297,   278,  6416,
         29889, 32007, 32001,   512], device='cuda:0'),
 tensor([ 3974, 29892,  4337,   278,   325,   271, 29892,   366,  1818, 29889,
         32007, 32000, 32010,   450,   289,   935,   310,   278,   282,   457,
          5447,   471,   528,  4901,   322,  6501, 29889, 32007, 32001, 26399,
          1758,  4317, 29889,  1383,  4901,   322,  6501, 29892,   278,   289,
           935,   310,   278,   282,   457,  5447,   471, 29889, 32007, 32000,
         32010,   951,  5989,  2507, 17354,   322, 13328,   297,   278,  6416,
         29889, 32007, 32001,   512], device='cuda:0'))

The labels were added automatically, and they're exactly the same as the inputs. Thus, this is a case of self-supervised fine-tuning.

The shifting of the labels will be handled automatically as well; there's no need to be concerned about it.

Although this is a 3.8 billion-parameter model, the configuration above allows us to squeeze training, using a mini-batch of eight, into an old setup with a consumer-grade GPU such as a GTX 1060 with only 6 GB RAM. True story!
It takes about 35 minutes to complete the training process.

Next, we call the train() method and wait:

In [25]:
trainer.train()
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
[220/220 25:28, Epoch 10/10]
Step Training Loss
10 2.845800
20 1.824200
30 1.604200
40 1.520200
50 1.397400
60 1.291100
70 1.185800
80 0.984700
90 0.889100
100 0.619500
110 0.587300
120 0.423400
130 0.429100
140 0.370600
150 0.352000
160 0.321200
170 0.309000
180 0.303000
190 0.274600
200 0.263400
210 0.243700
220 0.244900

Out[25]:
TrainOutput(global_step=220, training_loss=0.831096840988506, metrics={'train_runtime': 1536.2273, 'train_samples_per_second': 2.22, 'train_steps_per_second': 0.143, 'total_flos': 4890970340720640.0, 'train_loss': 0.831096840988506})

Querying the Model¶

Now, our model should be able to produce a Yoda-like sentence as a response to any short sentence we give it.

So, the model requires its inputs to be properly formatted. We need to build a list of "messages"—ours, from the user, in this case—and prompt the model to answer by indicating it's its turn to write.

This is the purpose of the add_generation_prompt argument: it adds <|assistant|> to the end of the conversation, so the model can predict the next word—and continue doing so until it predicts an <|endoftext|> token.

The helper function below assembles a message (in the conversational format) and applies the chat template to it, appending the generation prompt to its end.

In [26]:
def gen_prompt(tokenizer, sentence):
    converted_sample = [
        {"role": "user", "content": sentence},
    ]
    prompt = tokenizer.apply_chat_template(converted_sample,
                                           tokenize=False,
                                           add_generation_prompt=True)
    return prompt

Let's try generating a prompt for an example sentence:

In [27]:
sentence = 'The Force is strong in you!'
prompt = gen_prompt(tokenizer, sentence)
print(prompt)
<|user|>
The Force is strong in you!<|end|>
<|assistant|>

The prompt seems about right; let's use it to generate a completion. The helper function below does the following:

  • It tokenizes the prompt into a tensor of token IDs (add_special_tokens is set to False because the tokens were already added by the chat template).
  • It sets the model to evaluation mode.
  • It calls the model's generate() method to produce the output (generated token IDs).
    • If the model was trained using mixed-precision, we wrap the generation in the autocast() context manager, which automatically handles conversion between data types.
  • It decodes the generated token IDs back into readable text.
In [28]:
def generate(model, tokenizer, prompt, max_new_tokens=64, skip_special_tokens=False):
    tokenized_input = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to(model.device)

    model.eval()
    # if it was trained using mixed precision, uses autocast context
    ctx = torch.autocast(device_type=model.device.type, dtype=model.dtype) \
          if model.dtype in [torch.float16, torch.bfloat16] else nullcontext()
    with ctx:  
        generation_output = model.generate(**tokenized_input,
                                           eos_token_id=tokenizer.eos_token_id,
                                           max_new_tokens=max_new_tokens)

    output = tokenizer.batch_decode(generation_output,
                                    skip_special_tokens=skip_special_tokens)
    return output[0]

Now, we can finally try out our model and see if it's indeed capable of generating Yoda-speak.

In [29]:
print(generate(model, tokenizer, prompt))
<|user|> The Force is strong in you!<|end|><|assistant|> Strong in you, the Force is.<|end|><|endoftext|>

Awesome! It works! Like Yoda, the model speaks. Hrrrmm.

Congratulations, you've fine-tuned your first LLM!

Now, you've got a small adapter that can be loaded into an instance of the Phi-3 Mini 4K Instruct model to turn it into a Yoda translator! How cool is that?

Saving the Adapter¶

Once the training is completed, you can save the adapter (and the tokenizer) to disk by calling the trainer's save_model() method. It will save everything to the specified folder:

In [30]:
trainer.save_model('local-phi3-mini-yoda-adapter')

The files that were saved include:

  • the adapter configuration (adapter_config.json) and weights (adapter_model.safetensors)—the adapter itself is just 50 MB in size
  • the training arguments (training_args.bin)
  • the tokenizer (tokenizer.json and tokenizer.model), its configuration (tokenizer_config.json), and its special tokens (added_tokens.json and speciak_tokens_map.json)
  • a README file
In [31]:
os.listdir('local-phi3-mini-yoda-adapter')
Out[31]:
['tokenizer.model',
 'README.md',
 'special_tokens_map.json',
 'adapter_config.json',
 'adapter_model.safetensors',
 'tokenizer_config.json',
 'tokenizer.json',
 'chat_template.jinja',
 'training_args.bin',
 'added_tokens.json']

If you'd like to share your adapter with everyone, you can also push it to the Hugging Face Hub. First, log in using a token that has permission to write:

In [ ]:
from huggingface_hub import login
login()

The code above will ask you to enter an access token:

Figure 0.1 - Logging into the Hugging Face Hub

A successful login should look like this (pay attention to the permissions):

Figure 0.2 - Successful Login

Then, you can use the trainer's push_to_hub() method to upload everything to your account in the Hub. The model will be named after the output_dir argument of the training arguments:

In [ ]:
trainer.push_to_hub()

There you go! Our model is out there in the world, and anyone can use it to translate English into Yoda speak.

That's a wrap!