refactor: 重构 ginka 模型部分

This commit is contained in:
unanmed 2025-03-18 21:11:53 +08:00
parent 09b66f2569
commit 1566acf691
9 changed files with 172 additions and 336 deletions

View File

@ -1,8 +1,8 @@
import json
import random
import torch
from torch.utils.data import Dataset
from transformers import BertTokenizer
from minamo.model.model import MinamoModel
from shared.graph import convert_map_to_graph
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
@ -15,11 +15,10 @@ def load_data(path: str):
return data_list
class GinkaDataset(Dataset):
def __init__(self, data_path: str, tokenizer: BertTokenizer, max_len=128):
def __init__(self, data_path: str, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数
self.tokenizer = tokenizer
self.max_len = max_len
self.max_size = 32
self.minamo = minamo
def __len__(self):
return len(self.data)
@ -27,28 +26,13 @@ class GinkaDataset(Dataset):
def __getitem__(self, idx):
item = self.data[idx]
# 文本处理
text = random.choice(item["text"])
encoding = self.tokenizer(
text,
max_length=self.max_len,
padding="max_length",
truncation=True,
return_tensors="pt"
)
# 噪声生成
w, h = item["size"]
noise = torch.randn(h, w, 1)
# 目标矩阵填充
target = torch.full((self.max_size, self.max_size), -100) # 使用-100忽略填充区域
target[:h, :w] = torch.tensor(item["map"])
target = torch.tensor(item["map"])
graph = convert_map_to_graph(target)
vision_feat, topo_feat = self.minamo(target, graph)
feat_vec = torch.cat([vision_feat, topo_feat])
return {
"noise": noise,
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"map_size": torch.tensor([h, w]),
"feat_vec": feat_vec,
"target": target
}
}

View File

@ -2,7 +2,7 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_toolbelt import losses as L
from minamo.model.model import MinamoModel
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
"""地图最外层是否为墙"""
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
return avg_loss
class GinkaLoss(nn.Module):
def __init__(self, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
"""Ginka Model 损失函数部分
Args:
@ -299,13 +299,11 @@ class GinkaLoss(nn.Module):
"""
super().__init__()
self.weight = weight
self.dice = L.DiceLoss(mode='multiclass')
self.ce = nn.CrossEntropyLoss()
self.minamo = minamo
def forward(self, pred, target):
def forward(self, pred, pred_softmax, target):
probs = F.softmax(pred, dim=1)
# 拓扑结构损失
# structure_loss = topology_loss(pred, target)
# 地图结构损失
border_loss = wall_border_loss(pred, probs)
wall_loss = internal_wall_loss(pred, probs)
@ -315,6 +313,9 @@ class GinkaLoss(nn.Module):
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
count_loss = integrated_count_loss(probs, target)
# 使用 Minamo Model 计算相似度
print(
# structure_loss.item(),
border_loss.item(),

View File

@ -1,180 +1,44 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel
from ...shared.attention import CBAM, SpatialAttention
from .sample import HybridUpsample, FinalUpsample, GumbelSampler
class ResidualBlock(nn.Module):
"""残差块"""
def __init__(self, channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(channels, channels, 3, padding=1),
nn.GroupNorm(8, channels),
nn.GELU(),
nn.Conv2d(channels, channels, 3, padding=1),
nn.GroupNorm(8, channels)
)
def forward(self, x):
return x + self.conv(x)
from .unet import GinkaUNet
class DynamicPadConv(nn.Module):
"""支持动态处理奇数尺寸的智能卷积"""
def __init__(self, in_ch, out_ch, kernel=3, stride=1):
class GumbelSoftmax(nn.Module):
def __init__(self, tau=1.0, hard=True):
super().__init__()
self.conv = nn.Conv2d(
in_ch, out_ch, kernel,
stride=stride,
padding=kernel//2
)
self.requires_pad = (stride > 1) # 仅在下采样时需要填充
def forward(self, x):
if self.requires_pad:
# 动态计算各维度需要填充的量
pad_h = x.size(-2) % 2
pad_w = x.size(-1) % 2
if pad_h or pad_w:
x = F.pad(x, (0, pad_w, 0, pad_h)) # 右下填充
return self.conv(x)
self.tau = tau # 温度参数
self.hard = hard # 是否生成硬性one-hot
class ConditionInjector(nn.Module):
"""基于注意力机制的条件注入"""
def __init__(self, cond_dim=128, feat_dim=256):
super().__init__()
self.cond_proj = nn.Sequential(
nn.Linear(cond_dim, feat_dim * 2),
nn.GELU(),
nn.LayerNorm(feat_dim * 2)
)
self.channel_att = nn.Sequential(
nn.Conv2d(feat_dim, feat_dim//8, 1),
nn.GELU(),
nn.Conv2d(feat_dim//8, feat_dim, 1),
nn.Sigmoid()
)
def forward(self, x, cond):
# 投影条件向量
gamma, beta = self.cond_proj(cond).chunk(2, dim=1) # [B, D]
def forward(self, logits):
# logits形状: [BS, C, H, W]
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
# 通道注意力调制
att = self.channel_att(x) # [B, C, H, W]
modulated = x * att
# 添加条件偏置
return modulated + beta.view(-1, gamma.size(1), 1, 1)
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_ch, out_ch):
super().__init__()
self.encoder = nn.Sequential(
DynamicPadConv(in_ch, out_ch, stride=1),
ResidualBlock(out_ch),
CBAM(out_ch),
nn.GroupNorm(8, out_ch),
nn.GELU()
)
def forward(self, x):
return self.encoder(x)
# 转换为类索引的连续表示
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
class GinkaModel(nn.Module):
def __init__(self, in_ch=1, base_ch=64, num_classes=32):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分
Args:
in_ch (int, optional): 输入通道数默认是 1
base_ch (int, optional): UNet 上下采样卷积基础通道数默认 64
num_classes (int, optional): 图块种类数量默认 32 预留出一部分以供后续拓展功能
"""
super().__init__()
# 轻量级文本编码器使用BERT前4层
self.bert = BertModel.from_pretrained('google-bert/bert-base-chinese', output_hidden_states=True)
self.text_proj = nn.Linear(768, 128)
# 动态尺寸处理系统
self.size_embed = nn.Embedding(32, 16) # 处理最大32的尺寸
# 编码器
self.enc1 = GinkaEncoder(in_ch, base_ch)
self.enc2 = GinkaEncoder(base_ch, base_ch * 2)
# self.enc3 = GinkaEncoder(base_ch * 2, base_ch * 4)
# 中间层
self.mid = nn.Sequential(
DynamicPadConv(base_ch * 2, base_ch * 4),
ConditionInjector(160, base_ch * 4)
self.base_ch = base_ch
self.fc = nn.Sequential(
nn.Linear(feat_dim, 32 * 32 * base_ch)
)
self.unet = GinkaUNet(base_ch, num_classes)
self.softmax = GumbelSoftmax()
# 解码器,解码器仅使用空间注意力
self.dec1 = HybridUpsample(base_ch * 4, base_ch * 2)
self.dec1_att = SpatialAttention()
self.dec2 = HybridUpsample(base_ch * 2, base_ch)
self.dec2_att = SpatialAttention()
# self.dec3 = HybridUpsample(base_ch * 2, base_ch)
# self.dec3_att = SpatialAttention()
# 输出层
self.out = FinalUpsample(base_ch, num_classes)
def forward(self, noise, input_ids, attention_mask, map_size):
def forward(self, feat):
"""
Args:
noise: 噪声输入 [BS, H, W, 1]
input_ids: 文本token id [BS, seq_len]
attention_mask: 文本attention mask [BS, seq_len]
map_size: 地图尺寸 [BS, 2] (height, width)
feat: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
# 文本特征提取
with torch.no_grad(): # 冻结BERT参数
bert_outputs = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True
)
# 取前4层隐藏状态的平均
hidden_states = torch.stack(bert_outputs.hidden_states[1:5]) # [4, BS, seq_len, 768]
text_features = torch.mean(hidden_states, dim=0)[:, 0, :] # [BS, 768]
text_features = self.text_proj(text_features) # [BS, 128]
# 尺寸特征处理
h_emb = self.size_embed(map_size[:, 0]) # [BS, 16]
w_emb = self.size_embed(map_size[:, 1]) # [BS, 16]
size_features = torch.cat([h_emb, w_emb], dim=1) # [BS, 32]
# 特征融合
conditional = torch.cat([text_features, size_features], dim=1) # [BS, 160]
# 调整噪声输入维度
x = noise.permute(0, 3, 1, 2) # [BS, 1, H, W]
# 编码器路径
x1 = self.enc1(x) # [BS, 64, H / 2, W / 2]
x2 = self.enc2(x1) # [BS, 128, H / 4, W / 4]
# 中间层(注入条件)
x_mid = self.mid[0](x2) # [BS, 256, H / 4, W / 4]
x_mid = self.mid[1](x_mid, conditional)
# 解码器路径
d1 = self.dec1(x_mid, x2) # [BS, 128, H / 2, W / 2]
d1 = self.dec1_att(d1)
d2 = self.dec2(d1, x1) # [BS, 64, H, W]
d2 = self.dec2_att(d2)
# d3 = self.dec3(d2, x1)
# d3 = self.dec3_att(d3)
# 最终自适应上采样
h, w = noise.shape[1:3] # 获取原始输入尺寸
return self.out(d2, (h, w))
x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x)
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
return x, self.softmax(x)

View File

@ -1,79 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.ops as ops
class HybridUpsample(nn.Module):
"""自适应尺寸的混合上采样"""
def __init__(self, in_ch, out_ch, skip_ch=None):
super().__init__()
# 子像素卷积上采样
self.subpixel = nn.Sequential(
nn.Conv2d(in_ch, out_ch * 4, 3, padding=1),
nn.PixelShuffle(2) # 2倍上采样
)
# 跳跃连接处理
self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None
self.adaptive_pool = nn.AdaptiveAvgPool2d(None)
def forward(self, x, skip=None):
x = self.subpixel(x) # [B, out_ch, 2H, 2W]
if skip is not None and self.skip_conv:
# 自动对齐尺寸
if x.shape[-2:] != skip.shape[-2:]:
skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')
# 融合特征
x = x + self.skip_conv(skip)
return x
class DiscreteAwareUpsample(nn.Module):
"""离散感知的智能上采样模块"""
def __init__(self, in_ch, out_ch, base_size=16):
super().__init__()
self.base_size = base_size
self.scale_factors = [2, 4, 8] # 支持放大倍数
# 可变形卷积增强几何感知
self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1)
# 多尺度特征融合
self.multi_scale = nn.ModuleList([
nn.Sequential(
nn.Conv2d(in_ch, in_ch//4, 1),
nn.Upsample(scale_factor=s, mode='nearest')
) for s in self.scale_factors
])
# 门控上采样机制
self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1)
# 离散化输出层
self.final_conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch*4, 3, padding=1),
nn.PixelShuffle(2), # 亚像素卷积
nn.Conv2d(out_ch, out_ch, 3, padding=1)
)
def forward(self, x, target_size):
# 几何特征提取
deform_feat = self.deform_conv(x)
# 生成多尺度特征
scale_features = [f(deform_feat) for f in self.multi_scale]
# 动态门控选择
gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1)
# 加权融合多尺度特征
combined = sum(g * F.interpolate(f, size=target_size, mode='nearest')
for g, f in zip(gate_map.unbind(1), scale_features+[x]))
# 离散化上采样
out = self.final_conv(combined)
# 结构化约束(保持通道独立性)
return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留

91
ginka/model/unet.py Normal file
View File

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
self.pool = nn.MaxPool2d(2)
def forward(self, x):
x_res = self.conv(x) # 卷积提取特征
x_down = self.pool(x_res) # 进行池化
return x_down, x_res # 返回池化后的特征和跳跃连接特征
class GinkaDecoder(nn.Module):
"""解码器(上采样)部分"""
def __init__(self, in_channels, out_channels):
super().__init__()
# 上采样(双线性插值 + 卷积)
self.upsample = nn.Sequential(
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
# 跳跃连接融合
self.conv = nn.Sequential(
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU()
)
def forward(self, x, skip):
x = self.upsample(x)
# 跳跃连接融合
x = torch.cat([x, skip], dim=1)
x = self.conv(x)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class GinkaUNet(nn.Module):
def __init__(self, in_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.down1 = GinkaEncoder(in_ch, in_ch*2)
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
self.up2 = GinkaDecoder(in_ch*2, in_ch)
self.final = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):
x, skip1 = self.down1(x)
x, skip2 = self.down2(x)
x = self.bottleneck(x)
x = self.up1(x, skip2)
x = self.up2(x, skip1)
return self.final(x)

View File

@ -9,87 +9,57 @@ from tqdm import tqdm
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
epochs = 100
epochs = 70
def collate_fn(batch):
# 动态填充噪声到最大尺寸
max_h = max([b["noise"].shape[0] for b in batch])
max_w = max([b["noise"].shape[1] for b in batch])
padded_batch = {}
for key in ["noise", "target"]:
padded = []
for b in batch:
tensor = b[key]
pad_h = max_h - tensor.shape[0]
pad_w = max_w - tensor.shape[1]
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
padded_batch[key] = torch.stack(padded)
# 其他字段直接堆叠
for key in ["input_ids", "attention_mask", "map_size"]:
padded_batch[key] = torch.stack([b[key] for b in batch])
return padded_batch
def update_tau(epoch):
start_tau = 1.0
min_tau = 0.1
decay_rate = 0.95
return max(min_tau, start_tau * (decay_rate ** epoch))
def train():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
model = GinkaModel()
model.to(device)
minamo = MinamoModel(32)
minamo.to(device)
minamo.eval()
# 准备数据集
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
dataset = GinkaDataset("dataset.json", minamo)
dataloader = DataLoader(
dataset,
batch_size=4,
shuffle=True,
collate_fn=collate_fn,
num_workers=0
batch_size=32,
shuffle=True
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = GinkaLoss()
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
# 开始训练
for epoch in tqdm(range(epochs)):
model.train()
total_loss = 0
# 温度退火
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
model.softmax.tau = update_tau(epoch)
for batch in dataloader:
# 数据迁移到设备
noise = batch["noise"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
map_size = batch["map_size"].to(device)
target = batch["target"].to(device)
feat_vec = batch["feat_vec"].to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(noise, input_ids, attention_mask, map_size)
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
# print(sampled[0, :, :, 1])
# 构建拓扑图
# with torch.no_grad():
# pred_graphs = build_topology_graph(outputs.argmax(1))
# ref_graphs = build_topology_graph(target)
output, output_softmax = model(feat_vec)
# 计算损失
loss = criterion(
outputs, # 调整为 [BS, C, H, W]
target
)
loss = criterion(output, output_softmax, target)
# 反向传播
loss.backward()

View File

@ -11,11 +11,8 @@ class MinamoModel(nn.Module):
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
def forward(self, map1, map2, graph1, graph2):
vision_feat1 = self.vision_model(map1)
vision_feat2 = self.vision_model(map2)
def forward(self, map, graph):
vision_feat = self.vision_model(map)
topo_feat = self.topo_model(graph)
topo_feat1 = self.topo_model(graph1)
topo_feat2 = self.topo_model(graph2)
return vision_feat1, vision_feat2, topo_feat1, topo_feat2
return vision_feat, topo_feat

View File

@ -76,7 +76,8 @@ def train():
# 前向传播
optimizer.zero_grad()
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2)
vision_feat1, topo_feat1 = model(map1, graph1)
vision_feat2, topo_feat2 = model(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)

View File

@ -1,5 +1,6 @@
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
@ -8,32 +9,38 @@ from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device)
# 准备数据集
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
val_dataset = MinamoDataset("minamo-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
criterion = MinamoLoss(temp=0.8)
criterion = MinamoLoss()
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(val_loader):
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val