diff --git a/ginka/common/cond.py b/ginka/common/cond.py index ac64021..bd6c218 100644 --- a/ginka/common/cond.py +++ b/ginka/common/cond.py @@ -7,6 +7,13 @@ class ConditionEncoder(nn.Module): super().__init__() self.tag_embed = nn.Linear(tag_dim, hidden_dim) self.val_embed = nn.Linear(val_dim, hidden_dim) + self.stage_embed = nn.Sequential( + nn.Linear(1, 64), + nn.LayerNorm(64), + nn.ELU(), + + nn.Linear(64, hidden_dim), + ) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, @@ -22,10 +29,11 @@ class ConditionEncoder(nn.Module): nn.Linear(hidden_dim*2, out_dim) ) - def forward(self, tag, val): + def forward(self, tag, val, stage): tag = self.tag_embed(tag) val = self.val_embed(val) - feat = torch.stack([tag, val], dim=1) + stage = self.stage_embed(stage) + feat = torch.stack([tag, val, stage], dim=1) feat = self.encoder(feat) feat = torch.mean(feat, dim=1) feat = self.fusion(feat) diff --git a/ginka/critic/model.py b/ginka/critic/model.py index 9ef80c3..72415fd 100644 --- a/ginka/critic/model.py +++ b/ginka/critic/model.py @@ -76,9 +76,11 @@ class MinamoModel(nn.Module): self.head3 = MinamoScoreHead(512, 512) def forward(self, map, graph, stage, tag_cond, val_cond): + B, D = tag_cond.shape + stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device) vision = self.vision_model(map) topo = self.topo_model(graph) - cond = self.cond(tag_cond, val_cond) + cond = self.cond(tag_cond, val_cond, stage_tensor) if stage == 1: vision_score, topo_score = self.head1(vision, topo, graph, cond) elif stage == 2: diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 7660d18..2d13ef9 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -375,7 +375,7 @@ def immutable_penalty_loss( return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]): + def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]): # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight diff --git a/ginka/generator/model.py b/ginka/generator/model.py index 71871a9..acb1826 100644 --- a/ginka/generator/model.py +++ b/ginka/generator/model.py @@ -21,7 +21,9 @@ class GinkaModel(nn.Module): self.output = GinkaOutput(base_ch, out_ch, (13, 13)) def forward(self, x, stage, tag_cond, val_cond, random=False): - cond = self.cond(tag_cond, val_cond) + B, D = tag_cond.shape + stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) + cond = self.cond(tag_cond, val_cond, stage_tensor) if random: x_in = F.softmax(self.head(x, cond), dim=1) else: diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 1f3d419..14e1e80 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -325,9 +325,11 @@ def train(): low_loss_epochs = 0 if train_stage >= 2: - train_stage += 1 - - if train_stage == 5: + if stage_epoch % 5 == 1: + train_stage = 3 + elif stage_epoch % 5 == 3: + train_stage = 4 + elif stage_epoch % 5 == 0: train_stage = 2 if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: