feat: 当前阶段条件注入

This commit is contained in:
unanmed 2025-05-02 15:20:58 +08:00
parent e3e496957c
commit d800a2382b
5 changed files with 22 additions and 8 deletions

View File

@ -7,6 +7,13 @@ class ConditionEncoder(nn.Module):
super().__init__()
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
self.val_embed = nn.Linear(val_dim, hidden_dim)
self.stage_embed = nn.Sequential(
nn.Linear(1, 64),
nn.LayerNorm(64),
nn.ELU(),
nn.Linear(64, hidden_dim),
)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
@ -22,10 +29,11 @@ class ConditionEncoder(nn.Module):
nn.Linear(hidden_dim*2, out_dim)
)
def forward(self, tag, val):
def forward(self, tag, val, stage):
tag = self.tag_embed(tag)
val = self.val_embed(val)
feat = torch.stack([tag, val], dim=1)
stage = self.stage_embed(stage)
feat = torch.stack([tag, val, stage], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)

View File

@ -76,9 +76,11 @@ class MinamoModel(nn.Module):
self.head3 = MinamoScoreHead(512, 512)
def forward(self, map, graph, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device)
vision = self.vision_model(map)
topo = self.topo_model(graph)
cond = self.cond(tag_cond, val_cond)
cond = self.cond(tag_cond, val_cond, stage_tensor)
if stage == 1:
vision_score, topo_score = self.head1(vision, topo, graph, cond)
elif stage == 2:

View File

@ -375,7 +375,7 @@ def immutable_penalty_loss(
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]):
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight

View File

@ -21,7 +21,9 @@ class GinkaModel(nn.Module):
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage, tag_cond, val_cond, random=False):
cond = self.cond(tag_cond, val_cond)
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
if random:
x_in = F.softmax(self.head(x, cond), dim=1)
else:

View File

@ -325,9 +325,11 @@ def train():
low_loss_epochs = 0
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
if stage_epoch % 5 == 1:
train_stage = 3
elif stage_epoch % 5 == 3:
train_stage = 4
elif stage_epoch % 5 == 0:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: