diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 7d430d9..7a5a409 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -1,4 +1,5 @@ import time +import random import torch import torch.nn as nn import torch.nn.functional as F @@ -182,7 +183,7 @@ class GinkaRNNModel(nn.Module): self.feat_fusion = GinkaInputFusion() self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) - def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False): + def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0): """ val_cond: [B, val_dim] target_map: [B, H, W] @@ -207,6 +208,7 @@ class GinkaRNNModel(nn.Module): # 位置编码、图块编码、地图局部编码 tile_embed = self.tile_embedding(now_tile) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) + use_self = random.random() < use_self_probility map_patch = self.map_patch(map if use_self else target_map, x, y) # 编码特征融合 feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index d450e9b..e8eaf53 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -7,19 +7,18 @@ import torch.optim as optim import cv2 from torch_geometric.loader import DataLoader from tqdm import tqdm -from .common.cond import ConditionEncoder from .generator.rnn import GinkaRNNModel from .dataset import GinkaRNNDataset from .generator.loss import RNNGinkaLoss from shared.image import matrix_to_image_cv -# 手工标注标签定义: +# 手工标注标签定义(暂时不用): # 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, # 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风 # 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门 # 24. 中心对称, 25. 部分对称, 26. 鱼骨 -# 自动标注标签定义: +# 自动标注标签定义(暂时不用): # 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊 # 32. 平面塔, 33. 转换塔, 34. 道具塔 @@ -59,6 +58,10 @@ os.makedirs("result/ginka_rnn_img", exist_ok=True) disable_tqdm = not sys.stdout.isatty() +def gt_prob(epoch: int, max_epoch: int) -> float: + progress = epoch / max_epoch + return 0.1 + 0.9 * progress + def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) @@ -114,7 +117,7 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs)) loss = criterion.rnn_loss(fake_logits, target_map) @@ -158,7 +161,7 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch)) val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()