Reparameterization Trick and Gumbel-Softmax

Reparameterization Trick and Gumbel-Softmax

In the learning process of RL reinforcement learning, category sampling is needed in many places, that is, the neural network outputs the probability of each category, and then needs to be sampled, and the sampling process must be able to back propagate to find the gradient.

Or the neural network outputs the mean and standard deviation, and it is necessary to sample the normal distribution and be able to back propagate and find the gradient.

However, usual category sampling, or normal distribution sampling, is non-differentiable, and the gradient cannot be found. Other methods need to be used to find the gradient.

Reparameterization Trick reparameter sampling divides the sampling into two parts: Random + Determined

The random part is responsible for sampling, and the determination is used to calculate the sampling results and backpropagation gradient. The determination is the input. The random part does not need to find the gradient

The network below outputs the probability of the category. The take_action function gives the action of sampling based on probability. In this case, there is no need for backpropagation, so it can be sampled directly.

class PolicyNet(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNet, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc2 = torch.nn.Linear(hidden_dim, action_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        return F.softmax(self.fc2(x), dim=1) ## Returns the probability of the selected action in this state

    def take_action(self, state): # Randomly sample according to action probability distribution
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        probs = self.actor(state) ## Get the selection probability of each action in this state
        action_dist = torch.distributions.Categorical(probs) ## Configure the probability of good sampling
        action = action_dist.sample() ## Sample all actions in this state, and the probability of sampling is probs
        return action.item() ## Return the action sampled based on probability

Then there is the average and standard deviation output of the continuous action space, normal distribution sampling, take_action is sampled from the normal distribution, and does not require backpropagation of gradients, that is, it does not require derivation and can be sampled directly.

class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim) ## Used to find the mean of the normal distribution of each action
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim) ## Used to find the variance of the normal distribution of each action

    ## Requires the probability distribution of the swing action. The default is normal distribution.
    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = 2.0 * torch.tanh(self.fc_mu(x)) ## predict the mean of the normal distribution of each action, and the value range is [-2, 2]
        std = F.softplus(self.fc_std(x)) ## predict the variance of the normal distribution of each action, using the function softplus
        return mu, std # mean and standard deviation of Gaussian distribution

    def take_action(self, state): # Randomly sample according to action probability distribution
        state = torch.tensor([state], dtype=torch.float).to(self.device)
        mu, std = self.actor(state) ## Get the mean and standard deviation in this state
        action_dist = torch.distributions.Normal(mu, std) ## Configure the probability of good sampling
        action = action_dist.sample() ## Sample all actions in this state, and the sampling distribution is Gaussian distribution
        return [action.item()] ## Returns the sampled action value, which is the magnitude of the moment

The above two examples do not require derivation or gradient.

Reparameterization Trick

Re-parameter sampling. The figure below shows the process of re-parameter sampling. Direct sampling from the non-differentiable normal distribution is obtained, and differentiable re-parameter sampling is obtained. After re-parameter, random + determination, the random part is the standard normal distribution. , the determining part is the mean and standard deviation of the input.

The random part is

N

?

(

0

,

1

)

,

f

(

x

)

=

1

2

π

δ

e

?

(

x

?

u

)

2

2

δ

2

N-(0,1),f(x)=\frac{1}{\sqrt{2\pi}\delta}e^{-\frac{(x-u)^2}{2\ \delta^2}}

N?(0,1),f(x)=2π
?δ1?e?2δ2(x?u)2? , the standard normal distribution, multiplied by the standard deviation, plus the mean, is another normal distribution. After all, the standardization of the normal distribution is

x

?

N

(

u

,

δ

2

)

,

x

?

u

δ

?

N

(

0

,

1

)

x-N(u,\delta^2)_,\frac{x-u}{\delta}-N(0,1)

x?N(u,δ2),?δx?uN(0,1)

Heavy parameter sampling is generally used in normal distribution sampling, that is, continuous sampling.

Gumbel-Softmax

Gumbel softmax is discrete category sampling. When category sampling, the content of the corresponding gumbel distribution needs to be used, and softmax needs to be used for normalization. Usually discrete category sampling is not differentiable. But using the gumbel random part for sampling, the log value itself is differentiable.


Let’s look at the application of parameterization Reparameterization Trick

dist.rsample() is heavy parameter sampling, that is, random + determined. The random part is standard normal distribution sampling (0, 1), and then the mean and multiplied standard deviation are added

Random:

N

?

(

0

,

1

)

N-(0,1)

N?(0,1) standard normal distribution sampling, determine:

u

+

δ

x

u + \delta x

u + δx, so the random part does not need to be differentiated, only the determined mean and standard deviation need to be differentiated, and it is differentiable.

## Constructing a policy network
class PolicyNetContinuous(torch.nn.Module):
    def __init__(self, state_dim, hidden_dim, action_dim, action_bound):
        super(PolicyNetContinuous, self).__init__()
        self.fc1 = torch.nn.Linear(state_dim, hidden_dim)
        self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)
        self.fc_std = torch.nn.Linear(hidden_dim, action_dim)
        self.action_bound = action_bound

    def forward(self, x):
        x = F.relu(self.fc1(x))
        mu = self.fc_mu(x)
        std = F.softplus(self.fc_std(x)) ## Activation function, >0, similar to relu
        dist = Normal(mu, std)
        ## Re-parameter sampling is to sample from the standard normal distribution, then add the mean and multiply the standard deviation. While sampling, you can find the gradient of the mean and standard deviation.
        normal_sample = dist.rsample() # rsample() is re-parameterized sampling. Gaussian distribution sampling alone is not differentiable, but re-parameterization is differentiable.
        log_prob = dist.log_prob(normal_sample) ## Sample the probability from the distribution and calculate the log value of the probability
        action = torch.tanh(normal_sample) ## The action value output by the deterministic strategy is not probabilistic and is distributed between -1 and 1
        # Calculate the logarithmic probability density of tanh_normal distribution
        ## The opposite of entropy
        log_prob = log_prob - torch.log(1 - torch.tanh(action).pow(2) + 1e-7)
        action = action * self.action_bound ## Output action, scale to the range
        return action, log_prob

Then there is Gumbel-Softmax sampling, discrete sampling, which is used in the following program. The random part of the sampling is the gumbel distribution, and then the determined part is the softmax and Log values.

def onehot_from_logits(logits, eps=0.01):
    ''' Generate one-hot form of optimal action '''
    argmax_acs = (logits == logits.max(1, keepdim=True)[0]).float()
    # Generate random actions and convert them into one-hot form
    rand_acs = torch.autograd.Variable(torch.eye(logits.shape[1])[[
        np.random.choice(range(logits.shape[1]), size=logits.shape[0])
    ]],
                                       requires_grad=False).to(logits.device)
    # Use the epsilon-greedy algorithm to choose which action to use
    return torch.stack([
        argmax_acs[i] if r > eps else rand_acs[i]
        for i, r in enumerate(torch.rand(logits.shape[0]))
    ])

def sample_gumbel(shape, eps=1e-20, tens_type=torch.FloatTensor):
    """Sampling from the Gumbel(0,1) distribution"""
    U = torch.autograd.Variable(tens_type(*shape).uniform_(),
                                requires_grad=False)
    return -torch.log(-torch.log(U + eps) + eps) ## Gumbel distribution

def gumbel_softmax_sample(logits, temperature):
    """ Sampling from Gumbel-Softmax distribution"""
    ## The log value, plus the Gumbel distribution, is quite sampled.
    ## This sampling can calculate the gradient of logits, and the gumbel part has no gradient, which means that while category sampling is implemented, gradients can be backpropagated. Sampling has basically nothing to do with logits values.
    y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)).to(
        logits.device)
    return F.softmax(y / temperature, dim=1) ## softmax, use temperature to control the similarity with uniform distribution

def gumbel_softmax(logits, temperature=1.0):
    """Sample from the Gumbel-Softmax distribution and discretize it"""
    y = gumbel_softmax_sample(logits, temperature)
    y_hard = onehot_from_logits(y) ## One_hotization, which will be used later, no gradient is required
    '''
    detach does not seek gradients, so the variables in (y_hard.to(logits.device) - y) will not seek gradients.
    Only + y will find the gradient. How about finding the gradient only for + y?
    That is to say, you don’t need to pay attention to the one_hot step. The one_hot step does not require derivation.
    -y + y = 0
    So the next step is to avoid derivation of one_hot and only derivation of y
    '''
    y = (y_hard.to(logits.device) - y).detach() + y
    # Return a unique heat of y_hard, but its gradient is y. We can not only get a discrete action that interacts with the environment, but also
    # Properly propagate gradients back
    return y

Hands-on reinforcement learning?hrl.boyuai.com/

“Hands-on Reinforcement Learning” (Zhang Weinan, Shen Jian, Yu Yong) [Introduction_Book Review_Online Reading] – Dangdang Books (dangdang.com)?product.dangdang.com/29391150.html

ZouJiu1/Hands-on-RL: https://hrl.boyuai.com/ (github.com)?github.com/ZouJiu1/Hands-on-RL/tree/main

The Gumbel-Softmax Distribution – Emma Benjaminson – Mechanical Engineering Graduate Student (sassafras13.github.io)

pdf (openreview.net)

https://zhuanlan.zhihu.com/p/659339817