diff --git a/gan.sh b/gan.sh index 5065d5c..4767ca5 100644 --- a/gan.sh +++ b/gan.sh @@ -2,7 +2,7 @@ python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true -python3 -m ginka.train --epochs 30 --resume true +python3 -m ginka.train --epochs 70 --resume true python3 -m ginka.validate # 训练完毕,处理数据 mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" diff --git a/ginka/model/input.py b/ginka/model/input.py deleted file mode 100644 index 5348a2b..0000000 --- a/ginka/model/input.py +++ /dev/null @@ -1,41 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class ResidualUpsampleBlock(nn.Module): - def __init__(self, in_ch, out_ch): - super().__init__() - self.conv = nn.Sequential( - nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True), - nn.Conv2d(in_ch, out_ch, 3, padding=1), - nn.InstanceNorm2d(out_ch), - nn.GELU(), - nn.Conv2d(out_ch, out_ch, 3, padding=1), - nn.InstanceNorm2d(out_ch), - nn.GELU() - ) - - def forward(self, x): - return self.conv(x) - -class GinkaInput(nn.Module): - def __init__(self, feat_dim=1024, out_ch=64): - super().__init__() - fc_dim = out_ch * 8 * 4 * 4 - self.out_ch = out_ch - self.fc = nn.Sequential( - nn.Linear(feat_dim, fc_dim), - nn.BatchNorm1d(fc_dim), - nn.ReLU() - ) - self.upsample = nn.Sequential( - ResidualUpsampleBlock(out_ch*8, out_ch*8), - ResidualUpsampleBlock(out_ch*8, out_ch*4), - ResidualUpsampleBlock(out_ch*4, out_ch) - ) - - def forward(self, x): - x = self.fc(x) - x = x.view(-1, self.out_ch*8, 4, 4) - x = self.upsample(x) - return x diff --git a/ginka/model/loss.py b/ginka/model/loss.py index ea635ca..a279627 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -117,71 +117,77 @@ def adaptive_count_loss( class_list: list = list(range(32)), margin_ratio: float = 0.2, zero_margin_scale: float = 0.2, + lambda_entropy: float = 0.05, + lambda_local: float = 0.1, + grid_size: int = 8, eps: float = 1e-3 ) -> torch.Tensor: """ - 自适应图块数量约束损失函数 + 改进版自适应图块数量约束损失,包含局部匹配和熵约束 - 参数: - pred_probs: 预测概率分布 [B, C, H, W] - target_map: 真实地图 [B, C, H, W] - class_list: 需要约束的类别列表 - margin_ratio: 允许的相对误差范围(如0.2表示±20%) - zero_margin_scale: 参考数量为0时的允许余量系数(余量=scale*sqrt(H*W)) - eps: 数值稳定性常数 - - 返回: - loss: 标量损失值 """ B, C, H, W = pred_probs.shape device = pred_probs.device total_loss = 0.0 valid_classes = 0 - # 预计算地图面积用于余量计算 + # 预计算地图面积 map_area = math.sqrt(H * W) + # 计算最小非零类别概率 + min_nonzero_prob = pred_probs[:, class_list].max(dim=1).values.mean() # 获取预测中的最小非零概率 + dynamic_zero_margin = zero_margin_scale * min_nonzero_prob * map_area # 让零类别不被填充 + for cls in class_list: - # 预测数量(概率和) - pred_count = pred_probs[:, cls].sum(dim=(1,2)) # [B] - # 真实数量 - true_count = target_map[:, cls].sum(dim=(1,2)) # [B] + pred_count = pred_probs[:, cls].sum(dim=(1,2)) # 预测类别数量 + true_count = target_map[:, cls].sum(dim=(1,2)) # 真实类别数量 - # 动态容差计算 - with torch.no_grad(): - # 当真实数量为0时的允许上限 - zero_mask = (true_count == 0) - dynamic_margin = torch.where( - zero_mask, - zero_margin_scale * map_area, # 允许存在少量 - margin_ratio * true_count # 相对误差范围 - ) + zero_mask = (true_count == 0) + dynamic_margin = torch.where( + zero_mask, + dynamic_zero_margin, + margin_ratio * true_count + ) - # 误差计算(考虑数值稳定性) - safe_true = true_count + eps * zero_mask # 零真实值时添加微小量 + safe_true = true_count + eps * zero_mask abs_error = torch.abs(pred_count - true_count) rel_error = abs_error / safe_true - # 双阶段损失函数 - # 阶段一:误差在容差范围内时使用二次函数(强梯度) - # 阶段二:超出容差时转为线性(稳定训练) loss_per_class = torch.where( abs_error <= dynamic_margin, - (rel_error ** 2) * 0.5, # 区间内强梯度 - rel_error - (0.5 * margin_ratio) # 区间外稳定梯度 + (rel_error ** 2) * 0.8 + 0.2 * rel_error, + rel_error - (0.5 * margin_ratio) ) - # 零真实值特殊处理:仅惩罚超出余量部分 loss_per_class = torch.where( zero_mask, - F.relu(abs_error - dynamic_margin) / map_area, # 归一化处理 + F.relu(abs_error - dynamic_margin) / map_area, loss_per_class ) total_loss += loss_per_class.mean() valid_classes += 1 - return total_loss / valid_classes # 类别平均 + # 平均类别损失 + total_loss /= valid_classes + + # 加入负熵约束,防止类别均匀化 + def entropy_loss(pred_probs): + avg_probs = pred_probs.mean(dim=(2, 3)) + entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-6), dim=1) + return entropy.mean() + + total_loss += lambda_entropy * entropy_loss(pred_probs) + + # 加入局部类别匹配 + def local_count_loss(pred_probs, target_probs, grid_size=8): + pred_local = F.avg_pool2d(pred_probs, kernel_size=grid_size, stride=grid_size) + target_local = F.avg_pool2d(target_probs, kernel_size=grid_size, stride=grid_size) + return F.mse_loss(pred_local, target_local) + + total_loss += lambda_local * local_count_loss(pred_probs, target_map, grid_size) + + return total_loss def illegal_tile_loss( pred_probs: torch.Tensor, diff --git a/ginka/model/model.py b/ginka/model/model.py index 4cdb830..177966c 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -2,31 +2,48 @@ import torch import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet -from .input import GinkaInput from .output import GinkaOutput +def print_memory(tag=""): + print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") + class GinkaModel(nn.Module): def __init__(self, feat_dim=1024, base_ch=64, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() - self.input = GinkaInput(feat_dim, base_ch) - self.unet = GinkaUNet(base_ch, num_classes) + self.unet = GinkaUNet(1, base_ch, num_classes, feat_dim) self.output = GinkaOutput(num_classes, (13, 13)) - print(f"Input parameters: {sum(p.numel() for p in self.input.parameters())}") - print(f"UNet parameters: {sum(p.numel() for p in self.unet.parameters())}") - print(f"Output parameters: {sum(p.numel() for p in self.output.parameters())}") - print(f"Total parameters: {sum(p.numel() for p in self.parameters())}") - def forward(self, x): + def forward(self, x, feat): """ Args: feat: 参考地图的特征向量 Returns: logits: 输出logits [BS, num_classes, H, W] """ - x = self.input(x) - x = self.unet(x) + x = self.unet(x, feat) x = self.output(x) return x, F.softmax(x, dim=1) + +# 检查显存占用 +if __name__ == "__main__": + x = torch.randn((1, 1, 32, 32)).cuda() + feat = torch.randn((1, 1024)).cuda() + + # 初始化模型 + model = GinkaModel().cuda() + + print_memory("初始化后") + + # 前向传播 + output, output_softmax = model(x, feat) + + print_memory("前向传播后") + + print(f"输入形状: x={x.shape}, feat={feat.shape}") + print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}") + print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}") + print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") \ No newline at end of file diff --git a/ginka/model/unet.py b/ginka/model/unet.py index 9303e51..3b7484d 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,106 +1,128 @@ import torch import torch.nn as nn import torch.nn.functional as F -from shared.attention import CBAM, SEBlock + +class GinkaAdaIN(nn.Module): + def __init__(self, num_features, condition_dim): + """ + 自适应实例归一化 (AdaIN) + 参数: + num_features: 归一化的通道数 + condition_dim: 条件输入的特征维度 + """ + super(GinkaAdaIN, self).__init__() + self.fc = nn.Linear(condition_dim, num_features * 2) # γ 和 β + + def forward(self, x, condition): + """ + x: [B, C, H, W] - 输入特征图 + condition: [B, condition_dim] - 需要注入的条件向量 + """ + gamma, beta = self.fc(condition).chunk(2, dim=1) # 分割为 γ 和 β + gamma = gamma.view(x.shape[0], x.shape[1], 1, 1) # 调整形状 + beta = beta.view(x.shape[0], x.shape[1], 1, 1) + + x = F.instance_norm(x) # 标准化 + return gamma * x + beta # 进行变换 + +class ConvBlock(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(), + nn.Conv2d(out_ch, out_ch, 3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(), + ) + + def forward(self, x): + return self.conv(x) + +class AdaINConvBlock(nn.Module): + def __init__(self, in_ch, out_ch, feat_dim): + super().__init__() + self.conv = ConvBlock(in_ch, out_ch) + self.adain = GinkaAdaIN(out_ch, feat_dim) + + def forward(self, x, feat): + x = self.conv(x) + x = self.adain(x, feat) + return x class GinkaEncoder(nn.Module): """编码器(下采样)部分""" - def __init__(self, in_channels, out_channels, attention=False, block='CBAM'): + def __init__(self, in_ch, out_ch, feat_dim): + super().__init__() + self.conv = ConvBlock(in_ch, out_ch) + self.pool = nn.MaxPool2d(2) + self.adain = GinkaAdaIN(out_ch, feat_dim) + + def forward(self, x, feat): + x = self.conv(x) + x = self.pool(x) + x = self.adain(x, feat) + return x + +class GinkaUpSample(nn.Module): + def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.InstanceNorm2d(out_channels), + nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2), + nn.BatchNorm2d(out_ch), nn.GELU(), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.InstanceNorm2d(out_channels), ) - # 注意力 - if attention: - if block == 'CBAM': - self.conv.append(CBAM(out_channels)) - elif block == 'SEBlock': - self.conv.append(SEBlock(out_channels)) - self.conv.append(nn.GELU()) - self.down = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) - + def forward(self, x): - x_res = self.conv(x) - x_down = self.down(x_res) - return x_down, x_res + return self.conv(x) class GinkaDecoder(nn.Module): """解码器(上采样)部分""" - def __init__(self, in_channels, out_channels, attention=False, block='CBAM'): + def __init__(self, in_ch, out_ch, feat_dim): super().__init__() - self.upsample = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) - - self.conv = nn.Sequential( - nn.Conv2d(in_channels + out_channels, out_channels, kernel_size=3, padding=1), - nn.InstanceNorm2d(out_channels), - ) - # 注意力 - if attention: - if block == 'CBAM': - self.conv.append(CBAM(out_channels)) - elif block == 'SEBlock': - self.conv.append(SEBlock(out_channels)) - self.conv.append(nn.GELU()) - - def forward(self, x, skip): + self.upsample = GinkaUpSample(in_ch, in_ch // 2) + self.conv = ConvBlock(in_ch, out_ch) + self.adain = GinkaAdaIN(out_ch, feat_dim) + + def forward(self, x, skip, feat): x = self.upsample(x) - x = torch.cat([x, skip], dim=1) + x = torch.cat([x, skip], dim=1) x = self.conv(x) + x = self.adain(x, feat) return x - -class GinkaBottleneck(nn.Module): - def __init__(self, in_channels, out_channels, attention=False): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), - nn.InstanceNorm2d(out_channels), - nn.GELU(), - nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), - nn.InstanceNorm2d(out_channels), - ) - if attention: - self.conv.append(SEBlock(out_channels)) - self.conv.append(nn.GELU()) - - def forward(self, x): - return self.conv(x) class GinkaUNet(nn.Module): - def __init__(self, in_ch=64, out_ch=32): + def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024): """Ginka Model UNet 部分 """ super().__init__() - self.down1 = GinkaEncoder(in_ch, in_ch*2, attention=True) - self.down2 = GinkaEncoder(in_ch*2, in_ch*4, attention=True) - self.down3 = GinkaEncoder(in_ch*4, in_ch*8, attention=True, block='SEBlock') - self.down4 = GinkaEncoder(in_ch*8, in_ch*16, attention=True, block='SEBlock') + self.in_conv = AdaINConvBlock(in_ch, base_ch, feat_dim) + self.down1 = GinkaEncoder(base_ch, base_ch*2, feat_dim) + self.down2 = GinkaEncoder(base_ch*2, base_ch*4, feat_dim) + self.down3 = GinkaEncoder(base_ch*4, base_ch*8, feat_dim) - self.bottleneck = GinkaBottleneck(in_ch*16, in_ch*16, attention=True) + self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16, feat_dim) - self.up1 = GinkaDecoder(in_ch*16, in_ch*8, attention=True, block='SEBlock') - self.up2 = GinkaDecoder(in_ch*8, in_ch*4, attention=True, block='SEBlock') - self.up3 = GinkaDecoder(in_ch*4, in_ch*2, attention=True) - self.up4 = GinkaDecoder(in_ch*2, in_ch, attention=True) + self.up1 = GinkaDecoder(base_ch*16, base_ch*8, feat_dim) + self.up2 = GinkaDecoder(base_ch*8, base_ch*4, feat_dim) + self.up3 = GinkaDecoder(base_ch*4, base_ch*2, feat_dim) + self.up4 = GinkaDecoder(base_ch*2, base_ch, feat_dim) self.final = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 1), + nn.Conv2d(base_ch, out_ch, 1), ) - def forward(self, x): - x_down1, skip1 = self.down1(x) - x_down2, skip2 = self.down2(x_down1) - x_down3, skip3 = self.down3(x_down2) - x_down4, skip4 = self.down4(x_down3) - - x = self.bottleneck(x_down4) - - x = self.up1(x, skip4) - x = self.up2(x, skip3) - x = self.up3(x, skip2) - x = self.up4(x, skip1) + def forward(self, x, feat): + x1 = self.in_conv(x, feat) + x2 = self.down1(x1, feat) + x3 = self.down2(x2, feat) + x4 = self.down3(x3, feat) + x5 = self.bottleneck(x4, feat) + + x = self.up1(x5, x4, feat) + x = self.up2(x, x3, feat) + x = self.up3(x, x2, feat) + x = self.up4(x, x1, feat) return self.final(x) diff --git a/ginka/train.py b/ginka/train.py index b4f7b9b..c339629 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -10,14 +10,12 @@ from .dataset import GinkaDataset from minamo.model.model import MinamoModel from shared.args import parse_arguments +BATCH_SIZE = 32 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/ginka_checkpoint", exist_ok=True) -# 在生成器输出后添加梯度检查钩子 -def grad_hook(module, grad_input, grad_output): - print(f"Generator output grad norm: {grad_output[0].norm().item()}") - def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") @@ -29,21 +27,18 @@ def train(): minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) minamo.to(device) minamo.eval() - - # for param in minamo.parameters(): - # param.requires_grad = False # 准备数据集 dataset = GinkaDataset(args.train, device, minamo) dataset_val = GinkaDataset(args.validate, device, minamo) dataloader = DataLoader( dataset, - batch_size=32, + batch_size=BATCH_SIZE, shuffle=True ) dataloader_val = DataLoader( dataset_val, - batch_size=32, + batch_size=BATCH_SIZE, shuffle=True ) @@ -52,9 +47,6 @@ def train(): scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = GinkaLoss(minamo) - # model.register_full_backward_hook(grad_hook) - # converter.register_full_backward_hook(grad_hook) - # criterion.register_full_backward_hook(grad_hook) if args.resume: data = torch.load(args.from_state, map_location=device) model.load_state_dict(data["model_state"]) @@ -83,30 +75,20 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 optimizer.zero_grad() - _, output_softmax = model(feat_vec) + noise = torch.randn((BATCH_SIZE, 1, 32, 32)) + _, output_softmax = model(noise, feat_vec) # 计算损失 scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat) # 反向传播 - scaled_losses.backward() + losses.backward() optimizer.step() total_loss += losses.item() - # for name, param in model.named_parameters(): - # if param.grad is not None: - # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") avg_loss = total_loss / len(dataloader) tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") - # total_norm = 0 - # for p in model.parameters(): - # if p.grad is not None: - # param_norm = p.grad.detach().data.norm(2) - # total_norm += param_norm.item() ** 2 - # total_norm = total_norm ** 0.5 - # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 - # for name, param in model.named_parameters(): # if param.grad is not None: # print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}") diff --git a/minamo/dataset.py b/minamo/dataset.py index 0aa8060..079c4d4 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -1,4 +1,5 @@ import json +import random import torch import torch.nn.functional as F from torch.utils.data import Dataset @@ -28,8 +29,12 @@ class MinamoDataset(Dataset): map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - map1_probs = random_smooth_onehot(map1_probs) - map2_probs = random_smooth_onehot(map2_probs) + min_main = random.uniform(0.7, 1) + max_main = random.uniform(0.9, 1) + epsilon = random.uniform(0, 0.3) + + map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon) + map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon) graph1 = differentiable_convert_to_data(map1_probs) graph2 = differentiable_convert_to_data(map2_probs) diff --git a/minamo/train.py b/minamo/train.py index ef81db6..91991eb 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -16,24 +16,6 @@ os.makedirs("result", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True) disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm -def collate_fn(batch): - """动态处理不同尺寸地图的批处理""" - map1_batch = [item[0] for item in batch] - map2_batch = [item[1] for item in batch] - vis_sim = torch.cat([item[2] for item in batch]) - topo_sim = torch.cat([item[3] for item in batch]) - - # 保持批次内地图尺寸一致(根据问题描述) - assert all(m.shape == map1_batch[0].shape for m in map1_batch), \ - "对比地图必须尺寸相同" - - return ( - torch.stack(map1_batch), # (B, H, W) - torch.stack(map2_batch), # (B, H, W) - vis_sim, - topo_sim - ) - def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") @@ -91,6 +73,9 @@ def train(): graph1 = graph1.to(device) graph2 = graph2.to(device) + if map1.shape[0] == 1: + continue + # 前向传播 optimizer.zero_grad() vision_feat1, topo_feat1 = model(map1, graph1)