LLM – SFT workflow fine-tuning workflow

Table of Contents

I. Introduction

2. Workflow dismantling by process

1. Workflow code

2.Workflow disassembly

◆ Hyperparameter initialization

◆Data set initialization

◆ Loading and quantification

◆ Data set preprocessing

◆ DataCollator

◆ Model fine-tuning sft

3. Summary

1. Introduction

Previously, we have analyzed and coded examples for each step of the LLM-related process. The following combines the code to organize the above parts into a workflow, and gives the complete workflow in the framework so that everyone can become familiar with the process of the LLM training process.


The data set and code in this article mainly refer to Github LLaMA-Efficient-Tuning.

2. Workflow dismantling by process

1. Workflow code

Only the workflow of SFT fine-tuning is given here. For more complete code, you can refer to the git project in the introduction, or the HF Transformer code given at the top of the code.

# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py

from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments

from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.misc import get_logits_processor
from llmtuner.extras.ploting import plot_loss
from llmtuner.tuner.core import load_model_and_tokenizer
from llmtuner.tuner.sft.metric import ComputeMetrics
from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer

    from transformers import TrainerCallback
    from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments

# 1. Get parameters through parser
def run_sft(
    model_args: "ModelArguments",
    data_args: "DataArguments",
    training_args: "Seq2SeqTrainingArguments",
    finetuning_args: "FinetuningArguments",
    generating_args: "GeneratingArguments",
    callbacks: Optional[List["TrainerCallback"]] = None
    # 2.Get Batch DataSet
    dataset = get_dataset(model_args, data_args)
    # 3.Load Lora Model And Bit or Not
    model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
    # 4.Process Dataset
    dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
    # 5.Data Collator
    data_collator = DataCollatorForSeq2Seq(
        label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id

    # 6.Training Args conversion
    # Override the decoding parameters of Seq2SeqTrainer
    training_args_dict = training_args.to_dict()
        generation_max_length=training_args.generation_max_length or data_args.max_target_length,
        generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams
    training_args = Seq2SeqTrainingArguments(**training_args_dict)

    #Initialize our Trainer
    trainer = Seq2SeqPeftTrainer(
        compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
        **split_dataset(dataset, data_args, training_args)

    # Keyword arguments for `model.generate`
    gen_kwargs = generating_args.to_dict()
    gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
    gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
    gen_kwargs["logits_processor"] = get_logits_processor()

    if training_args.do_train:
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
        if trainer.is_world_process_zero() and model_args.plot_loss:
            plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])

    if training_args.do_eval:
        metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
        if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
            metrics.pop("eval_loss", None)
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    if training_args.do_predict:
        predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
        if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
            predict_results.metrics.pop("predict_loss", None)
        trainer.log_metrics("predict", predict_results.metrics)
        trainer.save_metrics("predict", predict_results.metrics)

2.Workflow disassembly

Hyperparameter initialization

Model, Data, Training, Generate Agruments hyperparameter analysis icon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/132755841 ?spm=1001.2014.3001.5501

In addition to passing the address or path corresponding to the model, here we mainly pass the relevant training parameters, fine-tuning parameters, generation parameters, etc.

Data set initialization

Load datasets in batches and merge icon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/132825731?spm=1001.2014.3001.5501

data_args contains relevant dataset parameters. We load the alpaca_data_zh_51k.json dataset here:

Let’s take the first 5 lines of output and view the dataset:

def show(dataset):
    show_info = dataset.select(range(5))
    for row in show_info:

There are many columns given by features. We mainly focus on the prompt word, query question and response reply.

Loading and Quantization

Model Load_in_8bit or 4biticon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/132490630?spm=1001.2014.3001.5501

The specific logic of the function here can be found in the link given earlier. Mainly responsible for obtaining model parameters from model_args, and obtaining fine-tuning related parameters from finetuning_args, such as lora_target, lora_rank, etc. The model is loaded through HF’s Auto component, and the Lora model is implemented through the Peft library.

Base Model For Baichuan:

The relevant model configuration printed after the model is loaded, you can see the model type, some Special Token IDs and the previously mentioned silu activation function, etc. We did not use a quantization model here, but the new Baichuan2 provides 8bit and 4bit online quantization and offline quantization solutions for everyone to choose from.

LoRA Info For Baichuan:

Since it is SFT fine-tuning, the LoRA module is added through peft. Here lora_target is ‘W_pack’, which also prints out the proportion of our fine-tuning parameters to the total parameters.

Data set preprocessing

Process Dataset For LLM With PT, SFT, RMicon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/132830908? spm=1001.2014.3001.5501

Because data preprocessing requires the tokenizer corresponding to the model, the model and tokenizer need to be loaded first. Here, our recent article introduces the processing methods of the three mode data sets of SFT, PT, and RM. Run the code similarly to see the first 5 rows of data. What does it look like after prepross:

After processing, the dataset only contains relevant content required by SFT. input_ids is the token ids corresponding to input, where input is prompt + “\t” + query + response, and labels mask all parts except response.

以第一条记录为例,input_ids 为prompt + query + response,label_ids 将对应的token 用-100 的IGNORE_INDEX 替换,其对应的token 为,最后结尾处的 对应token id 为2, So the sentences all end with 2.


DataCollator sample generationicon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/131701620?spm=1001.2014.3001.5501

The model’s trainer also needs data_collator to generate corresponding training data, where Tokenizer and corresponding pad_token_ids are specified.

Model fine-tuning sft

Baichuan7B Lora training detailed explanationicon-default.png?t=N7T8https://blog.csdn.net/BIT_666/article/details/131675165?spm=1001.2014.3001.5501

Training is mainly inherited from from transformers import Seq2SeqTrainer:

split_dataset is responsible for dividing the data set into train and eval parts:

Model training log:

Three. Summary

Making big models is like a roller coaster. I think it is very powerful, but the structure is just a stack of Transformers; I think it can be easily trained and fine-tuned, but it requires strong strength and financial resources to afford it; the workflow code looks very logical, but in fact it contains a lot of small details. Details are worth learning. The tangled learning process, just watch and learn.

syntaxbug.com © 2021 All Rights Reserved.