小型LLM PLaMo 2 1BをGoogle ColabでSFTしてみる

今回はPreferred Networksとその子会社のPreferred Elementsが共同で開発した1Bサイズの小型のLLM、PLaMo 2 1Bに対してSFTをするコードの紹介になります。

Google Colabの無料枠で推論を回す方法は前回記事にしましたので、そもそもPLaMo 2 1Bって何と思った方や推論を回してみたいという方はそちらをご覧ください。

また、今回説明に使うコードはこちらに置いてありますので、適宜参照してください。

https://github.com/shu65/plamo-2-1b-sft-example

Google Colabにおける一連の実行に関してはJupyter Notebookにまとめてありますので、細かい実行方法がわからないという方はこちらをご覧ください

https://github.com/shu65/plamo-2-1b-sft-example/blob/main/run_sft_google_colab.ipynb

Supervised Fine-Tuning(SFT)とは?

SFTを知らない方に簡単に説明すると、SFTは指示と想定されている回答のペアを用意し、LLMに対して学習を行い、指示に従いやすいモデルを作る方法になります。

特にPLaMo 2 1Bのような事前学習モデルでは、特に指示に従うように学習されていないケースもあり、そのまま利用した際、余計なことをだらだらと出力し続けたり、頓珍漢な回答が返ってきたりという問題が発生することがあります。

このため指示に適切にこたえてもらうための技術がいろいろあるのですが、そのうちの一つにSFTというものがあります。

Google ColabでPLaMo 2 1BをSFTする

それでは本題のGoogle ColabでPLaMo 2 1BをSFTする方法について説明します。今回はGPUメモリの関係上、おそらく無料で使えるT4だと無改造では実行できない気がするのでL4を使った説明をします。

L4 GPUの利用

まず、Google ColabでL4が使えるように、課金が必要になります。

課金についてはこちらをご覧ください。

https://colab.research.google.com/signup?hl=ja

今回のコードを動かすだけであれば「Pay As You Go」で100 コンピューティング ユニットを購入すれば十分です。この記事を執筆時点では1200円に満たない程度で購入できます。

課金が済んだら、メニューバーから「ランタイム」→「ランタイムのタイプを変更」をクリックします。すると無料枠では選択できないL4 GPUが選択できるようになっていると思うので、L4 GPUを選択します。

これでGPUを使う準備ができました。

実行環境準備

L4を利用するようにしたら、実行するコードのダウンロードやPythonパッケージのインストールを行います。

まずGithubよりコードをcloneしてきます

!git clone https://github.com/shu65/plamo-2-1b-sft-example.git

次に、PyTorchのバージョンを現在の最新版よりも前の以下のものに変更します。

!pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu124

この後は以下のようにPyTorch以外のPLaMo 2 1Bの実行に必要なパッケージやSFTに必要なパッケージなどをインストールします。

!pip install -r plamo-2-1b-sft-example/requirements.txt

ここまで実行すると2025/02/12現在以下のようなバージョンがインストールされました。

causal-conv1d                      1.5.0.post8
fastrlock                          0.8.3
mamba-ssm                          2.2.4
numba                              0.61.0
numba-cuda                         0.0.17.1
sentence-transformers              3.4.1
torch                              2.4.1+cu124
torchaudio                         2.4.1+cu124
torchsummary                       1.5.1
torchvision                        0.19.1+cu124
transformers                       4.48.2
trl                                0.14.0

これであとはSFTのコードを実行すれば、SFTをすることができます。このSFTの中身に関しては次で紹介していきます。

PLaMo 2 1BをSFTする

SFTをする部分はsft.py というスクリプトにまとめてあります。このスクリプトの重要な部分について簡単にですが説明していきます。

まず、今回はすぐに実行が終わるように少量の質問と回答のペアのデータを用います。

今回は日本語の指示学習でよく使われるkunishou/databricks-dolly-15k-jaというデータセットのうち、input がなくinstructionoutput のペアになっているデータのみを取り出しその一部だけを利用します。一つ例を見せると以下のようなデータを利用します。

{
  "output": "イコクエイラクブカ",
  "input": "",
  "index": "1",
  "category": "classification",
  "instruction": "魚の種類はどっち?イコクエイラクブカとロープ"
}

一部だけ取り出すコードは以下の通りです。

    dataset = datasets.load_dataset("kunishou/databricks-dolly-15k-ja")
    train_dataset = dataset["train"].filter(lambda data: data["input"] == "")

次にSFTConfig というSFTの実行の設定のクラスのインスタンスを用意します。具体的には以下の通りです。

    sft_args = SFTConfig(
        output_dir="./outputs",
        evaluation_strategy="no",
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        learning_rate=5e-5,
        num_train_epochs=0.1,
        lr_scheduler_type="cosine",
        warmup_ratio=0.3,
        logging_steps=10,
        save_strategy="epoch",
        report_to="tensorboard",
        bf16=True,
        max_seq_length=1024,
        gradient_checkpointing=True,
        deepspeed='./deepspeed_config.json',
    )

重要なこととして、今回はGPUのメモリが少ないため、DeepSpeedのStage 3という学習時に一部のデータをCPU側に置いておくモードを利用します。

これによりGPUメモリが少ない環境でもSFTを回すことができます。

DeepSpeed周りの設定はdeepspeed_config.json に書いてありますので気になる方はご覧ください。

また、今回は学習データの10%だけを利用するようにしています。これはこの学習を早く終わらせるためであり、本来はもっと回す必要があると考えられますので、本気でSFTをする場合は注意してください。

次にデータをどのようなフォーマットでLLMに入力するかを指定するformatting_func という関数を用意します。今回は以下のようにしました。

INSTRUCTION_TEMPLATE = string.Template(
    """### Question:
${input}

### Answer:
${response}<|plamo:eos|>
"""
)


def formatting_func(examples):
    output_texts = []
    for i in range(len(examples['instruction'])):
        text = INSTRUCTION_TEMPLATE.substitute(input=examples['instruction'][i], response=examples['output'][i])
        output_texts.append(text)
    return output_texts

INSTRUCTION_TEMPLATE が今回のフォーマットで、### Question:\n の後に指示、### Answer:\n のあとに回答が続き、最後にend of sequenceである<|plamo:eos|> が来るようになっています。

また、学習時には回答部分だけを学習してほしいので、どこからが回答かがわかるように‎DataCollatorForCompletionOnlyLM のインスタンスも用意します。これは以下の通りです。

    data_collator = DataCollatorForCompletionOnlyLM(
        response_template=tokenizer.encode(" Answer:\n", add_special_tokens=False),
        tokenizer=tokenizer
    )

response_template のところで回答前の部分がどのようなtoken idになるかを指定する部分があるので、上記のように指定します。前後の文字の影響で指定したtoken idが出現しないケースがあるので、その時はいろいろresponse_template に指定する文字列を調整してみてください。

最後にSFTを実行するためのクラスの‎SFTTrainer を以下のように用意します。

    trainer = SFTTrainer(
        model=model,
        args=sft_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        formatting_func=formatting_func,
    )

そして、以下のように実行し、結果を保存します。

    trainer.train()
    trainer.save_model()

これで学習が終わるとSFTConfigoutput_dir で指定した./outputs に結果が出力されます。試しに私がGoogle Colabで実行した際は13分程度で学習が終わりました。コンピューティングユニットとしてはパッケージなどのインストールも含めて4だけ消費しました。

SFTされたモデルで推論してみる

最後にSFTされたモデルで推論するというのを行います。

これはPLaMo 2 1Bのexampleとほぼ同じでpromptだけ少し変えたものを例として用います。コードとしては以下の通りです。

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch


model_name = "./plamo-2-1b-sft-example/outputs"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)


# プロンプトの準備
prompt = "### Question:\n埼玉の県庁所在地は何市?\n\n### Answer:\n"

# 推論の実行
inputs = tokenizer(prompt, return_tensors="pt")
generated_tokens = model.generate(
    **inputs,
    max_new_tokens=64,
    pad_token_id=tokenizer.pad_token_id,
)[0]
generated_text = tokenizer.decode(generated_tokens)
print(generated_text)

出力結果は以下のようになります。

<|plamo:bos|>### Question:
埼玉の県庁所在地は何市?

### Answer:
埼玉県の県庁所在地はさいたま市です。<|plamo:eos|>

ちゃんと学習で指定されたように### Answer:\n の後に質問に対する回答をし、その後<|plamo:eos|> を出力するということができています。

ちなみにSFTしていないモデルではどうなるかというと、以下のように余計なことを出力するうえ、出力が止まらないという状態になっています。

<|plamo:bos|>### Question:
埼玉の県庁所在地は何市?

### Answer:
さいたま市

### 解説
「県庁所在地」とは、都道府県庁が置かれている都市のことです。
「さいたま市」は埼玉県の県庁所在地です。

### 関連記事
### 取り急ぎお知らせ
「埼玉の県庁所在地は何市?」の解説は以上です。
「埼玉の県庁所在地は何市?」の解説は以上です。

このため、SFTでうまくフォーマットに従うよう学習できたと考えられます。

終わりに

今回はPLaMo 2 1Bを使ってSFTをする例を示しました。今回示したように簡単なSFTなら十分Google Colabで実行することができます。みなさんもぜひいろいろ試していただければと思います。

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

このサイトは reCAPTCHA によって保護されており、Google のプライバシーポリシー および 利用規約 に適用されます。

reCaptcha の認証期間が終了しました。ページを再読み込みしてください。