mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
feat: rnn train
This commit is contained in:
parent
fa8ded2ecd
commit
771553b8b8
@ -230,7 +230,6 @@ class GinkaRNNDataset(Dataset):
|
|||||||
|
|
||||||
target = torch.LongTensor(item['map']) # [H, W]
|
target = torch.LongTensor(item['map']) # [H, W]
|
||||||
H, W = target.shape
|
H, W = target.shape
|
||||||
target = target.reshape(H * W) # [T]
|
|
||||||
tag_cond = torch.FloatTensor(item['tag'])
|
tag_cond = torch.FloatTensor(item['tag'])
|
||||||
val_cond = torch.FloatTensor(item['val'])
|
val_cond = torch.FloatTensor(item['val'])
|
||||||
val_cond[9] = val_cond[9] / H / W
|
val_cond[9] = val_cond[9] / H / W
|
||||||
|
|||||||
@ -49,7 +49,7 @@ class GinkaMapPatch(nn.Module):
|
|||||||
map: [B, H, W]
|
map: [B, H, W]
|
||||||
"""
|
"""
|
||||||
B, H, W = map.shape
|
B, H, W = map.shape
|
||||||
result = torch.zeros([B, 5, 5], dtype=torch.long, device=map.device)
|
result = torch.zeros([B, 5, 5], dtype=torch.long).to(map.device)
|
||||||
left = x - 2 if x >= 2 else 0
|
left = x - 2 if x >= 2 else 0
|
||||||
right = x + 3 if x < self.width - 2 else self.width
|
right = x + 3 if x < self.width - 2 else self.width
|
||||||
top = y - 4 if y >= 4 else 0
|
top = y - 4 if y >= 4 else 0
|
||||||
@ -89,8 +89,8 @@ class GinkaPosEmbedding(nn.Module):
|
|||||||
self.col_embedding = nn.Embedding(height, embed_dim)
|
self.col_embedding = nn.Embedding(height, embed_dim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||||
row = self.row_embedding(x)
|
row = self.row_embedding(x).squeeze(1)
|
||||||
col = self.col_embedding(y)
|
col = self.col_embedding(y).squeeze(1)
|
||||||
|
|
||||||
return row, col
|
return row, col
|
||||||
|
|
||||||
@ -169,21 +169,21 @@ class GinkaRNNModel(nn.Module):
|
|||||||
B, C = val_cond.shape
|
B, C = val_cond.shape
|
||||||
|
|
||||||
# 张量声明
|
# 张量声明
|
||||||
now_tile = torch.LongTensor([self.start_tile], device=self.device).expand(B, -1)
|
now_tile = torch.LongTensor([self.start_tile]).to(self.device).expand(B)
|
||||||
|
|
||||||
map = torch.zeros([B, self.height, self.width], dtype=torch.int32, device=self.device)
|
map = torch.zeros([B, self.height, self.width], dtype=torch.int32).to(self.device)
|
||||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes], device=self.device)
|
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
||||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden, device=self.device)
|
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
||||||
|
|
||||||
# 条件编码,全局,所以只用一次
|
# 条件编码,全局,所以只用一次
|
||||||
cond = self.cond(val_cond)
|
cond = self.cond(val_cond)
|
||||||
|
|
||||||
for y in range(0, self.height):
|
for y in range(0, self.height):
|
||||||
for x in range(0, self.width):
|
for x in range(0, self.width):
|
||||||
x_tensor = torch.LongTensor([x], device=self.device)
|
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||||
y_tensor = torch.LongTensor([y], device=self.device)
|
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||||
# 位置编码、图块编码、地图局部编码
|
# 位置编码、图块编码、地图局部编码
|
||||||
tile_embed = self.tile_embedding(now_tile if use_self else target_map[:, y, x])
|
tile_embed = self.tile_embedding(target_map[:, y, x])
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||||
map_patch = self.map_patch(map, x, y)
|
map_patch = self.map_patch(map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
@ -194,9 +194,9 @@ class GinkaRNNModel(nn.Module):
|
|||||||
output_logits[:, y, x] = logits[:]
|
output_logits[:, y, x] = logits[:]
|
||||||
hidden[:] = h[:]
|
hidden[:] = h[:]
|
||||||
probs = F.softmax(logits, dim=1)
|
probs = F.softmax(logits, dim=1)
|
||||||
tile_id = torch.argmax(probs, dim=1)
|
tile_id = torch.argmax(probs, dim=1).detach()
|
||||||
map[:, y, x] = tile_id[:]
|
map[:, y, x] = tile_id[:]
|
||||||
now_tile[:] = tile_id[:]
|
# now_tile[:] = tile_id[:]
|
||||||
|
|
||||||
return output_logits, map
|
return output_logits, map
|
||||||
|
|
||||||
|
|||||||
@ -114,7 +114,8 @@ def train():
|
|||||||
val_cond = batch["val_cond"].to(device)
|
val_cond = batch["val_cond"].to(device)
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
with torch.autograd.set_detect_anomaly(True):
|
||||||
|
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
|
||||||
|
|
||||||
loss = criterion.rnn_loss(fake_logits, target_map)
|
loss = criterion.rnn_loss(fake_logits, target_map)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user