Large model | NEFTune’s benefits of introducing random noise to large model training

Large model | NEFTune’s benefits of introducing random noise to large model training

The paper mentioned that adding moderate random noise to inputs_embedding during the model foward process will bring significant benefits.

Paper: https://arxiv.org/pdf/2310.05914.pdf
Github: https://github.com/neelsjain/NEFTune

Article directory

  • Large model | NEFTune’s benefits of introducing random noise to large model training
  • theory
  • 1. Practical methods
    • 1.1 Waiting for Hugging to release this function
    • 1.2 Directly encapsulate the model
    • 1.3 Rewrite compute_loss

Theory

The core is that after the input passes through the Embedding layer, a uniformly distributed noise is added. The sampling range of the noise is

[

?

α

L

d

,

α

L

d

]

[-\frac{\alpha}{\sqrt{Ld}},\frac{\alpha}{\sqrt{Ld}}]

[?Ld
?α?,Ld
?α?], where

α

\alpha

α is the noise hyperparameter, L is the input length, and d is the Embedding layer dimension (ie hidden dimension)

On the AlpacaEval list, GPT4 is used as the scorer to fine-tune the Llama2-7B model on multiple data. The NEFTune method has significantly improved compared to the direct fine-tuning method.

It can alleviate the over-fitting phenomenon of the model in the instruction fine-tuning phase and make better use of the knowledge content in the pre-training phase.

1. Practical methods

1.1 Waiting for Hugging to release this function

Progress: Waiting for hugging face to officially release this feature, 2023-10-26

[10/17/2023] NEFTune has been intregrated into the Huggingface’s TRL (Transformer Reinforcement Learning) library. See Annoucement.

1.2 Directly encapsulate model

Progress: Directly encapsulate the model as follows. The principle is to rewrite model.embed_tokens.forward(). After practice, this method does not work and will report a stack overflow error.

from torch.nn import functional as F

def NEFTune(model, noise_alpha=5)
    def noised_embed(orig_embed, noise_alpha):
        def new_func(x):
            # during training, we add noise to the embedding
            # during generation, we don't add noise to the embedding
            if model.training:
                embed_init = orig_embed(x)
                dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
                mag_norm = noise_alpha/torch.sqrt(dims)
                return embed_init + torch.zeros_like(embed_init).uniform_(-mag_norm, mag_norm)
            else:
                return orig_embed(x)
        return new_func
    ##### NOTE: this is for a LLaMA model #####
    ##### For a different model, you need to change the attribute path to the embedding #####
    model.base_model.model.model.embed_tokens.forward = noised_embed(model.base_model.model.model.embed_tokens, noise_alpha)
    return model

1.3 Rewrite compute_loss

Progress: loss can be calculated normally, but optimzer will report an error, which may be related to accuracy and has not yet been resolved.

Since the loss function is written by myself, try to append the noise code before model(**input). Note that the input_ids were originally passed into the model, but now because we have added noise to inputs_embeds, the input into the model will be directly replaced by inputs_embeds. The code is as follows

class TargetLMLossNeft(Loss):

    def __init__(self, ignore_index):
        super().__init__()
        self.ignore_index = ignore_index
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=ignore_index)

    def __call__(self, model, inputs, training_args, return_outputs=False):
        input_ids = inputs['input_ids'] # B x L [3, 964]
        attention_mask = inputs['attention_mask'] # B x L
        target_mask = inputs['target_mask'] # B x L

        ### -------------------------- add noise to embeds
        neftune_alpha = 5
        embed_device = model.base_model.model.model.embed_tokens.weight.device
        embeds_init = model.base_model.model.model.embed_tokens.forward(input_ids).to(embed_device) # First forward, become B X L X hidden_state
        # embed_device = model.model.embed_tokens.weight.device
        # embeds_init = model.model.embed_tokens.forward(input_ids).to(embed_device)

        input_mask = attention_mask.to(embeds_init) # B x L
        input_lengths = torch.sum(input_mask, 1) # B, calculate the actual length of each sample
        
        noise_ = torch.zeros_like(embeds_init).uniform_(-1,1) # B X L X hidden_state, and the value range is normal distribution in [-1,1]
        delta = noise_ * input_mask.unsqueeze(2) #Add a dimension, from B X L to B X L X hidden_state
        dims = input_lengths * embeds_init.size(-1)
        mag = neftune_alpha / torch.sqrt(dims)
        delta = (delta * mag.view(-1, 1, 1)).detach() # B X L X hidden_state
        inputs_embeds = delta + embeds_init
        ### -------------------------- add noise to embeds
        

        # For model feedforward prediction, input_ids were originally passed in, but now we need to directly pass in inputs_embeds with added noise.
        # outputs = model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        outputs = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, return_dict=True)
        logits = outputs["logits"] if isinstance(outputs, dict) else outputs[0] # Normally it should be torch.float32
        #logits.requires_grad = True # Strange, why does it default to False? Is it because of the detach() above?

        # Set the part of labels that does not belong to the target as ignore_index, and only calculate the loss of the target part
        labels = torch.where(target_mask == 1, input_ids, self.ignore_index)
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        loss = self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # float32
        loss.requires_grad = True
        return (loss, outputs) if return_outputs else loss