diff --git a/ginka/dataset.py b/ginka/dataset.py index d811293..ea050d4 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -230,7 +230,6 @@ class GinkaRNNDataset(Dataset): target = torch.LongTensor(item['map']) # [H, W] H, W = target.shape - target = target.reshape(H * W) # [T] tag_cond = torch.FloatTensor(item['tag']) val_cond = torch.FloatTensor(item['val']) val_cond[9] = val_cond[9] / H / W diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 510b01f..8691daa 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -49,7 +49,7 @@ class GinkaMapPatch(nn.Module): map: [B, H, W] """ B, H, W = map.shape - result = torch.zeros([B, 5, 5], dtype=torch.long, device=map.device) + result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device) left = x - 2 if x >= 2 else 0 right = x + 3 if x < self.width - 2 else self.width top = y - 4 if y >= 4 else 0 @@ -89,8 +89,8 @@ class GinkaPosEmbedding(nn.Module): self.col_embedding = nn.Embedding(height, embed_dim) def forward(self, x: torch.Tensor, y: torch.Tensor): - row = self.row_embedding(x) - col = self.col_embedding(y) + row = self.row_embedding(x).squeeze(1) + col = self.col_embedding(y).squeeze(1) return row, col @@ -169,21 +169,21 @@ class GinkaRNNModel(nn.Module): B, C = val_cond.shape # 张量声明 - now_tile = torch.LongTensor([self.start_tile], device=self.device).expand(B, -1) + now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B) - map = torch.zeros([B, self.height, self.width], dtype=torch.int32, device=self.device) - output_logits = torch.zeros([B, self.height, self.width, self.tile_classes], device=self.device) - hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden, device=self.device) + map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device) + output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device) + hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device) # 条件编码,全局,所以只用一次 cond = self.cond(val_cond) for y in range(0, self.height): for x in range(0, self.width): - x_tensor = torch.LongTensor([x], device=self.device) - y_tensor = torch.LongTensor([y], device=self.device) + x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1) + y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1) # 位置编码、图块编码、地图局部编码 - tile_embed = self.tile_embedding(now_tile if use_self else target_map[:, y, x]) + tile_embed = self.tile_embedding(target_map[:, y, x]) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) map_patch = self.map_patch(map, x, y) # 编码特征融合 @@ -194,9 +194,9 @@ class GinkaRNNModel(nn.Module): output_logits[:, y, x] = logits[:] hidden[:] = h[:] probs = F.softmax(logits, dim=1) - tile_id = torch.argmax(probs, dim=1) + tile_id = torch.argmax(probs, dim=1).detach() map[:, y, x] = tile_id[:] - now_tile[:] = tile_id[:] + # now_tile[:] = tile_id[:] return output_logits, map diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 8badc99..45b0d62 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -114,7 +114,8 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) + with torch.autograd.set_detect_anomaly(True): + fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) loss = criterion.rnn_loss(fake_logits, target_map)