feat: rnn 随机采用自身生成的内容

This commit is contained in:
unanmed 2026-01-16 23:01:37 +08:00
parent d0decfc63a
commit cfc022724a
2 changed files with 11 additions and 6 deletions

View File

@ -1,4 +1,5 @@
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -182,7 +183,7 @@ class GinkaRNNModel(nn.Module):
self.feat_fusion = GinkaInputFusion()
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self=False):
def forward(self, val_cond: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
"""
val_cond: [B, val_dim]
target_map: [B, H, W]
@ -207,6 +208,7 @@ class GinkaRNNModel(nn.Module):
# 位置编码、图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)

View File

@ -7,19 +7,18 @@ import torch.optim as optim
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .common.cond import ConditionEncoder
from .generator.rnn import GinkaRNNModel
from .dataset import GinkaRNNDataset
from .generator.loss import RNNGinkaLoss
from shared.image import matrix_to_image_cv
# 手工标注标签定义
# 手工标注标签定义(暂时不用)
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具风
# 16. 区域入口, 17. 区域连接, 18. 有机关门, 19. 道具层, 20. 斜向对称, 21. 左右通道, 22. 上下通道, 23. 多机关门
# 24. 中心对称, 25. 部分对称, 26. 鱼骨
# 自动标注标签定义
# 自动标注标签定义(暂时不用)
# 0. 左右对称, 1. 上下对称, 2. 中心对称, 3. 斜向对称, 4. 伪对称, 5. 多房间, 6. 多走廊
# 32. 平面塔, 33. 转换塔, 34. 道具塔
@ -59,6 +58,10 @@ os.makedirs("result/ginka_rnn_img", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def gt_prob(epoch: int, max_epoch: int) -> float:
progress = epoch / max_epoch
return 0.1 + 0.9 * progress
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
@ -114,7 +117,7 @@ def train():
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs))
loss = criterion.rnn_loss(fake_logits, target_map)
@ -158,7 +161,7 @@ def train():
val_cond = batch["val_cond"].to(device)
target_map = batch["target_map"].to(device)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, False)
fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch))
val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()