From 1566acf6913984899083a0fe00b755d240eb4d28 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 18 Mar 2025 21:11:53 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E9=87=8D=E6=9E=84=20ginka=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 38 +++------ ginka/model/loss.py | 13 +-- ginka/model/model.py | 184 ++++++------------------------------------ ginka/model/sample.py | 79 ------------------ ginka/model/unet.py | 91 +++++++++++++++++++++ ginka/train.py | 70 +++++----------- minamo/model/model.py | 11 +-- minamo/train.py | 3 +- minamo/validate.py | 19 +++-- 9 files changed, 172 insertions(+), 336 deletions(-) delete mode 100644 ginka/model/sample.py create mode 100644 ginka/model/unet.py diff --git a/ginka/dataset.py b/ginka/dataset.py index d89a8b1..d1158eb 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -1,8 +1,8 @@ import json -import random import torch from torch.utils.data import Dataset -from transformers import BertTokenizer +from minamo.model.model import MinamoModel +from shared.graph import convert_map_to_graph def load_data(path: str): with open(path, 'r', encoding="utf-8") as f: @@ -15,11 +15,10 @@ def load_data(path: str): return data_list class GinkaDataset(Dataset): - def __init__(self, data_path: str, tokenizer: BertTokenizer, max_len=128): + def __init__(self, data_path: str, minamo: MinamoModel): self.data = load_data(data_path) # 自定义数据加载函数 - self.tokenizer = tokenizer - self.max_len = max_len self.max_size = 32 + self.minamo = minamo def __len__(self): return len(self.data) @@ -27,28 +26,13 @@ class GinkaDataset(Dataset): def __getitem__(self, idx): item = self.data[idx] - # 文本处理 - text = random.choice(item["text"]) - encoding = self.tokenizer( - text, - max_length=self.max_len, - padding="max_length", - truncation=True, - return_tensors="pt" - ) - - # 噪声生成 - w, h = item["size"] - noise = torch.randn(h, w, 1) - - # 目标矩阵填充 - target = torch.full((self.max_size, self.max_size), -100) # 使用-100忽略填充区域 - target[:h, :w] = torch.tensor(item["map"]) + target = torch.tensor(item["map"]) + graph = convert_map_to_graph(target) + vision_feat, topo_feat = self.minamo(target, graph) + feat_vec = torch.cat([vision_feat, topo_feat]) return { - "noise": noise, - "input_ids": encoding["input_ids"].squeeze(), - "attention_mask": encoding["attention_mask"].squeeze(), - "map_size": torch.tensor([h, w]), + "feat_vec": feat_vec, "target": target - } \ No newline at end of file + } + \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index f87abb5..304d621 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -2,7 +2,7 @@ import math import torch import torch.nn as nn import torch.nn.functional as F -from pytorch_toolbelt import losses as L +from minamo.model.model import MinamoModel def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]): """地图最外层是否为墙""" @@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler return avg_loss class GinkaLoss(nn.Module): - def __init__(self, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): + def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): """Ginka Model 损失函数部分 Args: @@ -299,13 +299,11 @@ class GinkaLoss(nn.Module): """ super().__init__() self.weight = weight - self.dice = L.DiceLoss(mode='multiclass') self.ce = nn.CrossEntropyLoss() + self.minamo = minamo - def forward(self, pred, target): + def forward(self, pred, pred_softmax, target): probs = F.softmax(pred, dim=1) - # 拓扑结构损失 - # structure_loss = topology_loss(pred, target) # 地图结构损失 border_loss = wall_border_loss(pred, probs) wall_loss = internal_wall_loss(pred, probs) @@ -315,6 +313,9 @@ class GinkaLoss(nn.Module): valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean") count_loss = integrated_count_loss(probs, target) + # 使用 Minamo Model 计算相似度 + + print( # structure_loss.item(), border_loss.item(), diff --git a/ginka/model/model.py b/ginka/model/model.py index 2aa4db8..f79ce49 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -1,180 +1,44 @@ import torch import torch.nn as nn import torch.nn.functional as F -from transformers import BertModel -from ...shared.attention import CBAM, SpatialAttention -from .sample import HybridUpsample, FinalUpsample, GumbelSampler - -class ResidualBlock(nn.Module): - """残差块""" - def __init__(self, channels): - super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(channels, channels, 3, padding=1), - nn.GroupNorm(8, channels), - nn.GELU(), - nn.Conv2d(channels, channels, 3, padding=1), - nn.GroupNorm(8, channels) - ) - - def forward(self, x): - return x + self.conv(x) +from .unet import GinkaUNet -class DynamicPadConv(nn.Module): - """支持动态处理奇数尺寸的智能卷积""" - def __init__(self, in_ch, out_ch, kernel=3, stride=1): +class GumbelSoftmax(nn.Module): + def __init__(self, tau=1.0, hard=True): super().__init__() - self.conv = nn.Conv2d( - in_ch, out_ch, kernel, - stride=stride, - padding=kernel//2 - ) - self.requires_pad = (stride > 1) # 仅在下采样时需要填充 - - def forward(self, x): - if self.requires_pad: - # 动态计算各维度需要填充的量 - pad_h = x.size(-2) % 2 - pad_w = x.size(-1) % 2 - if pad_h or pad_w: - x = F.pad(x, (0, pad_w, 0, pad_h)) # 右下填充 - return self.conv(x) + self.tau = tau # 温度参数 + self.hard = hard # 是否生成硬性one-hot -class ConditionInjector(nn.Module): - """基于注意力机制的条件注入""" - def __init__(self, cond_dim=128, feat_dim=256): - super().__init__() - self.cond_proj = nn.Sequential( - nn.Linear(cond_dim, feat_dim * 2), - nn.GELU(), - nn.LayerNorm(feat_dim * 2) - ) - self.channel_att = nn.Sequential( - nn.Conv2d(feat_dim, feat_dim//8, 1), - nn.GELU(), - nn.Conv2d(feat_dim//8, feat_dim, 1), - nn.Sigmoid() - ) - - def forward(self, x, cond): - # 投影条件向量 - gamma, beta = self.cond_proj(cond).chunk(2, dim=1) # [B, D] + def forward(self, logits): + # logits形状: [BS, C, H, W] + y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard) - # 通道注意力调制 - att = self.channel_att(x) # [B, C, H, W] - modulated = x * att - - # 添加条件偏置 - return modulated + beta.view(-1, gamma.size(1), 1, 1) - -class GinkaEncoder(nn.Module): - """编码器(下采样)部分""" - def __init__(self, in_ch, out_ch): - super().__init__() - self.encoder = nn.Sequential( - DynamicPadConv(in_ch, out_ch, stride=1), - ResidualBlock(out_ch), - CBAM(out_ch), - nn.GroupNorm(8, out_ch), - nn.GELU() - ) - - def forward(self, x): - return self.encoder(x) + # 转换为类索引的连续表示 + class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1) + return (y * class_indices).sum(dim=1) # 形状[BS, H, W] class GinkaModel(nn.Module): - def __init__(self, in_ch=1, base_ch=64, num_classes=32): + def __init__(self, feat_dim=256, base_ch=64, num_classes=32): """Ginka Model 模型定义部分 - - Args: - in_ch (int, optional): 输入通道数,默认是 1 - base_ch (int, optional): UNet 上下采样卷积基础通道数,默认 64 - num_classes (int, optional): 图块种类数量,默认 32 预留出一部分以供后续拓展功能 """ super().__init__() - - # 轻量级文本编码器(使用BERT前4层) - self.bert = BertModel.from_pretrained('google-bert/bert-base-chinese', output_hidden_states=True) - self.text_proj = nn.Linear(768, 128) - - # 动态尺寸处理系统 - self.size_embed = nn.Embedding(32, 16) # 处理最大32的尺寸 - - # 编码器 - self.enc1 = GinkaEncoder(in_ch, base_ch) - self.enc2 = GinkaEncoder(base_ch, base_ch * 2) - # self.enc3 = GinkaEncoder(base_ch * 2, base_ch * 4) - - # 中间层 - self.mid = nn.Sequential( - DynamicPadConv(base_ch * 2, base_ch * 4), - ConditionInjector(160, base_ch * 4) + self.base_ch = base_ch + self.fc = nn.Sequential( + nn.Linear(feat_dim, 32 * 32 * base_ch) ) + self.unet = GinkaUNet(base_ch, num_classes) + self.softmax = GumbelSoftmax() - # 解码器,解码器仅使用空间注意力 - self.dec1 = HybridUpsample(base_ch * 4, base_ch * 2) - self.dec1_att = SpatialAttention() - - self.dec2 = HybridUpsample(base_ch * 2, base_ch) - self.dec2_att = SpatialAttention() - - # self.dec3 = HybridUpsample(base_ch * 2, base_ch) - # self.dec3_att = SpatialAttention() - - # 输出层 - self.out = FinalUpsample(base_ch, num_classes) - - def forward(self, noise, input_ids, attention_mask, map_size): + def forward(self, feat): """ Args: - noise: 噪声输入 [BS, H, W, 1] - input_ids: 文本token id [BS, seq_len] - attention_mask: 文本attention mask [BS, seq_len] - map_size: 地图尺寸 [BS, 2] (height, width) + feat: 参考地图的特征向量 Returns: logits: 输出logits [BS, num_classes, H, W] """ - # 文本特征提取 - with torch.no_grad(): # 冻结BERT参数 - bert_outputs = self.bert( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True - ) - # 取前4层隐藏状态的平均 - hidden_states = torch.stack(bert_outputs.hidden_states[1:5]) # [4, BS, seq_len, 768] - text_features = torch.mean(hidden_states, dim=0)[:, 0, :] # [BS, 768] - text_features = self.text_proj(text_features) # [BS, 128] - - # 尺寸特征处理 - h_emb = self.size_embed(map_size[:, 0]) # [BS, 16] - w_emb = self.size_embed(map_size[:, 1]) # [BS, 16] - size_features = torch.cat([h_emb, w_emb], dim=1) # [BS, 32] - - # 特征融合 - conditional = torch.cat([text_features, size_features], dim=1) # [BS, 160] - - # 调整噪声输入维度 - x = noise.permute(0, 3, 1, 2) # [BS, 1, H, W] - - # 编码器路径 - x1 = self.enc1(x) # [BS, 64, H / 2, W / 2] - x2 = self.enc2(x1) # [BS, 128, H / 4, W / 4] - - # 中间层(注入条件) - x_mid = self.mid[0](x2) # [BS, 256, H / 4, W / 4] - x_mid = self.mid[1](x_mid, conditional) - - # 解码器路径 - d1 = self.dec1(x_mid, x2) # [BS, 128, H / 2, W / 2] - d1 = self.dec1_att(d1) - d2 = self.dec2(d1, x1) # [BS, 64, H, W] - d2 = self.dec2_att(d2) - # d3 = self.dec3(d2, x1) - # d3 = self.dec3_att(d3) - - # 最终自适应上采样 - h, w = noise.shape[1:3] # 获取原始输入尺寸 - return self.out(d2, (h, w)) - + x = self.fc(feat) + x = x.view(-1, self.base_ch, 32, 32) + x = self.unet(x) + x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False) + return x, self.softmax(x) \ No newline at end of file diff --git a/ginka/model/sample.py b/ginka/model/sample.py deleted file mode 100644 index d8e6e95..0000000 --- a/ginka/model/sample.py +++ /dev/null @@ -1,79 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchvision.ops as ops - -class HybridUpsample(nn.Module): - """自适应尺寸的混合上采样""" - def __init__(self, in_ch, out_ch, skip_ch=None): - super().__init__() - # 子像素卷积上采样 - self.subpixel = nn.Sequential( - nn.Conv2d(in_ch, out_ch * 4, 3, padding=1), - nn.PixelShuffle(2) # 2倍上采样 - ) - - # 跳跃连接处理 - self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None - self.adaptive_pool = nn.AdaptiveAvgPool2d(None) - - def forward(self, x, skip=None): - x = self.subpixel(x) # [B, out_ch, 2H, 2W] - - if skip is not None and self.skip_conv: - # 自动对齐尺寸 - if x.shape[-2:] != skip.shape[-2:]: - skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest') - - # 融合特征 - x = x + self.skip_conv(skip) - - return x - -class DiscreteAwareUpsample(nn.Module): - """离散感知的智能上采样模块""" - def __init__(self, in_ch, out_ch, base_size=16): - super().__init__() - self.base_size = base_size - self.scale_factors = [2, 4, 8] # 支持放大倍数 - - # 可变形卷积增强几何感知 - self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1) - - # 多尺度特征融合 - self.multi_scale = nn.ModuleList([ - nn.Sequential( - nn.Conv2d(in_ch, in_ch//4, 1), - nn.Upsample(scale_factor=s, mode='nearest') - ) for s in self.scale_factors - ]) - - # 门控上采样机制 - self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1) - - # 离散化输出层 - self.final_conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch*4, 3, padding=1), - nn.PixelShuffle(2), # 亚像素卷积 - nn.Conv2d(out_ch, out_ch, 3, padding=1) - ) - - def forward(self, x, target_size): - # 几何特征提取 - deform_feat = self.deform_conv(x) - - # 生成多尺度特征 - scale_features = [f(deform_feat) for f in self.multi_scale] - - # 动态门控选择 - gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1) - - # 加权融合多尺度特征 - combined = sum(g * F.interpolate(f, size=target_size, mode='nearest') - for g, f in zip(gate_map.unbind(1), scale_features+[x])) - - # 离散化上采样 - out = self.final_conv(combined) - - # 结构化约束(保持通道独立性) - return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留 diff --git a/ginka/model/unet.py b/ginka/model/unet.py new file mode 100644 index 0000000..58fee2f --- /dev/null +++ b/ginka/model/unet.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +class GinkaEncoder(nn.Module): + """编码器(下采样)部分""" + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU() + ) + self.pool = nn.MaxPool2d(2) + + def forward(self, x): + x_res = self.conv(x) # 卷积提取特征 + x_down = self.pool(x_res) # 进行池化 + return x_down, x_res # 返回池化后的特征和跳跃连接特征 + +class GinkaDecoder(nn.Module): + """解码器(上采样)部分""" + def __init__(self, in_channels, out_channels): + super().__init__() + # 上采样(双线性插值 + 卷积) + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU() + ) + + # 跳跃连接融合 + self.conv = nn.Sequential( + nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU() + ) + + def forward(self, x, skip): + x = self.upsample(x) + # 跳跃连接融合 + x = torch.cat([x, skip], dim=1) + x = self.conv(x) + return x + +class GinkaBottleneck(nn.Module): + def __init__(self, in_channels, out_channels): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + nn.BatchNorm2d(out_channels), + nn.ReLU(), + ) + + def forward(self, x): + return self.conv(x) + +class GinkaUNet(nn.Module): + def __init__(self, in_ch=64, out_ch=32): + """Ginka Model UNet 部分 + """ + super().__init__() + + self.down1 = GinkaEncoder(in_ch, in_ch*2) + self.down2 = GinkaEncoder(in_ch*2, in_ch*4) + + self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4) + + self.up1 = GinkaDecoder(in_ch*4, in_ch*2) + self.up2 = GinkaDecoder(in_ch*2, in_ch) + + self.final = nn.Sequential( + nn.Conv2d(in_ch, out_ch, 1) + ) + + def forward(self, x): + x, skip1 = self.down1(x) + x, skip2 = self.down2(x) + + x = self.bottleneck(x) + + x = self.up1(x, skip2) + x = self.up2(x, skip1) + + return self.final(x) diff --git a/ginka/train.py b/ginka/train.py index a8e9d2c..076de3b 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -9,87 +9,57 @@ from tqdm import tqdm from .model.model import GinkaModel from .model.loss import GinkaLoss from .dataset import GinkaDataset +from minamo.model.model import MinamoModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) -epochs = 100 +epochs = 70 -def collate_fn(batch): - # 动态填充噪声到最大尺寸 - max_h = max([b["noise"].shape[0] for b in batch]) - max_w = max([b["noise"].shape[1] for b in batch]) - - padded_batch = {} - for key in ["noise", "target"]: - padded = [] - for b in batch: - tensor = b[key] - pad_h = max_h - tensor.shape[0] - pad_w = max_w - tensor.shape[1] - padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0)) - padded_batch[key] = torch.stack(padded) - - # 其他字段直接堆叠 - for key in ["input_ids", "attention_mask", "map_size"]: - padded_batch[key] = torch.stack([b[key] for b in batch]) - - return padded_batch +def update_tau(epoch): + start_tau = 1.0 + min_tau = 0.1 + decay_rate = 0.95 + return max(min_tau, start_tau * (decay_rate ** epoch)) def train(): - print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.") + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") model = GinkaModel() model.to(device) + minamo = MinamoModel(32) + minamo.to(device) + minamo.eval() # 准备数据集 - tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese') - dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer) + dataset = GinkaDataset("dataset.json", minamo) dataloader = DataLoader( dataset, - batch_size=4, - shuffle=True, - collate_fn=collate_fn, - num_workers=0 + batch_size=32, + shuffle=True ) # 设定优化器与调度器 optimizer = optim.AdamW(model.parameters(), lr=3e-4) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) - criterion = GinkaLoss() + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) + criterion = GinkaLoss(minamo) # 开始训练 for epoch in tqdm(range(epochs)): model.train() total_loss = 0 - - # 温度退火 - model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs) + model.softmax.tau = update_tau(epoch) for batch in dataloader: # 数据迁移到设备 - noise = batch["noise"].to(device) - input_ids = batch["input_ids"].to(device) - attention_mask = batch["attention_mask"].to(device) - map_size = batch["map_size"].to(device) target = batch["target"].to(device) + feat_vec = batch["feat_vec"].to(device) # 前向传播 optimizer.zero_grad() - outputs = model(noise, input_ids, attention_mask, map_size) - - print(torch.argmax(torch.softmax(outputs, dim=1), dim=1)) - # print(sampled[0, :, :, 1]) - - # 构建拓扑图 - # with torch.no_grad(): - # pred_graphs = build_topology_graph(outputs.argmax(1)) - # ref_graphs = build_topology_graph(target) + output, output_softmax = model(feat_vec) # 计算损失 - loss = criterion( - outputs, # 调整为 [BS, C, H, W] - target - ) + loss = criterion(output, output_softmax, target) # 反向传播 loss.backward() diff --git a/minamo/model/model.py b/minamo/model/model.py index 1f09a65..ffb051e 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -11,11 +11,8 @@ class MinamoModel(nn.Module): # 拓扑相似度部分 self.topo_model = MinamoTopoModel(tile_types) - def forward(self, map1, map2, graph1, graph2): - vision_feat1 = self.vision_model(map1) - vision_feat2 = self.vision_model(map2) + def forward(self, map, graph): + vision_feat = self.vision_model(map) + topo_feat = self.topo_model(graph) - topo_feat1 = self.topo_model(graph1) - topo_feat2 = self.topo_model(graph2) - - return vision_feat1, vision_feat2, topo_feat1, topo_feat2 + return vision_feat, topo_feat diff --git a/minamo/train.py b/minamo/train.py index 04a0103..211d90a 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -76,7 +76,8 @@ def train(): # 前向传播 optimizer.zero_grad() - vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2) + vision_feat1, topo_feat1 = model(map1, graph1) + vision_feat2, topo_feat2 = model(map2, graph2) vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) diff --git a/minamo/validate.py b/minamo/validate.py index ca5b629..1af2731 100644 --- a/minamo/validate.py +++ b/minamo/validate.py @@ -1,5 +1,6 @@ import torch -from torch.utils.data import DataLoader +import torch.nn.functional as F +from torch_geometric.loader import DataLoader from tqdm import tqdm from .model.model import MinamoModel from .model.loss import MinamoLoss @@ -8,32 +9,38 @@ from .dataset import MinamoDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def validate(): - print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.") + print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = MinamoModel(32) model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) model.to(device) # 准备数据集 - val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") + val_dataset = MinamoDataset("minamo-eval.json") val_loader = DataLoader( val_dataset, batch_size=32, shuffle=True ) - criterion = MinamoLoss(temp=0.8) + criterion = MinamoLoss() model.eval() val_loss = 0 with torch.no_grad(): for val_batch in tqdm(val_loader): - map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch + map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch map1_val = map1_val.to(device) map2_val = map2_val.to(device) vision_simi_val = vision_simi_val.to(device) topo_simi_val = topo_simi_val.to(device) + graph1 = graph1.to(device) + graph2 = graph2.to(device) - vision_pred_val, topo_pred_val = model(map1_val, map2_val) + vision_feat1, topo_feat1 = model(map1_val, graph1) + vision_feat2, topo_feat2 = model(map2_val, graph2) + + vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) + topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) loss_val = criterion( vision_pred_val, topo_pred_val, vision_simi_val, topo_simi_val