perf: 改进输入部分

This commit is contained in:
unanmed 2025-05-07 15:38:31 +08:00
parent d800a2382b
commit 55f09fb37b
7 changed files with 148 additions and 220 deletions

View File

@ -62,3 +62,35 @@ class GCNBlock(nn.Module):
offset = i * num_nodes_per_batch offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset) batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1) return torch.cat(batch_edge_index, dim=1)
class ConvFusionModule(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
super().__init__()
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h)
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch])
def forward(self, x):
x1 = self.cnn(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
return x
class DoubleFCModule(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim),
nn.ELU()
)
def forward(self, x):
x = self.fc(x)
return x

View File

@ -1,19 +1,14 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .common import DoubleFCModule
class ConditionEncoder(nn.Module): class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim): def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__() super().__init__()
self.tag_embed = nn.Linear(tag_dim, hidden_dim) self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim)
self.val_embed = nn.Linear(val_dim, hidden_dim) self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim)
self.stage_embed = nn.Sequential( self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim)
nn.Linear(1, 64),
nn.LayerNorm(64),
nn.ELU(),
nn.Linear(64, hidden_dim),
)
self.encoder = nn.TransformerEncoder( self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer( nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,

View File

@ -1,18 +1,12 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..common.common import GCNBlock, DoubleConvBlock from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector from ..common.cond import ConditionInjector
class RandomInputHead(nn.Module): class RandomInputHead(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.conv = DoubleConvBlock([32, 64, 128]) self.enc = ConvFusionModule(32, 256, 256, 32, 32)
self.gcn = GCNBlock(32, 128, 128, 32, 32)
self.fusion = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(256),
nn.ELU(),
)
self.out_conv = nn.Sequential( self.out_conv = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'), nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(128), nn.InstanceNorm2d(128),
@ -24,33 +18,45 @@ class RandomInputHead(nn.Module):
self.inject = ConditionInjector(256, 256) self.inject = ConditionInjector(256, 256)
def forward(self, x, cond): def forward(self, x, cond):
x_cnn = self.conv(x) x = self.enc(x)
x_gcn = self.gcn(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.inject(x, cond) x = self.inject(x, cond)
x = self.out_conv(x) x = self.out_conv(x)
return x return x
class InputUpsample(nn.Module):
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
nn.ELU(),
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
nn.ELU(),
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
nn.Conv2d(hidden_ch, out_ch, kernel_size=3, padding=1),
nn.ELU(),
)
def forward(self, x): # [B, C, 13, 13]
x = self.net(x) # [B, C, 32, 32]
return x
class GinkaInput(nn.Module): class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__() super().__init__()
self.out_size = out_size self.out_size = out_size
self.fc = nn.Sequential( self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1])
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]), self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
nn.LayerNorm(out_size[0] * out_size[1]), self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
nn.ELU() self.inject1 = ConditionInjector(256, in_ch)
) self.inject2 = ConditionInjector(256, out_ch)
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU()
)
def forward(self, x): def forward(self, x, cond):
B, C, H, W = x.shape x = self.enc1(x)
x = x.view(B, C, H * W) x = self.inject1(x, cond)
x = self.fc(x) x = self.upsample(x)
x = x.view(B, C, self.out_size[0], self.out_size[1]) x = self.enc2(x)
x = self.conv(x) x = self.inject2(x, cond)
return x return x

View File

@ -11,13 +11,20 @@ from ..critic.model import MinamoModel
CLASS_NUM = 32 CLASS_NUM = 32
ILLEGAL_MAX_NUM = 30 ILLEGAL_MAX_NUM = 30
STAGE_ALLOWED = [ STAGE_CHANGEABLE = [
[], [],
[0, 1, 2, 29, 30], [0, 1, 2, 29, 30],
[3, 4, 5, 6, 26, 27, 28], [3, 4, 5, 6, 26, 27, 28],
list(range(7, 26)) list(range(7, 26))
] ]
STAGE_ALLOWED = [
[],
STAGE_CHANGEABLE[1],
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]],
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]]
]
DENSITY_MAP = [ DENSITY_MAP = [
[1, *list(range(3, 30))], [1, *list(range(3, 30))],
[1], [1],
@ -32,6 +39,27 @@ DENSITY_MAP = [
[29, 30] [29, 30]
] ]
DENSITY_WEIGHTS = [
1,
1.5,
0.5,
5,
4,
3,
3,
3,
5,
10,
20
]
DENSITY_STAGE = [
[],
[1, 2, 10],
[1, 2, 3, 4, 10],
list(range(0, 11))
]
def get_not_allowed(classes: list[int], include_illegal=False): def get_not_allowed(classes: list[int], include_illegal=False):
res = list() res = list()
for num in range(0, CLASS_NUM): for num in range(0, CLASS_NUM):
@ -44,37 +72,6 @@ def get_not_allowed(classes: list[int], include_illegal=False):
return res return res
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[*list(range(0, 29)), 30]):
"""
强制地图最外圈像素必须为指定类别墙或箭头
参数:
pred: 模型输出的概率分布形状 [B, C, H, W]
allowed_classes: 允许出现在外圈的类别列表
返回:
loss: 标量损失值
"""
B, C, H, W = pred.shape
# 创建外圈mask [H, W]
border_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
border_mask[0, :] = True # 第一行
border_mask[-1, :] = True # 最后一行
border_mask[:, 0] = True # 第一列
border_mask[:, -1] = True # 最后一列
# 提取所有允许和不允许类别的概率和 [B, H, W]
unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1)
# 获取外圈区域允许类别的概率 [B, N_pixels]
border_unallowed = unallowed_probs[:, border_mask]
target = torch.zeros_like(border_unallowed)
loss_unallowed = F.mse_loss(border_unallowed, target)
return loss_unallowed
def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))): def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))):
"""限定内部允许出现的图块种类 """限定内部允许出现的图块种类
@ -159,93 +156,6 @@ def entrance_constraint_loss(
) )
return total_loss return total_loss
def adaptive_count_loss(
pred_probs: torch.Tensor,
target_map: torch.Tensor,
class_list: list = list(range(32)),
margin_ratio: float = 0.1, # 降低margin比例以更严格
zero_margin_scale: float = 0.1, # 减少零类别的margin
lambda_entropy: float = 0.2, # 增大熵约束权重
lambda_local: float = 0.2,
lambda_max: float = 0, # 新增最大概率约束
grid_size: int = 4, # 减小局部网格尺寸
eps: float = 1e-3
) -> torch.Tensor:
"""
改进版自适应图块数量约束损失增强局部匹配和概率确定性
"""
B, C, H, W = pred_probs.shape
device = pred_probs.device
total_loss = 0.0
valid_classes = 0
# 预计算地图面积
map_area = math.sqrt(H * W)
# 动态调整零类别的margin基于预测中最小的非零概率
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean()
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area
# 计算每个类别的数量损失
for cls in class_list:
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量
zero_mask = (true_count == 0)
dynamic_margin = torch.where(
zero_mask,
dynamic_zero_margin,
margin_ratio * true_count
)
safe_true = true_count + eps * zero_mask
abs_error = torch.abs(pred_count - true_count)
rel_error = abs_error / safe_true
# 调整损失函数形状,远离目标时惩罚更大
loss_per_class = torch.where(
abs_error <= dynamic_margin,
rel_error ** 2, # 近目标时二次损失
(rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长
)
# 零类别使用更严格的绝对误差惩罚
loss_per_class = torch.where(
zero_mask,
F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area,
loss_per_class
)
total_loss += loss_per_class.mean()
valid_classes += 1
total_loss /= valid_classes # 平均类别损失
# 改进的熵约束:每个像素的熵
def entropy_loss(pred_probs):
entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1)
return entropy_per_pixel.mean() # 所有像素的平均熵
total_loss += lambda_entropy * entropy_loss(pred_probs)
# 新增最大概率约束:鼓励每个位置概率尖锐化
max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率
max_loss = (1 - max_probs).mean() # 鼓励接近1
total_loss += lambda_max * max_loss
# 改进局部损失:约束局部区域内的数量
def local_count_loss(pred_probs, target_probs, grid_size):
grid_area = grid_size ** 2
# 计算每个grid内的预测数量
pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area
target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area
# 使用L1损失更鲁棒
return F.l1_loss(pred_counts, target_counts)
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
return total_loss
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)): def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
C = input_map.shape[1] C = input_map.shape[1]
mask = torch.ones(C, device=input_map.device) mask = torch.ones(C, device=input_map.device)
@ -261,7 +171,7 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
return wall_penalty return wall_penalty
def compute_multi_density_loss(probs, target_densities): def compute_multi_density_loss(probs, target_densities, tile_list):
""" """
pred: [B, C, H, W] pred: [B, C, H, W]
target_densities: [B, N] - N 个目标类别密度 target_densities: [B, N] - N 个目标类别密度
@ -271,54 +181,11 @@ def compute_multi_density_loss(probs, target_densities):
for i, classes in enumerate(DENSITY_MAP): for i, classes in enumerate(DENSITY_MAP):
class_map = probs[:, classes, :, :] class_map = probs[:, classes, :, :]
pred_density = torch.mean(class_map, dim=(1, 2, 3)) pred_density = torch.mean(class_map, dim=(1, 2, 3))
loss = F.mse_loss(pred_density, target_densities[:, i]) if i in tile_list:
losses.append(loss) loss = F.mse_loss(pred_density, target_densities[:, i])
losses.append(loss * DENSITY_WEIGHTS[i])
return sum(losses) return sum(losses)
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
"""Ginka Model 损失函数部分
Args:
weight (list, optional): 每一个损失函数的权重从第 0 项开始依次是
1. Minamo 相似度损失
2. 图块种类损失要求内部不出现箭头外圈只出现箭头和墙壁不允许出现非法图块
3. 入口间距及存在性损失
4. 怪物道具门数量损失
"""
super().__init__()
self.weight = weight
self.minamo = minamo
def forward(self, pred, target, target_vision_feat, target_topo_feat):
# 地图结构损失
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
entrance_loss = entrance_constraint_loss(pred)
count_loss = adaptive_count_loss(pred, target)
# 使用 Minamo Model 计算相似度
graph = batch_convert_soft_map_to_graph(pred)
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1)
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1)
minamo_sim = 0 * vision_sim + 1 * topo_sim
# tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}")
minamo_loss = (1.0 - minamo_sim).mean()
tqdm.write(
f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
)
losses = [
minamo_loss * self.weight[0] * 4,
class_loss * self.weight[1],
entrance_loss * self.weight[2],
count_loss * self.weight[3]
]
return sum(losses)
# 对图像数据进行插值 # 对图像数据进行插值
def interpolate_data(real_data, fake_data, epsilon): def interpolate_data(real_data, fake_data, epsilon):
return epsilon * real_data + (1 - epsilon) * fake_data return epsilon * real_data + (1 - epsilon) * fake_data
@ -374,9 +241,16 @@ def immutable_penalty_loss(
return penalty return penalty
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
target = torch.zeros_like(input_mask)
penalty = F.cross_entropy(input_mask, target)
return penalty
class WGANGinkaLoss: class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]): def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失 # weight: 判别器损失CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数 self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight self.weight = weight
@ -443,9 +317,9 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0) fake_a, fake_b = fake.chunk(2, dim=0)
@ -473,13 +347,15 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0) fake_a, fake_b = fake.chunk(2, dim=0)
losses = [ losses = [
minamo_loss * self.weight[0], minamo_loss * self.weight[0],
illegal_loss * self.weight[2],
constraint_loss * self.weight[3], constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6], density_loss * self.weight[6],
@ -498,9 +374,9 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake) constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
fake_a, fake_b = fake.chunk(2, dim=0) fake_a, fake_b = fake.chunk(2, dim=0)

View File

@ -16,8 +16,8 @@ class GinkaModel(nn.Module):
super().__init__() super().__init__()
self.head = RandomInputHead() self.head = RandomInputHead()
self.cond = ConditionEncoder(64, 16, 256, 256) self.cond = ConditionEncoder(64, 16, 256, 256)
self.input = GinkaInput(32, 32, (13, 13), (32, 32)) self.input = GinkaInput(32, 64, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch) self.unet = GinkaUNet(64, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13)) self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage, tag_cond, val_cond, random=False): def forward(self, x, stage, tag_cond, val_cond, random=False):
@ -28,7 +28,7 @@ class GinkaModel(nn.Module):
x_in = F.softmax(self.head(x, cond), dim=1) x_in = F.softmax(self.head(x, cond), dim=1)
else: else:
x_in = x x_in = x
x = self.input(x_in) x = self.input(x_in, cond)
x = self.unet(x, cond) x = self.unet(x, cond)
x = self.output(x, stage, cond) x = self.output(x, stage, cond)
return x, x_in return x, x_in
@ -51,7 +51,7 @@ if __name__ == "__main__":
print(f"输入形状: feat={input.shape}") print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}") print(f"输出形状: output={output.shape}")
print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}") print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")

View File

@ -61,6 +61,22 @@ class FusionModule(nn.Module):
x = self.conv(x) x = self.conv(x)
return x return x
class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvBlock(in_ch, in_ch)
self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h)
self.fusion = ConvBlock(in_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x1 = self.conv(x)
x2 = self.gcn(x)
x = torch.cat([x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
return x
class GinkaEncoder(nn.Module): class GinkaEncoder(nn.Module):
"""编码器(下采样)部分""" """编码器(下采样)部分"""
def __init__(self, in_ch, out_ch): def __init__(self, in_ch, out_ch):
@ -142,7 +158,7 @@ class GinkaBottleneck(nn.Module):
super().__init__() super().__init__()
self.transformer = GinkaTransformerEncoder( self.transformer = GinkaTransformerEncoder(
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h, in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
token_size=16, ff_dim=1024, num_layers=4 token_size=16, ff_dim=1024, num_layers=6
) )
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4) self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1) self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
@ -167,7 +183,7 @@ class GinkaUNet(nn.Module):
"""Ginka Model UNet 部分 """Ginka Model UNet 部分
""" """
super().__init__() super().__init__()
self.down1 = ConvBlock(in_ch, base_ch) self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4) self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
@ -184,7 +200,7 @@ class GinkaUNet(nn.Module):
) )
def forward(self, x, cond): def forward(self, x, cond):
x1 = self.down1(x) # [B, 64, 32, 32] x1 = self.down1(x, cond) # [B, 64, 32, 32]
x2 = self.down2(x1, cond) # [B, 128, 16, 16] x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8] x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4] x4 = self.down4(x3, cond) # [B, 512, 4, 4]

View File

@ -325,11 +325,11 @@ def train():
low_loss_epochs = 0 low_loss_epochs = 0
if train_stage >= 2: if train_stage >= 2:
if stage_epoch % 5 == 1: if (epoch + 1) % 5 == 1:
train_stage = 3 train_stage = 3
elif stage_epoch % 5 == 3: elif (epoch + 1) % 5 == 3:
train_stage = 4 train_stage = 4
elif stage_epoch % 5 == 0: elif (epoch + 1) % 5 == 0:
train_stage = 2 train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
@ -350,6 +350,9 @@ def train():
else: else:
g_steps = 1 g_steps = 1
if avg_loss_ginka > 0:
g_steps += int(max(avg_loss_ginka * 5, 0))
if avg_loss_minamo > 0: if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15)) c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else: else: