mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 当前阶段条件注入
This commit is contained in:
parent
e3e496957c
commit
d800a2382b
@ -7,6 +7,13 @@ class ConditionEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||
self.stage_embed = nn.Sequential(
|
||||
nn.Linear(1, 64),
|
||||
nn.LayerNorm(64),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(64, hidden_dim),
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||
@ -22,10 +29,11 @@ class ConditionEncoder(nn.Module):
|
||||
nn.Linear(hidden_dim*2, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, tag, val):
|
||||
def forward(self, tag, val, stage):
|
||||
tag = self.tag_embed(tag)
|
||||
val = self.val_embed(val)
|
||||
feat = torch.stack([tag, val], dim=1)
|
||||
stage = self.stage_embed(stage)
|
||||
feat = torch.stack([tag, val, stage], dim=1)
|
||||
feat = self.encoder(feat)
|
||||
feat = torch.mean(feat, dim=1)
|
||||
feat = self.fusion(feat)
|
||||
|
||||
@ -76,9 +76,11 @@ class MinamoModel(nn.Module):
|
||||
self.head3 = MinamoScoreHead(512, 512)
|
||||
|
||||
def forward(self, map, graph, stage, tag_cond, val_cond):
|
||||
B, D = tag_cond.shape
|
||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device)
|
||||
vision = self.vision_model(map)
|
||||
topo = self.topo_model(graph)
|
||||
cond = self.cond(tag_cond, val_cond)
|
||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||
if stage == 1:
|
||||
vision_score, topo_score = self.head1(vision, topo, graph, cond)
|
||||
elif stage == 2:
|
||||
|
||||
@ -375,7 +375,7 @@ def immutable_penalty_loss(
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -21,7 +21,9 @@ class GinkaModel(nn.Module):
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
||||
cond = self.cond(tag_cond, val_cond)
|
||||
B, D = tag_cond.shape
|
||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||
if random:
|
||||
x_in = F.softmax(self.head(x, cond), dim=1)
|
||||
else:
|
||||
|
||||
@ -325,9 +325,11 @@ def train():
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
train_stage += 1
|
||||
|
||||
if train_stage == 5:
|
||||
if stage_epoch % 5 == 1:
|
||||
train_stage = 3
|
||||
elif stage_epoch % 5 == 3:
|
||||
train_stage = 4
|
||||
elif stage_epoch % 5 == 0:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user