mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: 重构 ginka 模型部分
This commit is contained in:
parent
09b66f2569
commit
1566acf691
@ -1,8 +1,8 @@
|
||||
import json
|
||||
import random
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import BertTokenizer
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import convert_map_to_graph
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
@ -15,11 +15,10 @@ def load_data(path: str):
|
||||
return data_list
|
||||
|
||||
class GinkaDataset(Dataset):
|
||||
def __init__(self, data_path: str, tokenizer: BertTokenizer, max_len=128):
|
||||
def __init__(self, data_path: str, minamo: MinamoModel):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.tokenizer = tokenizer
|
||||
self.max_len = max_len
|
||||
self.max_size = 32
|
||||
self.minamo = minamo
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -27,28 +26,13 @@ class GinkaDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
# 文本处理
|
||||
text = random.choice(item["text"])
|
||||
encoding = self.tokenizer(
|
||||
text,
|
||||
max_length=self.max_len,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt"
|
||||
)
|
||||
|
||||
# 噪声生成
|
||||
w, h = item["size"]
|
||||
noise = torch.randn(h, w, 1)
|
||||
|
||||
# 目标矩阵填充
|
||||
target = torch.full((self.max_size, self.max_size), -100) # 使用-100忽略填充区域
|
||||
target[:h, :w] = torch.tensor(item["map"])
|
||||
target = torch.tensor(item["map"])
|
||||
graph = convert_map_to_graph(target)
|
||||
vision_feat, topo_feat = self.minamo(target, graph)
|
||||
feat_vec = torch.cat([vision_feat, topo_feat])
|
||||
|
||||
return {
|
||||
"noise": noise,
|
||||
"input_ids": encoding["input_ids"].squeeze(),
|
||||
"attention_mask": encoding["attention_mask"].squeeze(),
|
||||
"map_size": torch.tensor([h, w]),
|
||||
"feat_vec": feat_vec,
|
||||
"target": target
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,7 +2,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from pytorch_toolbelt import losses as L
|
||||
from minamo.model.model import MinamoModel
|
||||
|
||||
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
||||
"""地图最外层是否为墙"""
|
||||
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
|
||||
return avg_loss
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||
"""Ginka Model 损失函数部分
|
||||
|
||||
Args:
|
||||
@ -299,13 +299,11 @@ class GinkaLoss(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
self.dice = L.DiceLoss(mode='multiclass')
|
||||
self.ce = nn.CrossEntropyLoss()
|
||||
self.minamo = minamo
|
||||
|
||||
def forward(self, pred, target):
|
||||
def forward(self, pred, pred_softmax, target):
|
||||
probs = F.softmax(pred, dim=1)
|
||||
# 拓扑结构损失
|
||||
# structure_loss = topology_loss(pred, target)
|
||||
# 地图结构损失
|
||||
border_loss = wall_border_loss(pred, probs)
|
||||
wall_loss = internal_wall_loss(pred, probs)
|
||||
@ -315,6 +313,9 @@ class GinkaLoss(nn.Module):
|
||||
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
|
||||
count_loss = integrated_count_loss(probs, target)
|
||||
|
||||
# 使用 Minamo Model 计算相似度
|
||||
|
||||
|
||||
print(
|
||||
# structure_loss.item(),
|
||||
border_loss.item(),
|
||||
|
||||
@ -1,180 +1,44 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import BertModel
|
||||
from ...shared.attention import CBAM, SpatialAttention
|
||||
from .sample import HybridUpsample, FinalUpsample, GumbelSampler
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
"""残差块"""
|
||||
def __init__(self, channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 3, padding=1),
|
||||
nn.GroupNorm(8, channels),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channels, channels, 3, padding=1),
|
||||
nn.GroupNorm(8, channels)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv(x)
|
||||
from .unet import GinkaUNet
|
||||
|
||||
class DynamicPadConv(nn.Module):
|
||||
"""支持动态处理奇数尺寸的智能卷积"""
|
||||
def __init__(self, in_ch, out_ch, kernel=3, stride=1):
|
||||
class GumbelSoftmax(nn.Module):
|
||||
def __init__(self, tau=1.0, hard=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_ch, out_ch, kernel,
|
||||
stride=stride,
|
||||
padding=kernel//2
|
||||
)
|
||||
self.requires_pad = (stride > 1) # 仅在下采样时需要填充
|
||||
|
||||
def forward(self, x):
|
||||
if self.requires_pad:
|
||||
# 动态计算各维度需要填充的量
|
||||
pad_h = x.size(-2) % 2
|
||||
pad_w = x.size(-1) % 2
|
||||
if pad_h or pad_w:
|
||||
x = F.pad(x, (0, pad_w, 0, pad_h)) # 右下填充
|
||||
return self.conv(x)
|
||||
self.tau = tau # 温度参数
|
||||
self.hard = hard # 是否生成硬性one-hot
|
||||
|
||||
class ConditionInjector(nn.Module):
|
||||
"""基于注意力机制的条件注入"""
|
||||
def __init__(self, cond_dim=128, feat_dim=256):
|
||||
super().__init__()
|
||||
self.cond_proj = nn.Sequential(
|
||||
nn.Linear(cond_dim, feat_dim * 2),
|
||||
nn.GELU(),
|
||||
nn.LayerNorm(feat_dim * 2)
|
||||
)
|
||||
self.channel_att = nn.Sequential(
|
||||
nn.Conv2d(feat_dim, feat_dim//8, 1),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(feat_dim//8, feat_dim, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
# 投影条件向量
|
||||
gamma, beta = self.cond_proj(cond).chunk(2, dim=1) # [B, D]
|
||||
def forward(self, logits):
|
||||
# logits形状: [BS, C, H, W]
|
||||
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
||||
|
||||
# 通道注意力调制
|
||||
att = self.channel_att(x) # [B, C, H, W]
|
||||
modulated = x * att
|
||||
|
||||
# 添加条件偏置
|
||||
return modulated + beta.view(-1, gamma.size(1), 1, 1)
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encoder = nn.Sequential(
|
||||
DynamicPadConv(in_ch, out_ch, stride=1),
|
||||
ResidualBlock(out_ch),
|
||||
CBAM(out_ch),
|
||||
nn.GroupNorm(8, out_ch),
|
||||
nn.GELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.encoder(x)
|
||||
# 转换为类索引的连续表示
|
||||
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
||||
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, in_ch=1, base_ch=64, num_classes=32):
|
||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
|
||||
Args:
|
||||
in_ch (int, optional): 输入通道数,默认是 1
|
||||
base_ch (int, optional): UNet 上下采样卷积基础通道数,默认 64
|
||||
num_classes (int, optional): 图块种类数量,默认 32 预留出一部分以供后续拓展功能
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# 轻量级文本编码器(使用BERT前4层)
|
||||
self.bert = BertModel.from_pretrained('google-bert/bert-base-chinese', output_hidden_states=True)
|
||||
self.text_proj = nn.Linear(768, 128)
|
||||
|
||||
# 动态尺寸处理系统
|
||||
self.size_embed = nn.Embedding(32, 16) # 处理最大32的尺寸
|
||||
|
||||
# 编码器
|
||||
self.enc1 = GinkaEncoder(in_ch, base_ch)
|
||||
self.enc2 = GinkaEncoder(base_ch, base_ch * 2)
|
||||
# self.enc3 = GinkaEncoder(base_ch * 2, base_ch * 4)
|
||||
|
||||
# 中间层
|
||||
self.mid = nn.Sequential(
|
||||
DynamicPadConv(base_ch * 2, base_ch * 4),
|
||||
ConditionInjector(160, base_ch * 4)
|
||||
self.base_ch = base_ch
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
||||
)
|
||||
self.unet = GinkaUNet(base_ch, num_classes)
|
||||
self.softmax = GumbelSoftmax()
|
||||
|
||||
# 解码器,解码器仅使用空间注意力
|
||||
self.dec1 = HybridUpsample(base_ch * 4, base_ch * 2)
|
||||
self.dec1_att = SpatialAttention()
|
||||
|
||||
self.dec2 = HybridUpsample(base_ch * 2, base_ch)
|
||||
self.dec2_att = SpatialAttention()
|
||||
|
||||
# self.dec3 = HybridUpsample(base_ch * 2, base_ch)
|
||||
# self.dec3_att = SpatialAttention()
|
||||
|
||||
# 输出层
|
||||
self.out = FinalUpsample(base_ch, num_classes)
|
||||
|
||||
def forward(self, noise, input_ids, attention_mask, map_size):
|
||||
def forward(self, feat):
|
||||
"""
|
||||
Args:
|
||||
noise: 噪声输入 [BS, H, W, 1]
|
||||
input_ids: 文本token id [BS, seq_len]
|
||||
attention_mask: 文本attention mask [BS, seq_len]
|
||||
map_size: 地图尺寸 [BS, 2] (height, width)
|
||||
feat: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
# 文本特征提取
|
||||
with torch.no_grad(): # 冻结BERT参数
|
||||
bert_outputs = self.bert(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True
|
||||
)
|
||||
# 取前4层隐藏状态的平均
|
||||
hidden_states = torch.stack(bert_outputs.hidden_states[1:5]) # [4, BS, seq_len, 768]
|
||||
text_features = torch.mean(hidden_states, dim=0)[:, 0, :] # [BS, 768]
|
||||
text_features = self.text_proj(text_features) # [BS, 128]
|
||||
|
||||
# 尺寸特征处理
|
||||
h_emb = self.size_embed(map_size[:, 0]) # [BS, 16]
|
||||
w_emb = self.size_embed(map_size[:, 1]) # [BS, 16]
|
||||
size_features = torch.cat([h_emb, w_emb], dim=1) # [BS, 32]
|
||||
|
||||
# 特征融合
|
||||
conditional = torch.cat([text_features, size_features], dim=1) # [BS, 160]
|
||||
|
||||
# 调整噪声输入维度
|
||||
x = noise.permute(0, 3, 1, 2) # [BS, 1, H, W]
|
||||
|
||||
# 编码器路径
|
||||
x1 = self.enc1(x) # [BS, 64, H / 2, W / 2]
|
||||
x2 = self.enc2(x1) # [BS, 128, H / 4, W / 4]
|
||||
|
||||
# 中间层(注入条件)
|
||||
x_mid = self.mid[0](x2) # [BS, 256, H / 4, W / 4]
|
||||
x_mid = self.mid[1](x_mid, conditional)
|
||||
|
||||
# 解码器路径
|
||||
d1 = self.dec1(x_mid, x2) # [BS, 128, H / 2, W / 2]
|
||||
d1 = self.dec1_att(d1)
|
||||
d2 = self.dec2(d1, x1) # [BS, 64, H, W]
|
||||
d2 = self.dec2_att(d2)
|
||||
# d3 = self.dec3(d2, x1)
|
||||
# d3 = self.dec3_att(d3)
|
||||
|
||||
# 最终自适应上采样
|
||||
h, w = noise.shape[1:3] # 获取原始输入尺寸
|
||||
return self.out(d2, (h, w))
|
||||
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch, 32, 32)
|
||||
x = self.unet(x)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
|
||||
return x, self.softmax(x)
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision.ops as ops
|
||||
|
||||
class HybridUpsample(nn.Module):
|
||||
"""自适应尺寸的混合上采样"""
|
||||
def __init__(self, in_ch, out_ch, skip_ch=None):
|
||||
super().__init__()
|
||||
# 子像素卷积上采样
|
||||
self.subpixel = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch * 4, 3, padding=1),
|
||||
nn.PixelShuffle(2) # 2倍上采样
|
||||
)
|
||||
|
||||
# 跳跃连接处理
|
||||
self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None
|
||||
self.adaptive_pool = nn.AdaptiveAvgPool2d(None)
|
||||
|
||||
def forward(self, x, skip=None):
|
||||
x = self.subpixel(x) # [B, out_ch, 2H, 2W]
|
||||
|
||||
if skip is not None and self.skip_conv:
|
||||
# 自动对齐尺寸
|
||||
if x.shape[-2:] != skip.shape[-2:]:
|
||||
skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')
|
||||
|
||||
# 融合特征
|
||||
x = x + self.skip_conv(skip)
|
||||
|
||||
return x
|
||||
|
||||
class DiscreteAwareUpsample(nn.Module):
|
||||
"""离散感知的智能上采样模块"""
|
||||
def __init__(self, in_ch, out_ch, base_size=16):
|
||||
super().__init__()
|
||||
self.base_size = base_size
|
||||
self.scale_factors = [2, 4, 8] # 支持放大倍数
|
||||
|
||||
# 可变形卷积增强几何感知
|
||||
self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1)
|
||||
|
||||
# 多尺度特征融合
|
||||
self.multi_scale = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch//4, 1),
|
||||
nn.Upsample(scale_factor=s, mode='nearest')
|
||||
) for s in self.scale_factors
|
||||
])
|
||||
|
||||
# 门控上采样机制
|
||||
self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1)
|
||||
|
||||
# 离散化输出层
|
||||
self.final_conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch*4, 3, padding=1),
|
||||
nn.PixelShuffle(2), # 亚像素卷积
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
||||
)
|
||||
|
||||
def forward(self, x, target_size):
|
||||
# 几何特征提取
|
||||
deform_feat = self.deform_conv(x)
|
||||
|
||||
# 生成多尺度特征
|
||||
scale_features = [f(deform_feat) for f in self.multi_scale]
|
||||
|
||||
# 动态门控选择
|
||||
gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1)
|
||||
|
||||
# 加权融合多尺度特征
|
||||
combined = sum(g * F.interpolate(f, size=target_size, mode='nearest')
|
||||
for g, f in zip(gate_map.unbind(1), scale_features+[x]))
|
||||
|
||||
# 离散化上采样
|
||||
out = self.final_conv(combined)
|
||||
|
||||
# 结构化约束(保持通道独立性)
|
||||
return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留
|
||||
91
ginka/model/unet.py
Normal file
91
ginka/model/unet.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
|
||||
def forward(self, x):
|
||||
x_res = self.conv(x) # 卷积提取特征
|
||||
x_down = self.pool(x_res) # 进行池化
|
||||
return x_down, x_res # 返回池化后的特征和跳跃连接特征
|
||||
|
||||
class GinkaDecoder(nn.Module):
|
||||
"""解码器(上采样)部分"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
# 上采样(双线性插值 + 卷积)
|
||||
self.upsample = nn.Sequential(
|
||||
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# 跳跃连接融合
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
def forward(self, x, skip):
|
||||
x = self.upsample(x)
|
||||
# 跳跃连接融合
|
||||
x = torch.cat([x, skip], dim=1)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class GinkaBottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=64, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
||||
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
||||
|
||||
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
||||
|
||||
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x, skip1 = self.down1(x)
|
||||
x, skip2 = self.down2(x)
|
||||
|
||||
x = self.bottleneck(x)
|
||||
|
||||
x = self.up1(x, skip2)
|
||||
x = self.up2(x, skip1)
|
||||
|
||||
return self.final(x)
|
||||
@ -9,87 +9,57 @@ from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .dataset import GinkaDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
|
||||
epochs = 100
|
||||
epochs = 70
|
||||
|
||||
def collate_fn(batch):
|
||||
# 动态填充噪声到最大尺寸
|
||||
max_h = max([b["noise"].shape[0] for b in batch])
|
||||
max_w = max([b["noise"].shape[1] for b in batch])
|
||||
|
||||
padded_batch = {}
|
||||
for key in ["noise", "target"]:
|
||||
padded = []
|
||||
for b in batch:
|
||||
tensor = b[key]
|
||||
pad_h = max_h - tensor.shape[0]
|
||||
pad_w = max_w - tensor.shape[1]
|
||||
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
|
||||
padded_batch[key] = torch.stack(padded)
|
||||
|
||||
# 其他字段直接堆叠
|
||||
for key in ["input_ids", "attention_mask", "map_size"]:
|
||||
padded_batch[key] = torch.stack([b[key] for b in batch])
|
||||
|
||||
return padded_batch
|
||||
def update_tau(epoch):
|
||||
start_tau = 1.0
|
||||
min_tau = 0.1
|
||||
decay_rate = 0.95
|
||||
return max(min_tau, start_tau * (decay_rate ** epoch))
|
||||
|
||||
def train():
|
||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
model = GinkaModel()
|
||||
model.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
# 准备数据集
|
||||
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
|
||||
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
|
||||
dataset = GinkaDataset("dataset.json", minamo)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=4,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=0
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
criterion = GinkaLoss()
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
# 温度退火
|
||||
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
|
||||
model.softmax.tau = update_tau(epoch)
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
noise = batch["noise"].to(device)
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
attention_mask = batch["attention_mask"].to(device)
|
||||
map_size = batch["map_size"].to(device)
|
||||
target = batch["target"].to(device)
|
||||
feat_vec = batch["feat_vec"].to(device)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
outputs = model(noise, input_ids, attention_mask, map_size)
|
||||
|
||||
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
|
||||
# print(sampled[0, :, :, 1])
|
||||
|
||||
# 构建拓扑图
|
||||
# with torch.no_grad():
|
||||
# pred_graphs = build_topology_graph(outputs.argmax(1))
|
||||
# ref_graphs = build_topology_graph(target)
|
||||
output, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(
|
||||
outputs, # 调整为 [BS, C, H, W]
|
||||
target
|
||||
)
|
||||
loss = criterion(output, output_softmax, target)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
|
||||
@ -11,11 +11,8 @@ class MinamoModel(nn.Module):
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
|
||||
def forward(self, map1, map2, graph1, graph2):
|
||||
vision_feat1 = self.vision_model(map1)
|
||||
vision_feat2 = self.vision_model(map2)
|
||||
def forward(self, map, graph):
|
||||
vision_feat = self.vision_model(map)
|
||||
topo_feat = self.topo_model(graph)
|
||||
|
||||
topo_feat1 = self.topo_model(graph1)
|
||||
topo_feat2 = self.topo_model(graph2)
|
||||
|
||||
return vision_feat1, vision_feat2, topo_feat1, topo_feat2
|
||||
return vision_feat, topo_feat
|
||||
|
||||
@ -76,7 +76,8 @@ def train():
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2)
|
||||
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
@ -8,32 +9,38 @@ from .dataset import MinamoDataset
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def validate():
|
||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = MinamoModel(32)
|
||||
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
model.to(device)
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
||||
val_dataset = MinamoDataset("minamo-eval.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
criterion = MinamoLoss(temp=0.8)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(val_loader):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
vision_simi_val = vision_simi_val.to(device)
|
||||
topo_simi_val = topo_simi_val.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
|
||||
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||
|
||||
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
loss_val = criterion(
|
||||
vision_pred_val, topo_pred_val,
|
||||
vision_simi_val, topo_simi_val
|
||||
|
||||
Loading…
Reference in New Issue
Block a user