Table of Contents
I. Introduction
2. Workflow dismantling by process
1. Workflow code
2.Workflow disassembly
◆ Hyperparameter initialization
◆Data set initialization
◆ Loading and quantification
◆ Data set preprocessing
◆ DataCollator
◆ Model fine-tuning sft
3. Summary
1. Introduction
Previously, we have analyzed and coded examples for each step of the LLM-related process. The following combines the code to organize the above parts into a workflow, and gives the complete workflow in the framework so that everyone can become familiar with the process of the LLM training process.
Tips:
The data set and code in this article mainly refer to Github LLaMA-Efficient-Tuning.
2. Workflow dismantling by process
1. Workflow code
Only the workflow of SFT fine-tuning is given here. For more complete code, you can refer to the git project in the introduction, or the HF Transformer code given at the top of the code.
# Inspired by: https://github.com/huggingface/transformers/blob/v4.29.2/examples/pytorch/summarization/run_summarization.py from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss from llmtuner.tuner.core import load_model_and_tokenizer from llmtuner.tuner.sft.metric import ComputeMetrics from llmtuner.tuner.sft.trainer import Seq2SeqPeftTrainer if TYPE_CHECKING: from transformers import TrainerCallback from llmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments # 1. Get parameters through parser def run_sft( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: Optional[List["TrainerCallback"]] = None ): # 2.Get Batch DataSet dataset = get_dataset(model_args, data_args) # 3.Load Lora Model And Bit or Not model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") # 4.Process Dataset dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") # 5.Data Collator data_collator = DataCollatorForSeq2Seq( tokenizer=tokenizer, label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id ) # 6.Training Args conversion # Override the decoding parameters of Seq2SeqTrainer training_args_dict = training_args.to_dict() training_args_dict.update(dict( generation_max_length=training_args.generation_max_length or data_args.max_target_length, generation_num_beams=data_args.eval_num_beams or training_args.generation_num_beams )) training_args = Seq2SeqTrainingArguments(**training_args_dict) #Initialize our Trainer trainer = Seq2SeqPeftTrainer( finetuning_args=finetuning_args, model=model, args=training_args, tokenizer=tokenizer, data_collator=data_collator, callbacks=callbacks, compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, **split_dataset(dataset, data_args, training_args) ) # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() #Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) trainer.log_metrics("train", train_result.metrics) trainer.save_metrics("train", train_result.metrics) trainer.save_state() trainer.save_model() if trainer.is_world_process_zero() and model_args.plot_loss: plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) #Evaluation if training_args.do_eval: metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled metrics.pop("eval_loss", None) trainer.log_metrics("eval", metrics) trainer.save_metrics("eval", metrics) #Predict if training_args.do_predict: predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) trainer.save_predictions(predict_results)
2.Workflow disassembly
◆ Hyperparameter initialization
Model, Data, Training, Generate Agruments hyperparameter analysis https://blog.csdn.net/BIT_666/article/details/132755841 ?spm=1001.2014.3001.5501
In addition to passing the address or path corresponding to the model, here we mainly pass the relevant training parameters, fine-tuning parameters, generation parameters, etc.
◆ Data set initialization
Load datasets in batches and merge https://blog.csdn.net/BIT_666/article/details/132825731?spm=1001.2014.3001.5501
data_args contains relevant dataset parameters. We load the alpaca_data_zh_51k.json dataset here:
Let’s take the first 5 lines of output and view the dataset:
def show(dataset): show_info = dataset.select(range(5)) print(show_info) for row in show_info: print(row)
There are many columns given by features. We mainly focus on the prompt word, query question and response reply.
◆ Loading and Quantization
Model Load_in_8bit or 4bithttps://blog.csdn.net/BIT_666/article/details/132490630?spm=1001.2014.3001.5501
The specific logic of the function here can be found in the link given earlier. Mainly responsible for obtaining model parameters from model_args, and obtaining fine-tuning related parameters from finetuning_args, such as lora_target, lora_rank, etc. The model is loaded through HF’s Auto component, and the Lora model is implemented through the Peft library.
Base Model For Baichuan:
The relevant model configuration printed after the model is loaded, you can see the model type, some Special Token IDs and the previously mentioned silu activation function, etc. We did not use a quantization model here, but the new Baichuan2 provides 8bit and 4bit online quantization and offline quantization solutions for everyone to choose from.
LoRA Info For Baichuan:
Since it is SFT fine-tuning, the LoRA module is added through peft. Here lora_target is ‘W_pack’, which also prints out the proportion of our fine-tuning parameters to the total parameters.
◆ Data set preprocessing
Process Dataset For LLM With PT, SFT, RMhttps://blog.csdn.net/BIT_666/article/details/132830908? spm=1001.2014.3001.5501
Because data preprocessing requires the tokenizer corresponding to the model, the model and tokenizer need to be loaded first. Here, our recent article introduces the processing methods of the three mode data sets of SFT, PT, and RM. Run the code similarly to see the first 5 rows of data. What does it look like after prepross:
After processing, the dataset only contains relevant content required by SFT. input_ids is the token ids corresponding to input, where input is prompt + “\t” + query + response, and labels mask all parts except response.
以第一条记录为例,input_ids 为prompt + query + response,label_ids 将对应的token 用-100 的IGNORE_INDEX 替换,其对应的token 为 对应token id 为2, So the sentences all end with 2.
◆ DataCollator
DataCollator sample generationhttps://blog.csdn.net/BIT_666/article/details/131701620?spm=1001.2014.3001.5501
The model’s trainer also needs data_collator to generate corresponding training data, where Tokenizer and corresponding pad_token_ids are specified.
◆ Model fine-tuning sft
Baichuan7B Lora training detailed explanationhttps://blog.csdn.net/BIT_666/article/details/131675165?spm=1001.2014.3001.5501
Training is mainly inherited from from transformers import Seq2SeqTrainer:
split_dataset is responsible for dividing the data set into train and eval parts:
Model training log:
Three. Summary
Making big models is like a roller coaster. I think it is very powerful, but the structure is just a stack of Transformers; I think it can be easily trained and fine-tuned, but it requires strong strength and financial resources to afford it; the workflow code looks very logical, but in fact it contains a lot of small details. Details are worth learning. The tangled learning process, just watch and learn.