refactor: 重构 ginka 模型部分

This commit is contained in:
unanmed 2025-03-18 21:11:53 +08:00
parent 09b66f2569
commit 1566acf691
9 changed files with 172 additions and 336 deletions

View File

@ -1,8 +1,8 @@
import json import json
import random
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
from transformers import BertTokenizer from minamo.model.model import MinamoModel
from shared.graph import convert_map_to_graph
def load_data(path: str): def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:
@ -15,11 +15,10 @@ def load_data(path: str):
return data_list return data_list
class GinkaDataset(Dataset): class GinkaDataset(Dataset):
def __init__(self, data_path: str, tokenizer: BertTokenizer, max_len=128): def __init__(self, data_path: str, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数 self.data = load_data(data_path) # 自定义数据加载函数
self.tokenizer = tokenizer
self.max_len = max_len
self.max_size = 32 self.max_size = 32
self.minamo = minamo
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
@ -27,28 +26,13 @@ class GinkaDataset(Dataset):
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.data[idx] item = self.data[idx]
# 文本处理 target = torch.tensor(item["map"])
text = random.choice(item["text"]) graph = convert_map_to_graph(target)
encoding = self.tokenizer( vision_feat, topo_feat = self.minamo(target, graph)
text, feat_vec = torch.cat([vision_feat, topo_feat])
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# 噪声生成
w, h = item["size"]
noise = torch.randn(h, w, 1)
# 目标矩阵填充
target = torch.full((self.max_size, self.max_size), -100) # 使用-100忽略填充区域
target[:h, :w] = torch.tensor(item["map"])
return { return {
"noise": noise, "feat_vec": feat_vec,
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"map_size": torch.tensor([h, w]),
"target": target "target": target
} }

View File

@ -2,7 +2,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from pytorch_toolbelt import losses as L from minamo.model.model import MinamoModel
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]): def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
"""地图最外层是否为墙""" """地图最外层是否为墙"""
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
return avg_loss return avg_loss
class GinkaLoss(nn.Module): class GinkaLoss(nn.Module):
def __init__(self, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]): def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
"""Ginka Model 损失函数部分 """Ginka Model 损失函数部分
Args: Args:
@ -299,13 +299,11 @@ class GinkaLoss(nn.Module):
""" """
super().__init__() super().__init__()
self.weight = weight self.weight = weight
self.dice = L.DiceLoss(mode='multiclass')
self.ce = nn.CrossEntropyLoss() self.ce = nn.CrossEntropyLoss()
self.minamo = minamo
def forward(self, pred, target): def forward(self, pred, pred_softmax, target):
probs = F.softmax(pred, dim=1) probs = F.softmax(pred, dim=1)
# 拓扑结构损失
# structure_loss = topology_loss(pred, target)
# 地图结构损失 # 地图结构损失
border_loss = wall_border_loss(pred, probs) border_loss = wall_border_loss(pred, probs)
wall_loss = internal_wall_loss(pred, probs) wall_loss = internal_wall_loss(pred, probs)
@ -315,6 +313,9 @@ class GinkaLoss(nn.Module):
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean") valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
count_loss = integrated_count_loss(probs, target) count_loss = integrated_count_loss(probs, target)
# 使用 Minamo Model 计算相似度
print( print(
# structure_loss.item(), # structure_loss.item(),
border_loss.item(), border_loss.item(),

View File

@ -1,180 +1,44 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from transformers import BertModel from .unet import GinkaUNet
from ...shared.attention import CBAM, SpatialAttention
from .sample import HybridUpsample, FinalUpsample, GumbelSampler
class ResidualBlock(nn.Module): class GumbelSoftmax(nn.Module):
"""残差块""" def __init__(self, tau=1.0, hard=True):
def __init__(self, channels):
super().__init__() super().__init__()
self.conv = nn.Sequential( self.tau = tau # 温度参数
nn.Conv2d(channels, channels, 3, padding=1), self.hard = hard # 是否生成硬性one-hot
nn.GroupNorm(8, channels),
nn.GELU(),
nn.Conv2d(channels, channels, 3, padding=1),
nn.GroupNorm(8, channels)
)
def forward(self, x): def forward(self, logits):
return x + self.conv(x) # logits形状: [BS, C, H, W]
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
class DynamicPadConv(nn.Module): # 转换为类索引的连续表示
"""支持动态处理奇数尺寸的智能卷积""" class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
def __init__(self, in_ch, out_ch, kernel=3, stride=1): return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
super().__init__()
self.conv = nn.Conv2d(
in_ch, out_ch, kernel,
stride=stride,
padding=kernel//2
)
self.requires_pad = (stride > 1) # 仅在下采样时需要填充
def forward(self, x):
if self.requires_pad:
# 动态计算各维度需要填充的量
pad_h = x.size(-2) % 2
pad_w = x.size(-1) % 2
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h)) # 右下填充
return self.conv(x)
class ConditionInjector(nn.Module):
"""基于注意力机制的条件注入"""
def __init__(self, cond_dim=128, feat_dim=256):
super().__init__()
self.cond_proj = nn.Sequential(
nn.Linear(cond_dim, feat_dim * 2),
nn.GELU(),
nn.LayerNorm(feat_dim * 2)
)
self.channel_att = nn.Sequential(
nn.Conv2d(feat_dim, feat_dim//8, 1),
nn.GELU(),
nn.Conv2d(feat_dim//8, feat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, cond):
# 投影条件向量
gamma, beta = self.cond_proj(cond).chunk(2, dim=1) # [B, D]
# 通道注意力调制
att = self.channel_att(x) # [B, C, H, W]
modulated = x * att
# 添加条件偏置
return modulated + beta.view(-1, gamma.size(1), 1, 1)
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.encoder = nn.Sequential(
DynamicPadConv(in_ch, out_ch, stride=1),
ResidualBlock(out_ch),
CBAM(out_ch),
nn.GroupNorm(8, out_ch),
nn.GELU()
)
def forward(self, x):
return self.encoder(x)
class GinkaModel(nn.Module): class GinkaModel(nn.Module):
def __init__(self, in_ch=1, base_ch=64, num_classes=32): def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分 """Ginka Model 模型定义部分
Args:
in_ch (int, optional): 输入通道数默认是 1
base_ch (int, optional): UNet 上下采样卷积基础通道数默认 64
num_classes (int, optional): 图块种类数量默认 32 预留出一部分以供后续拓展功能
""" """
super().__init__() super().__init__()
self.base_ch = base_ch
# 轻量级文本编码器使用BERT前4层 self.fc = nn.Sequential(
self.bert = BertModel.from_pretrained('google-bert/bert-base-chinese', output_hidden_states=True) nn.Linear(feat_dim, 32 * 32 * base_ch)
self.text_proj = nn.Linear(768, 128)
# 动态尺寸处理系统
self.size_embed = nn.Embedding(32, 16) # 处理最大32的尺寸
# 编码器
self.enc1 = GinkaEncoder(in_ch, base_ch)
self.enc2 = GinkaEncoder(base_ch, base_ch * 2)
# self.enc3 = GinkaEncoder(base_ch * 2, base_ch * 4)
# 中间层
self.mid = nn.Sequential(
DynamicPadConv(base_ch * 2, base_ch * 4),
ConditionInjector(160, base_ch * 4)
) )
self.unet = GinkaUNet(base_ch, num_classes)
self.softmax = GumbelSoftmax()
# 解码器,解码器仅使用空间注意力 def forward(self, feat):
self.dec1 = HybridUpsample(base_ch * 4, base_ch * 2)
self.dec1_att = SpatialAttention()
self.dec2 = HybridUpsample(base_ch * 2, base_ch)
self.dec2_att = SpatialAttention()
# self.dec3 = HybridUpsample(base_ch * 2, base_ch)
# self.dec3_att = SpatialAttention()
# 输出层
self.out = FinalUpsample(base_ch, num_classes)
def forward(self, noise, input_ids, attention_mask, map_size):
""" """
Args: Args:
noise: 噪声输入 [BS, H, W, 1] feat: 参考地图的特征向量
input_ids: 文本token id [BS, seq_len]
attention_mask: 文本attention mask [BS, seq_len]
map_size: 地图尺寸 [BS, 2] (height, width)
Returns: Returns:
logits: 输出logits [BS, num_classes, H, W] logits: 输出logits [BS, num_classes, H, W]
""" """
# 文本特征提取 x = self.fc(feat)
with torch.no_grad(): # 冻结BERT参数 x = x.view(-1, self.base_ch, 32, 32)
bert_outputs = self.bert( x = self.unet(x)
input_ids=input_ids, x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
attention_mask=attention_mask, return x, self.softmax(x)
output_hidden_states=True
)
# 取前4层隐藏状态的平均
hidden_states = torch.stack(bert_outputs.hidden_states[1:5]) # [4, BS, seq_len, 768]
text_features = torch.mean(hidden_states, dim=0)[:, 0, :] # [BS, 768]
text_features = self.text_proj(text_features) # [BS, 128]
# 尺寸特征处理
h_emb = self.size_embed(map_size[:, 0]) # [BS, 16]
w_emb = self.size_embed(map_size[:, 1]) # [BS, 16]
size_features = torch.cat([h_emb, w_emb], dim=1) # [BS, 32]
# 特征融合
conditional = torch.cat([text_features, size_features], dim=1) # [BS, 160]
# 调整噪声输入维度
x = noise.permute(0, 3, 1, 2) # [BS, 1, H, W]
# 编码器路径
x1 = self.enc1(x) # [BS, 64, H / 2, W / 2]
x2 = self.enc2(x1) # [BS, 128, H / 4, W / 4]
# 中间层(注入条件)
x_mid = self.mid[0](x2) # [BS, 256, H / 4, W / 4]
x_mid = self.mid[1](x_mid, conditional)
# 解码器路径
d1 = self.dec1(x_mid, x2) # [BS, 128, H / 2, W / 2]
d1 = self.dec1_att(d1)
d2 = self.dec2(d1, x1) # [BS, 64, H, W]
d2 = self.dec2_att(d2)
# d3 = self.dec3(d2, x1)
# d3 = self.dec3_att(d3)
# 最终自适应上采样
h, w = noise.shape[1:3] # 获取原始输入尺寸
return self.out(d2, (h, w))

View File

@ -1,79 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops as ops
class HybridUpsample(nn.Module):
"""自适应尺寸的混合上采样"""
def __init__(self, in_ch, out_ch, skip_ch=None):
super().__init__()
# 子像素卷积上采样
self.subpixel = nn.Sequential(
nn.Conv2d(in_ch, out_ch * 4, 3, padding=1),
nn.PixelShuffle(2) # 2倍上采样
)
# 跳跃连接处理
self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None
self.adaptive_pool = nn.AdaptiveAvgPool2d(None)
def forward(self, x, skip=None):
x = self.subpixel(x) # [B, out_ch, 2H, 2W]
if skip is not None and self.skip_conv:
# 自动对齐尺寸
if x.shape[-2:] != skip.shape[-2:]:
skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')
# 融合特征
x = x + self.skip_conv(skip)
return x
class DiscreteAwareUpsample(nn.Module):
"""离散感知的智能上采样模块"""
def __init__(self, in_ch, out_ch, base_size=16):
super().__init__()
self.base_size = base_size
self.scale_factors = [2, 4, 8] # 支持放大倍数
# 可变形卷积增强几何感知
self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1)
# 多尺度特征融合
self.multi_scale = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_ch, in_ch//4, 1),
nn.Upsample(scale_factor=s, mode='nearest')
) for s in self.scale_factors
])
# 门控上采样机制
self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1)
# 离散化输出层
self.final_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch*4, 3, padding=1),
nn.PixelShuffle(2), # 亚像素卷积
nn.Conv2d(out_ch, out_ch, 3, padding=1)
)
def forward(self, x, target_size):
# 几何特征提取
deform_feat = self.deform_conv(x)
# 生成多尺度特征
scale_features = [f(deform_feat) for f in self.multi_scale]
# 动态门控选择
gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1)
# 加权融合多尺度特征
combined = sum(g * F.interpolate(f, size=target_size, mode='nearest')
for g, f in zip(gate_map.unbind(1), scale_features+[x]))
# 离散化上采样
out = self.final_conv(combined)
# 结构化约束(保持通道独立性)
return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留

91
ginka/model/unet.py Normal file
View File

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x_res = self.conv(x) # 卷积提取特征
x_down = self.pool(x_res) # 进行池化
return x_down, x_res # 返回池化后的特征和跳跃连接特征
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 上采样(双线性插值 + 卷积)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
# 跳跃连接融合
self.conv = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, skip):
x = self.upsample(x)
# 跳跃连接融合
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
self.up2 = GinkaDecoder(in_ch*2, in_ch)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):
x, skip1 = self.down1(x)
x, skip2 = self.down2(x)
x = self.bottleneck(x)
x = self.up1(x, skip2)
x = self.up2(x, skip1)
return self.final(x)

View File

@ -9,87 +9,57 @@ from tqdm import tqdm
from .model.model import GinkaModel from .model.model import GinkaModel
from .model.loss import GinkaLoss from .model.loss import GinkaLoss
from .dataset import GinkaDataset from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
epochs = 100 epochs = 70
def collate_fn(batch): def update_tau(epoch):
# 动态填充噪声到最大尺寸 start_tau = 1.0
max_h = max([b["noise"].shape[0] for b in batch]) min_tau = 0.1
max_w = max([b["noise"].shape[1] for b in batch]) decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch))
padded_batch = {}
for key in ["noise", "target"]:
padded = []
for b in batch:
tensor = b[key]
pad_h = max_h - tensor.shape[0]
pad_w = max_w - tensor.shape[1]
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
padded_batch[key] = torch.stack(padded)
# 其他字段直接堆叠
for key in ["input_ids", "attention_mask", "map_size"]:
padded_batch[key] = torch.stack([b[key] for b in batch])
return padded_batch
def train(): def train():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = GinkaModel() model = GinkaModel()
model.to(device) model.to(device)
minamo = MinamoModel(32)
minamo.to(device)
minamo.eval()
# 准备数据集 # 准备数据集
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese') dataset = GinkaDataset("dataset.json", minamo)
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
dataloader = DataLoader( dataloader = DataLoader(
dataset, dataset,
batch_size=4, batch_size=32,
shuffle=True, shuffle=True
collate_fn=collate_fn,
num_workers=0
) )
# 设定优化器与调度器 # 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4) optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss() criterion = GinkaLoss(minamo)
# 开始训练 # 开始训练
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(epochs)):
model.train() model.train()
total_loss = 0 total_loss = 0
model.softmax.tau = update_tau(epoch)
# 温度退火
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
for batch in dataloader: for batch in dataloader:
# 数据迁移到设备 # 数据迁移到设备
noise = batch["noise"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
map_size = batch["map_size"].to(device)
target = batch["target"].to(device) target = batch["target"].to(device)
feat_vec = batch["feat_vec"].to(device)
# 前向传播 # 前向传播
optimizer.zero_grad() optimizer.zero_grad()
outputs = model(noise, input_ids, attention_mask, map_size) output, output_softmax = model(feat_vec)
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
# print(sampled[0, :, :, 1])
# 构建拓扑图
# with torch.no_grad():
# pred_graphs = build_topology_graph(outputs.argmax(1))
# ref_graphs = build_topology_graph(target)
# 计算损失 # 计算损失
loss = criterion( loss = criterion(output, output_softmax, target)
outputs, # 调整为 [BS, C, H, W]
target
)
# 反向传播 # 反向传播
loss.backward() loss.backward()

View File

@ -11,11 +11,8 @@ class MinamoModel(nn.Module):
# 拓扑相似度部分 # 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types) self.topo_model = MinamoTopoModel(tile_types)
def forward(self, map1, map2, graph1, graph2): def forward(self, map, graph):
vision_feat1 = self.vision_model(map1) vision_feat = self.vision_model(map)
vision_feat2 = self.vision_model(map2) topo_feat = self.topo_model(graph)
topo_feat1 = self.topo_model(graph1) return vision_feat, topo_feat
topo_feat2 = self.topo_model(graph2)
return vision_feat1, vision_feat2, topo_feat1, topo_feat2

View File

@ -76,7 +76,8 @@ def train():
# 前向传播 # 前向传播
optimizer.zero_grad() optimizer.zero_grad()
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2) vision_feat1, topo_feat1 = model(map1, graph1)
vision_feat2, topo_feat2 = model(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)

View File

@ -1,5 +1,6 @@
import torch import torch
from torch.utils.data import DataLoader import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from .model.model import MinamoModel from .model.model import MinamoModel
from .model.loss import MinamoLoss from .model.loss import MinamoLoss
@ -8,32 +9,38 @@ from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate(): def validate():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = MinamoModel(32) model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device) model.to(device)
# 准备数据集 # 准备数据集
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") val_dataset = MinamoDataset("minamo-eval.json")
val_loader = DataLoader( val_loader = DataLoader(
val_dataset, val_dataset,
batch_size=32, batch_size=32,
shuffle=True shuffle=True
) )
criterion = MinamoLoss(temp=0.8) criterion = MinamoLoss()
model.eval() model.eval()
val_loss = 0 val_loss = 0
with torch.no_grad(): with torch.no_grad():
for val_batch in tqdm(val_loader): for val_batch in tqdm(val_loader):
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device) map1_val = map1_val.to(device)
map2_val = map2_val.to(device) map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device) vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device) topo_simi_val = topo_simi_val.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val) vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion( loss_val = criterion(
vision_pred_val, topo_pred_val, vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val vision_simi_val, topo_simi_val