reason
The event file is too large to trace a total model, which couldn’t be opened by brower, trace every functions in the model is an alternative method.
here to donwload pytorch_utils.py
import torch import torch.nn as nn from pytorch_utils.modules import MLP EPS = 1e-8 class WorkingMemory(nn.Module): def __init__(self, device='cpu', mem_type='vanilla', num_cells=10, mem_size=300, mlp_size=300, dropout_rate=0.5, key_size=20, usage_decay_rate=0.98, **kwargs): super(WorkingMemory, self).__init__() self.device = device self.mem_type = mem_type self.num_cells = num_cells self.mem_size = mem_size self.mlp_size = mlp_size self.usage_decay_rate = usage_decay_rate # Dropout module self.drop_module = nn.Dropout(p=dropout_rate, inplace=False) if self.mem_type == 'learned': self.init_mem = nn.Parameter(torch.zeros(self.num_cells, self.mem_size)) elif self.mem_type == 'key_val': self.key_size = key_size #Initialize key and value vectors self.init_key = nn.Parameter(torch.zeros(self.num_cells, key_size)) # MLP to determine entity or not self.entity_mlp = MLP(self.mem_size, mlp_size, 1, num_layers=2, bias=True) # MLP to merge past memory and current candidate to write new memory self.U_key = nn.Linear(2 * mem_size, mem_size, bias=True) # MLP to determine coref similarity between current token and memory self.sim_mlp = MLP(3 * self.mem_size + 1, mlp_size, 1, num_layers=2, bias=True) self.gumbel_temperature = nn.Parameter(torch.tensor([1.0]), requires_grad=False) def initialize_memory(self, batch_size): """Initialize the memory with the learned key and the null value part.""" init_mem = torch.zeros(batch_size, self.num_cells, self.mem_size).to(self.device) if self.mem_type == 'learned': init_mem = self.init_mem.unsqueeze(dim=0) init_mem = init_mem.repeat(batch_size, 1, 1) elif self.mem_type == 'key_val': init_val = torch.zeros(batch_size, self.num_cells, self.mem_size - self.key_size).to(self.device) init_key = self.init_key.unsqueeze(dim=0) init_key = init_key.repeat(batch_size, 1, 1) init_mem = torch.cat([init_key, init_val], dim=2) init_usage = torch.zeros(batch_size, self.num_cells).to(self.device) return (init_mem, init_usage) def sample_gumbel(self, shape, eps=EPS): U = torch.rand(shape).to(self.device) return -torch.log(-torch.log(U + eps) + eps) def pick_overwrite_cell(self, usage, sim_score): """Pick cell to overwrite. - Prefer unused cells. - Break ties using similarity. """ norm_sim_score = nn.functional.softmax(sim_score, dim=-1) # Assign overwrite scores to each cell. # (1) Prefer cells which have not been used. # (2) Among the unused cells, prefer ones with the higher similarly score # (Useful for memory with learned initialization). # (3) Otherwise prefer cells with least usage. overwrite_score = ((usage == 0.0).float() * norm_sim_score * 1e5) + (1 - usage) if self.training: logits = torch.log(overwrite_score * (1 - EPS) + EPS) gumbel_noise = self.sample_gumbel(usage.size()) y = nn.functional.softmax( (logits + gumbel_noise) / self.gumbel_temperature, dim=-1) else: max_val = torch.max(overwrite_score, dim=-1, keepdim=True)[0] # Randomize the max index = torch.argmax( (torch.empty(overwrite_score.shape).uniform_(0.01, 1).to(self.device) * (overwrite_score == max_val).float()), dim=-1, keepdim=True) y = torch.zeros_like(overwrite_score).scatter_(-1, index, 1.0) return y def predict_entity_prob(self, cur_hidden_state): """Predicts whether the current word is (part of) an entity or not.""" ent_score = self.entity_mlp(cur_hidden_state) # Perform a softmax over scores of 0 and ent_score comb_score = torch.cat([torch.zeros_like(ent_score).to(self.device), ent_score], dim=1) # Numerically stable softmax max_score, _ = torch.max(comb_score, dim=1, keepdim=True) ent_prob = nn.functional.softmax(comb_score - max_score, dim=1) # We only care about the 2nd column i.e. corresponding to ent_score ent_prob = torch.unsqueeze(ent_prob[:, 1], dim=1) return ent_score, ent_prob def get_coref_mask(self, usage): """No coreference with empty cells.""" cell_mask = (usage > 0).float().to(self.device) return cell_mask def predict_coref_overwrite(self, mem_vectors, query_vector, usage, ent_prob): """Calculate similarity between query_vector and mem_vectors. query_vector: B x M x H mem_vectors: B x M x H """ pairwise_vec = torch.cat([mem_vectors, query_vector, query_vector * mem_vectors, torch.unsqueeze(usage, dim=2)], dim=-1) pairwise_score = self.sim_mlp(pairwise_vec) sim_score = pairwise_score # B x M x1 sim_score = torch.squeeze(sim_score, dim=-1) batch_size = query_vector.shape[0] base_score = torch.zeros((batch_size, 1)).to(self.device) comb_score = torch.cat([sim_score, base_score], dim=1) # Bx(M + 1) coref_mask = self.get_coref_mask(usage) # B x M # Coref only possible when the cell is active mult_mask = torch.cat([coref_mask, torch.ones((batch_size, 1)).to(self.device)], dim=-1) # Zero out the inactive cell scores and then add a big negative value comb_score = comb_score * mult_mask + (1 - mult_mask) * (-1e4) # Numerically stable softmax max_cell_score, _ = torch.max(comb_score, dim=1, keepdim=True) init_probs = nn.functional.softmax(comb_score - max_cell_score, dim=1) # Make sure the inactive cells are really zero even after logit of -1e4 masked_probs = init_probs * mult_mask norm_probs = ( masked_probs/(torch.sum(masked_probs, dim=-1, keepdim=True) + EPS)) coref_over_probs = ent_prob * norm_probs indv_coref_prob = coref_over_probs[:, :self.num_cells] overwrite_prob = coref_over_probs[:, self.num_cells] overwrite_prob = torch.unsqueeze(overwrite_prob, dim=1) return indv_coref_prob, overwrite_prob def forward(self, data_w): """Read excerpts. hidden_state_list: list of B x H tensors input_mask_list: list of B sized tensors """ hidden_state_list = data_w[0] input_mask_list = data_w[1] batch_size = hidden_state_list[0].shape[0] if self.mem_type == 'key_val': # Get initialized key vectors init_key = self.init_key.unsqueeze(dim=0) init_key = init_key.repeat(batch_size, 1, 1) #Initialize memory mem_vectors, usage = self.initialize_memory(batch_size) # Store all updates ent_list, usage_list, coref_list, overwrite_list = [], [], [], [] for t, (cur_hidden_state, cur_input_mask) in \ enumerate(zip(hidden_state_list, input_mask_list)): query_vector = self.drop_module(cur_hidden_state) ent_score, ent_prob = self.predict_entity_prob(query_vector) ent_prob = ent_prob * torch.unsqueeze(cur_input_mask, dim=1) ent_list.append(ent_prob * (1 - EPS) + EPS) rep_query_vector = query_vector.unsqueeze(dim=1) # B x M x H rep_query_vector = rep_query_vector.repeat(1, self.num_cells, 1) indv_coref_prob, new_ent_prob = self.predict_coref_overwrite( mem_vectors=mem_vectors, query_vector=rep_query_vector, usage=usage, ent_prob=ent_prob) coref_list.append(indv_coref_prob * (1 - EPS) + EPS) # Overwriting Prob - B x M pairwise_vec = torch.cat([mem_vectors, rep_query_vector, rep_query_vector * mem_vectors, torch.unsqueeze(usage, dim=2)], dim=-1) init_sim_score = torch.squeeze(self.sim_mlp(pairwise_vec), dim=-1) overwrite_prob = ( new_ent_prob * self.pick_overwrite_cell(usage, init_sim_score) ) try: assert (torch.max(overwrite_prob) <= 1) assert (torch.max(indv_coref_prob) <= 1) assert (torch.max(ent_prob) <= 1) except AssertionError: print("Assertion Error happened! Trying best to recover") return None #raise overwrite_list.append(overwrite_prob * (1 - EPS) + EPS) comb_inp = torch.cat([rep_query_vector, mem_vectors], dim=-1) mem_candidate = torch.tanh(self.U_key(comb_inp)) # B x M x H updated_mem_vectors = ( torch.unsqueeze(overwrite_prob, dim=2) * rep_query_vector + torch.unsqueeze(1 - overwrite_prob - indv_coref_prob, dim=2) *mem_vectors + torch.unsqueeze(indv_coref_prob, dim=2) * mem_candidate ) if self.mem_type == 'key_val': # Don't update the key dimensions. Only update the later dimensions. updated_mem_vectors = torch.cat( [init_key, updated_mem_vectors[:, :, self.key_size:]], dim=2) # Update usage updated_usage = torch.min( torch.FloatTensor([1.0]).to(self.device), overwrite_prob + indv_coref_prob + self.usage_decay_rate * usage) usage_list.append(updated_usage) #Update memory mem_vectors, usage = updated_mem_vectors, updated_usage # return {'ent': ent_list, 'usage': usage_list, # 'coref': coref_list, 'overwrite': overwrite_list} return (ent_list, usage_list, coref_list, overwrite_list) class PICK(nn.Module): def __init__(self, memory): super(PICK, self).__init__() self.memory = memory def forward(self, pick_data): usage, sim_score = pick_data return self.memory.pick_overwrite_cell(usage, sim_score) if __name__ == "__main__": from torch.utils.tensorboard import SummaryWriter hidden_state_list = tuple(list(torch.randn(99, 32, 300))) input_mask_list = tuple(list(torch.zeros(99, 32, ))) memory = WorkingMemory() data_w = (hidden_state_list, input_mask_list) model = PICK(memory) # result = model(data_w) # print(result[0]) pick_data = (torch.zeros(32, 10), torch.randn(32, 10)) result = model(pick_data) # print(result) writer = SummaryWriter('structure/pick_overwrite_cell') # Create a thing to save data writer.add_graph(model, (pick_data, )) writer.add_text("text", "hello, this is a text info", global_step=2) writer.close()