mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: rnn train
This commit is contained in:
parent
fa8ded2ecd
commit
771553b8b8
@ -230,7 +230,6 @@ class GinkaRNNDataset(Dataset):
|
||||
|
||||
target = torch.LongTensor(item['map']) # [H, W]
|
||||
H, W = target.shape
|
||||
target = target.reshape(H * W) # [T]
|
||||
tag_cond = torch.FloatTensor(item['tag'])
|
||||
val_cond = torch.FloatTensor(item['val'])
|
||||
val_cond[9] = val_cond[9] / H / W
|
||||
|
||||
@ -49,7 +49,7 @@ class GinkaMapPatch(nn.Module):
|
||||
map: [B, H, W]
|
||||
"""
|
||||
B, H, W = map.shape
|
||||
result = torch.zeros([B, 5, 5], dtype=torch.long, device=map.device)
|
||||
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
||||
left = x - 2 if x >= 2 else 0
|
||||
right = x + 3 if x < self.width - 2 else self.width
|
||||
top = y - 4 if y >= 4 else 0
|
||||
@ -89,8 +89,8 @@ class GinkaPosEmbedding(nn.Module):
|
||||
self.col_embedding = nn.Embedding(height, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
row = self.row_embedding(x)
|
||||
col = self.col_embedding(y)
|
||||
row = self.row_embedding(x).squeeze(1)
|
||||
col = self.col_embedding(y).squeeze(1)
|
||||
|
||||
return row, col
|
||||
|
||||
@ -169,21 +169,21 @@ class GinkaRNNModel(nn.Module):
|
||||
B, C = val_cond.shape
|
||||
|
||||
# 张量声明
|
||||
now_tile = torch.LongTensor([self.start_tile], device=self.device).expand(B, -1)
|
||||
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
|
||||
|
||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32, device=self.device)
|
||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes], device=self.device)
|
||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden, device=self.device)
|
||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
|
||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
||||
|
||||
# 条件编码,全局,所以只用一次
|
||||
cond = self.cond(val_cond)
|
||||
|
||||
for y in range(0, self.height):
|
||||
for x in range(0, self.width):
|
||||
x_tensor = torch.LongTensor([x], device=self.device)
|
||||
y_tensor = torch.LongTensor([y], device=self.device)
|
||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile if use_self else target_map[:, y, x])
|
||||
tile_embed = self.tile_embedding(target_map[:, y, x])
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
map_patch = self.map_patch(map, x, y)
|
||||
# 编码特征融合
|
||||
@ -194,9 +194,9 @@ class GinkaRNNModel(nn.Module):
|
||||
output_logits[:, y, x] = logits[:]
|
||||
hidden[:] = h[:]
|
||||
probs = F.softmax(logits, dim=1)
|
||||
tile_id = torch.argmax(probs, dim=1)
|
||||
tile_id = torch.argmax(probs, dim=1).detach()
|
||||
map[:, y, x] = tile_id[:]
|
||||
now_tile[:] = tile_id[:]
|
||||
# now_tile[:] = tile_id[:]
|
||||
|
||||
return output_logits, map
|
||||
|
||||
|
||||
@ -114,7 +114,8 @@ def train():
|
||||
val_cond = batch["val_cond"].to(device)
|
||||
target_map = batch["target_map"].to(device)
|
||||
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
with torch.autograd.set_detect_anomaly(True):
|
||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||
|
||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user