mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 改进输入部分
This commit is contained in:
parent
d800a2382b
commit
55f09fb37b
@ -61,4 +61,36 @@ class GCNBlock(nn.Module):
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
|
||||
class ConvFusionModule(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w: int, h: int):
|
||||
super().__init__()
|
||||
self.cnn = DoubleConvBlock([in_ch, hidden_ch, in_ch])
|
||||
self.gcn = GCNBlock(in_ch, hidden_ch, in_ch, w, h)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, hidden_ch*2, out_ch])
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.cnn(x)
|
||||
x2 = self.gcn(x)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
return x
|
||||
|
||||
class DoubleFCModule(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(hidden_dim, out_dim),
|
||||
nn.LayerNorm(out_dim),
|
||||
nn.ELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
@ -1,19 +1,14 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .common import DoubleFCModule
|
||||
|
||||
class ConditionEncoder(nn.Module):
|
||||
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
|
||||
super().__init__()
|
||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||
self.stage_embed = nn.Sequential(
|
||||
nn.Linear(1, 64),
|
||||
nn.LayerNorm(64),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(64, hidden_dim),
|
||||
)
|
||||
self.tag_embed = DoubleFCModule(tag_dim, hidden_dim*2, hidden_dim)
|
||||
self.val_embed = DoubleFCModule(val_dim, hidden_dim*2, hidden_dim)
|
||||
self.stage_embed = DoubleFCModule(1, hidden_dim*2, hidden_dim)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||
|
||||
@ -1,18 +1,12 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ..common.common import GCNBlock, DoubleConvBlock
|
||||
from ..common.common import ConvFusionModule
|
||||
from ..common.cond import ConditionInjector
|
||||
|
||||
class RandomInputHead(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = DoubleConvBlock([32, 64, 128])
|
||||
self.gcn = GCNBlock(32, 128, 128, 32, 32)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(256),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.enc = ConvFusionModule(32, 256, 256, 32, 32)
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(128),
|
||||
@ -24,33 +18,45 @@ class RandomInputHead(nn.Module):
|
||||
self.inject = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x_cnn = self.conv(x)
|
||||
x_gcn = self.gcn(x)
|
||||
x = torch.cat([x_cnn, x_gcn], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.enc(x)
|
||||
x = self.inject(x, cond)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
class InputUpsample(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
|
||||
super().__init__()
|
||||
self.net = nn.Sequential(
|
||||
nn.Conv2d(in_ch, hidden_ch, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
|
||||
nn.Conv2d(hidden_ch, hidden_ch, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
|
||||
nn.Conv2d(hidden_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x): # [B, C, 13, 13]
|
||||
x = self.net(x) # [B, C, 32, 32]
|
||||
return x
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
||||
super().__init__()
|
||||
self.out_size = out_size
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]),
|
||||
nn.LayerNorm(out_size[0] * out_size[1]),
|
||||
nn.ELU()
|
||||
)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU()
|
||||
)
|
||||
self.enc1 = ConvFusionModule(in_ch, in_ch*4, in_ch, in_size[0], in_size[1])
|
||||
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
|
||||
self.enc2 = ConvFusionModule(out_ch, out_ch*4, out_ch, out_size[0], out_size[1])
|
||||
self.inject1 = ConditionInjector(256, in_ch)
|
||||
self.inject2 = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.view(B, C, H * W)
|
||||
x = self.fc(x)
|
||||
x = x.view(B, C, self.out_size[0], self.out_size[1])
|
||||
x = self.conv(x)
|
||||
def forward(self, x, cond):
|
||||
x = self.enc1(x)
|
||||
x = self.inject1(x, cond)
|
||||
x = self.upsample(x)
|
||||
x = self.enc2(x)
|
||||
x = self.inject2(x, cond)
|
||||
return x
|
||||
|
||||
@ -11,13 +11,20 @@ from ..critic.model import MinamoModel
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 30
|
||||
|
||||
STAGE_ALLOWED = [
|
||||
STAGE_CHANGEABLE = [
|
||||
[],
|
||||
[0, 1, 2, 29, 30],
|
||||
[3, 4, 5, 6, 26, 27, 28],
|
||||
list(range(7, 26))
|
||||
]
|
||||
|
||||
STAGE_ALLOWED = [
|
||||
[],
|
||||
STAGE_CHANGEABLE[1],
|
||||
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2]],
|
||||
[*STAGE_CHANGEABLE[1], *STAGE_CHANGEABLE[2], *STAGE_CHANGEABLE[3]]
|
||||
]
|
||||
|
||||
DENSITY_MAP = [
|
||||
[1, *list(range(3, 30))],
|
||||
[1],
|
||||
@ -32,6 +39,27 @@ DENSITY_MAP = [
|
||||
[29, 30]
|
||||
]
|
||||
|
||||
DENSITY_WEIGHTS = [
|
||||
1,
|
||||
1.5,
|
||||
0.5,
|
||||
5,
|
||||
4,
|
||||
3,
|
||||
3,
|
||||
3,
|
||||
5,
|
||||
10,
|
||||
20
|
||||
]
|
||||
|
||||
DENSITY_STAGE = [
|
||||
[],
|
||||
[1, 2, 10],
|
||||
[1, 2, 3, 4, 10],
|
||||
list(range(0, 11))
|
||||
]
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
res = list()
|
||||
for num in range(0, CLASS_NUM):
|
||||
@ -44,37 +72,6 @@ def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
|
||||
return res
|
||||
|
||||
def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[*list(range(0, 29)), 30]):
|
||||
"""
|
||||
强制地图最外圈像素必须为指定类别(墙或箭头)
|
||||
|
||||
参数:
|
||||
pred: 模型输出的概率分布,形状 [B, C, H, W]
|
||||
allowed_classes: 允许出现在外圈的类别列表
|
||||
|
||||
返回:
|
||||
loss: 标量损失值
|
||||
"""
|
||||
B, C, H, W = pred.shape
|
||||
|
||||
# 创建外圈mask [H, W]
|
||||
border_mask = torch.zeros((H, W), dtype=torch.bool, device=pred.device)
|
||||
border_mask[0, :] = True # 第一行
|
||||
border_mask[-1, :] = True # 最后一行
|
||||
border_mask[:, 0] = True # 第一列
|
||||
border_mask[:, -1] = True # 最后一列
|
||||
|
||||
# 提取所有允许和不允许类别的概率和 [B, H, W]
|
||||
unallowed_probs = pred[:, get_not_allowed(allowed_classes, include_illegal=True), :, :].sum(dim=1)
|
||||
|
||||
# 获取外圈区域允许类别的概率 [B, N_pixels]
|
||||
border_unallowed = unallowed_probs[:, border_mask]
|
||||
|
||||
target = torch.zeros_like(border_unallowed)
|
||||
loss_unallowed = F.mse_loss(border_unallowed, target)
|
||||
|
||||
return loss_unallowed
|
||||
|
||||
def inner_constraint_loss(pred: torch.Tensor, allowed=list(range(0, 30))):
|
||||
"""限定内部允许出现的图块种类
|
||||
|
||||
@ -159,93 +156,6 @@ def entrance_constraint_loss(
|
||||
)
|
||||
return total_loss
|
||||
|
||||
def adaptive_count_loss(
|
||||
pred_probs: torch.Tensor,
|
||||
target_map: torch.Tensor,
|
||||
class_list: list = list(range(32)),
|
||||
margin_ratio: float = 0.1, # 降低margin比例以更严格
|
||||
zero_margin_scale: float = 0.1, # 减少零类别的margin
|
||||
lambda_entropy: float = 0.2, # 增大熵约束权重
|
||||
lambda_local: float = 0.2,
|
||||
lambda_max: float = 0, # 新增最大概率约束
|
||||
grid_size: int = 4, # 减小局部网格尺寸
|
||||
eps: float = 1e-3
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
改进版自适应图块数量约束损失,增强局部匹配和概率确定性
|
||||
"""
|
||||
B, C, H, W = pred_probs.shape
|
||||
device = pred_probs.device
|
||||
total_loss = 0.0
|
||||
valid_classes = 0
|
||||
|
||||
# 预计算地图面积
|
||||
map_area = math.sqrt(H * W)
|
||||
|
||||
# 动态调整零类别的margin:基于预测中最小的非零概率
|
||||
min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean()
|
||||
dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area
|
||||
|
||||
# 计算每个类别的数量损失
|
||||
for cls in class_list:
|
||||
pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测数量
|
||||
true_count = target_map[:, cls].sum(dim=(1,2)) # 真实数量
|
||||
|
||||
zero_mask = (true_count == 0)
|
||||
dynamic_margin = torch.where(
|
||||
zero_mask,
|
||||
dynamic_zero_margin,
|
||||
margin_ratio * true_count
|
||||
)
|
||||
|
||||
safe_true = true_count + eps * zero_mask
|
||||
abs_error = torch.abs(pred_count - true_count)
|
||||
rel_error = abs_error / safe_true
|
||||
|
||||
# 调整损失函数形状,远离目标时惩罚更大
|
||||
loss_per_class = torch.where(
|
||||
abs_error <= dynamic_margin,
|
||||
rel_error ** 2, # 近目标时二次损失
|
||||
(rel_error - 0.5 * margin_ratio) ** 2 # 远目标时二次增长
|
||||
)
|
||||
|
||||
# 零类别使用更严格的绝对误差惩罚
|
||||
loss_per_class = torch.where(
|
||||
zero_mask,
|
||||
F.relu(abs_error - dynamic_zero_margin) ** 2 / map_area,
|
||||
loss_per_class
|
||||
)
|
||||
|
||||
total_loss += loss_per_class.mean()
|
||||
valid_classes += 1
|
||||
|
||||
total_loss /= valid_classes # 平均类别损失
|
||||
|
||||
# 改进的熵约束:每个像素的熵
|
||||
def entropy_loss(pred_probs):
|
||||
entropy_per_pixel = -torch.sum(pred_probs * torch.log(pred_probs + 1e-6), dim=1)
|
||||
return entropy_per_pixel.mean() # 所有像素的平均熵
|
||||
|
||||
total_loss += lambda_entropy * entropy_loss(pred_probs)
|
||||
|
||||
# 新增最大概率约束:鼓励每个位置概率尖锐化
|
||||
max_probs = pred_probs.max(dim=1)[0] # 每个位置的最大概率
|
||||
max_loss = (1 - max_probs).mean() # 鼓励接近1
|
||||
total_loss += lambda_max * max_loss
|
||||
|
||||
# 改进局部损失:约束局部区域内的数量
|
||||
def local_count_loss(pred_probs, target_probs, grid_size):
|
||||
grid_area = grid_size ** 2
|
||||
# 计算每个grid内的预测数量
|
||||
pred_counts = F.avg_pool2d(pred_probs, grid_size, stride=grid_size) * grid_area
|
||||
target_counts = F.avg_pool2d(target_probs, grid_size, stride=grid_size) * grid_area
|
||||
# 使用L1损失更鲁棒
|
||||
return F.l1_loss(pred_counts, target_counts)
|
||||
|
||||
total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size)
|
||||
|
||||
return total_loss
|
||||
|
||||
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
|
||||
C = input_map.shape[1]
|
||||
mask = torch.ones(C, device=input_map.device)
|
||||
@ -261,7 +171,7 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
|
||||
|
||||
return wall_penalty
|
||||
|
||||
def compute_multi_density_loss(probs, target_densities):
|
||||
def compute_multi_density_loss(probs, target_densities, tile_list):
|
||||
"""
|
||||
pred: [B, C, H, W]
|
||||
target_densities: [B, N] - N 个目标类别密度
|
||||
@ -271,53 +181,10 @@ def compute_multi_density_loss(probs, target_densities):
|
||||
for i, classes in enumerate(DENSITY_MAP):
|
||||
class_map = probs[:, classes, :, :]
|
||||
pred_density = torch.mean(class_map, dim=(1, 2, 3))
|
||||
loss = F.mse_loss(pred_density, target_densities[:, i])
|
||||
losses.append(loss)
|
||||
if i in tile_list:
|
||||
loss = F.mse_loss(pred_density, target_densities[:, i])
|
||||
losses.append(loss * DENSITY_WEIGHTS[i])
|
||||
return sum(losses)
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
|
||||
"""Ginka Model 损失函数部分
|
||||
|
||||
Args:
|
||||
weight (list, optional): 每一个损失函数的权重,从第 0 项开始,依次是:
|
||||
1. Minamo 相似度损失
|
||||
2. 图块种类损失,要求内部不出现箭头,外圈只出现箭头和墙壁,不允许出现非法图块
|
||||
3. 入口间距及存在性损失
|
||||
4. 怪物、道具、门数量损失
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.minamo = minamo
|
||||
|
||||
def forward(self, pred, target, target_vision_feat, target_topo_feat):
|
||||
# 地图结构损失
|
||||
class_loss = outer_border_constraint_loss(pred) + inner_constraint_loss(pred)
|
||||
entrance_loss = entrance_constraint_loss(pred)
|
||||
count_loss = adaptive_count_loss(pred, target)
|
||||
|
||||
# 使用 Minamo Model 计算相似度
|
||||
graph = batch_convert_soft_map_to_graph(pred)
|
||||
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
|
||||
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1)
|
||||
minamo_sim = 0 * vision_sim + 1 * topo_sim
|
||||
# tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}")
|
||||
minamo_loss = (1.0 - minamo_sim).mean()
|
||||
|
||||
tqdm.write(
|
||||
f"{minamo_loss.item():.12f}, {class_loss.item():.12f}, {entrance_loss.item():.12f}, {count_loss.item():.12f}"
|
||||
)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0] * 4,
|
||||
class_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2],
|
||||
count_loss * self.weight[3]
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
# 对图像数据进行插值
|
||||
def interpolate_data(real_data, fake_data, epsilon):
|
||||
@ -374,9 +241,16 @@ def immutable_penalty_loss(
|
||||
|
||||
return penalty
|
||||
|
||||
def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
||||
not_allowed = get_not_allowed(legal_classes, include_illegal=True)
|
||||
input_mask = pred[:, not_allowed, :, :]
|
||||
target = torch.zeros_like(input_mask)
|
||||
penalty = F.cross_entropy(input_mask, target)
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.5]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失和非法图块损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -443,9 +317,9 @@ class WGANGinkaLoss:
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -473,13 +347,15 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
illegal_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
density_loss * self.weight[6],
|
||||
@ -498,9 +374,9 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
|
||||
@ -16,8 +16,8 @@ class GinkaModel(nn.Module):
|
||||
super().__init__()
|
||||
self.head = RandomInputHead()
|
||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||
self.input = GinkaInput(32, 64, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(64, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
||||
@ -28,7 +28,7 @@ class GinkaModel(nn.Module):
|
||||
x_in = F.softmax(self.head(x, cond), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
x = self.input(x_in)
|
||||
x = self.input(x_in, cond)
|
||||
x = self.unet(x, cond)
|
||||
x = self.output(x, stage, cond)
|
||||
return x, x_in
|
||||
@ -51,7 +51,7 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}")
|
||||
print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
|
||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
|
||||
@ -60,6 +60,22 @@ class FusionModule(nn.Module):
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GinkaUNetInput(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, in_ch)
|
||||
self.gcn = GCNBlock(in_ch, in_ch*2, in_ch, w, h)
|
||||
self.fusion = ConvBlock(in_ch*2, out_ch)
|
||||
self.inject = ConditionInjector(256, out_ch)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x1 = self.conv(x)
|
||||
x2 = self.gcn(x)
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.inject(x, cond)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
@ -142,7 +158,7 @@ class GinkaBottleneck(nn.Module):
|
||||
super().__init__()
|
||||
self.transformer = GinkaTransformerEncoder(
|
||||
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
token_size=16, ff_dim=1024, num_layers=6
|
||||
)
|
||||
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
||||
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
||||
@ -167,7 +183,7 @@ class GinkaUNet(nn.Module):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.down1 = ConvBlock(in_ch, base_ch)
|
||||
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
|
||||
@ -184,7 +200,7 @@ class GinkaUNet(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x1 = self.down1(x) # [B, 64, 32, 32]
|
||||
x1 = self.down1(x, cond) # [B, 64, 32, 32]
|
||||
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
|
||||
|
||||
@ -325,11 +325,11 @@ def train():
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
if stage_epoch % 5 == 1:
|
||||
if (epoch + 1) % 5 == 1:
|
||||
train_stage = 3
|
||||
elif stage_epoch % 5 == 3:
|
||||
elif (epoch + 1) % 5 == 3:
|
||||
train_stage = 4
|
||||
elif stage_epoch % 5 == 0:
|
||||
elif (epoch + 1) % 5 == 0:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
@ -350,6 +350,9 @@ def train():
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_ginka > 0:
|
||||
g_steps += int(max(avg_loss_ginka * 5, 0))
|
||||
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user