mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 11:01:12 +08:00
feat: 当前阶段条件注入
This commit is contained in:
parent
e3e496957c
commit
d800a2382b
@ -7,6 +7,13 @@ class ConditionEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||||
|
self.stage_embed = nn.Sequential(
|
||||||
|
nn.Linear(1, 64),
|
||||||
|
nn.LayerNorm(64),
|
||||||
|
nn.ELU(),
|
||||||
|
|
||||||
|
nn.Linear(64, hidden_dim),
|
||||||
|
)
|
||||||
self.encoder = nn.TransformerEncoder(
|
self.encoder = nn.TransformerEncoder(
|
||||||
nn.TransformerEncoderLayer(
|
nn.TransformerEncoderLayer(
|
||||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||||
@ -22,10 +29,11 @@ class ConditionEncoder(nn.Module):
|
|||||||
nn.Linear(hidden_dim*2, out_dim)
|
nn.Linear(hidden_dim*2, out_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, tag, val):
|
def forward(self, tag, val, stage):
|
||||||
tag = self.tag_embed(tag)
|
tag = self.tag_embed(tag)
|
||||||
val = self.val_embed(val)
|
val = self.val_embed(val)
|
||||||
feat = torch.stack([tag, val], dim=1)
|
stage = self.stage_embed(stage)
|
||||||
|
feat = torch.stack([tag, val, stage], dim=1)
|
||||||
feat = self.encoder(feat)
|
feat = self.encoder(feat)
|
||||||
feat = torch.mean(feat, dim=1)
|
feat = torch.mean(feat, dim=1)
|
||||||
feat = self.fusion(feat)
|
feat = self.fusion(feat)
|
||||||
|
|||||||
@ -76,9 +76,11 @@ class MinamoModel(nn.Module):
|
|||||||
self.head3 = MinamoScoreHead(512, 512)
|
self.head3 = MinamoScoreHead(512, 512)
|
||||||
|
|
||||||
def forward(self, map, graph, stage, tag_cond, val_cond):
|
def forward(self, map, graph, stage, tag_cond, val_cond):
|
||||||
|
B, D = tag_cond.shape
|
||||||
|
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(map.device)
|
||||||
vision = self.vision_model(map)
|
vision = self.vision_model(map)
|
||||||
topo = self.topo_model(graph)
|
topo = self.topo_model(graph)
|
||||||
cond = self.cond(tag_cond, val_cond)
|
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||||
if stage == 1:
|
if stage == 1:
|
||||||
vision_score, topo_score = self.head1(vision, topo, graph, cond)
|
vision_score, topo_score = self.head1(vision, topo, graph, cond)
|
||||||
elif stage == 2:
|
elif stage == 2:
|
||||||
|
|||||||
@ -375,7 +375,7 @@ def immutable_penalty_loss(
|
|||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
class WGANGinkaLoss:
|
||||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]):
|
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
|
||||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|||||||
@ -21,7 +21,9 @@ class GinkaModel(nn.Module):
|
|||||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||||
|
|
||||||
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
||||||
cond = self.cond(tag_cond, val_cond)
|
B, D = tag_cond.shape
|
||||||
|
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
||||||
|
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||||
if random:
|
if random:
|
||||||
x_in = F.softmax(self.head(x, cond), dim=1)
|
x_in = F.softmax(self.head(x, cond), dim=1)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -325,9 +325,11 @@ def train():
|
|||||||
low_loss_epochs = 0
|
low_loss_epochs = 0
|
||||||
|
|
||||||
if train_stage >= 2:
|
if train_stage >= 2:
|
||||||
train_stage += 1
|
if stage_epoch % 5 == 1:
|
||||||
|
train_stage = 3
|
||||||
if train_stage == 5:
|
elif stage_epoch % 5 == 3:
|
||||||
|
train_stage = 4
|
||||||
|
elif stage_epoch % 5 == 0:
|
||||||
train_stage = 2
|
train_stage = 2
|
||||||
|
|
||||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user