mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
refactor: 重构 ginka 模型部分
This commit is contained in:
parent
09b66f2569
commit
1566acf691
@ -1,8 +1,8 @@
|
|||||||
import json
|
import json
|
||||||
import random
|
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from transformers import BertTokenizer
|
from minamo.model.model import MinamoModel
|
||||||
|
from shared.graph import convert_map_to_graph
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
@ -15,11 +15,10 @@ def load_data(path: str):
|
|||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
class GinkaDataset(Dataset):
|
class GinkaDataset(Dataset):
|
||||||
def __init__(self, data_path: str, tokenizer: BertTokenizer, max_len=128):
|
def __init__(self, data_path: str, minamo: MinamoModel):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.max_len = max_len
|
|
||||||
self.max_size = 32
|
self.max_size = 32
|
||||||
|
self.minamo = minamo
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -27,28 +26,13 @@ class GinkaDataset(Dataset):
|
|||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
# 文本处理
|
target = torch.tensor(item["map"])
|
||||||
text = random.choice(item["text"])
|
graph = convert_map_to_graph(target)
|
||||||
encoding = self.tokenizer(
|
vision_feat, topo_feat = self.minamo(target, graph)
|
||||||
text,
|
feat_vec = torch.cat([vision_feat, topo_feat])
|
||||||
max_length=self.max_len,
|
|
||||||
padding="max_length",
|
|
||||||
truncation=True,
|
|
||||||
return_tensors="pt"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 噪声生成
|
|
||||||
w, h = item["size"]
|
|
||||||
noise = torch.randn(h, w, 1)
|
|
||||||
|
|
||||||
# 目标矩阵填充
|
|
||||||
target = torch.full((self.max_size, self.max_size), -100) # 使用-100忽略填充区域
|
|
||||||
target[:h, :w] = torch.tensor(item["map"])
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"noise": noise,
|
"feat_vec": feat_vec,
|
||||||
"input_ids": encoding["input_ids"].squeeze(),
|
|
||||||
"attention_mask": encoding["attention_mask"].squeeze(),
|
|
||||||
"map_size": torch.tensor([h, w]),
|
|
||||||
"target": target
|
"target": target
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2,7 +2,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from pytorch_toolbelt import losses as L
|
from minamo.model.model import MinamoModel
|
||||||
|
|
||||||
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
def wall_border_loss(pred: torch.Tensor, probs: torch.Tensor, allow_border=[1, 11]):
|
||||||
"""地图最外层是否为墙"""
|
"""地图最外层是否为墙"""
|
||||||
@ -283,7 +283,7 @@ def integrated_count_loss(probs, target, class_list=[0,1,2,3,4,5,6,7,8,9], toler
|
|||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
class GinkaLoss(nn.Module):
|
class GinkaLoss(nn.Module):
|
||||||
def __init__(self, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
def __init__(self, minamo: MinamoModel, weight=[0.35, 0.1, 0.1, 0.1, 0.1, 0.05, 0.1, 0.1]):
|
||||||
"""Ginka Model 损失函数部分
|
"""Ginka Model 损失函数部分
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -299,13 +299,11 @@ class GinkaLoss(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.dice = L.DiceLoss(mode='multiclass')
|
|
||||||
self.ce = nn.CrossEntropyLoss()
|
self.ce = nn.CrossEntropyLoss()
|
||||||
|
self.minamo = minamo
|
||||||
|
|
||||||
def forward(self, pred, target):
|
def forward(self, pred, pred_softmax, target):
|
||||||
probs = F.softmax(pred, dim=1)
|
probs = F.softmax(pred, dim=1)
|
||||||
# 拓扑结构损失
|
|
||||||
# structure_loss = topology_loss(pred, target)
|
|
||||||
# 地图结构损失
|
# 地图结构损失
|
||||||
border_loss = wall_border_loss(pred, probs)
|
border_loss = wall_border_loss(pred, probs)
|
||||||
wall_loss = internal_wall_loss(pred, probs)
|
wall_loss = internal_wall_loss(pred, probs)
|
||||||
@ -315,6 +313,9 @@ class GinkaLoss(nn.Module):
|
|||||||
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
|
valid_block_loss = illegal_block_loss(pred, probs, used_classes=12, mode="mean")
|
||||||
count_loss = integrated_count_loss(probs, target)
|
count_loss = integrated_count_loss(probs, target)
|
||||||
|
|
||||||
|
# 使用 Minamo Model 计算相似度
|
||||||
|
|
||||||
|
|
||||||
print(
|
print(
|
||||||
# structure_loss.item(),
|
# structure_loss.item(),
|
||||||
border_loss.item(),
|
border_loss.item(),
|
||||||
|
|||||||
@ -1,180 +1,44 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers import BertModel
|
from .unet import GinkaUNet
|
||||||
from ...shared.attention import CBAM, SpatialAttention
|
|
||||||
from .sample import HybridUpsample, FinalUpsample, GumbelSampler
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class GumbelSoftmax(nn.Module):
|
||||||
"""残差块"""
|
def __init__(self, tau=1.0, hard=True):
|
||||||
def __init__(self, channels):
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv = nn.Sequential(
|
self.tau = tau # 温度参数
|
||||||
nn.Conv2d(channels, channels, 3, padding=1),
|
self.hard = hard # 是否生成硬性one-hot
|
||||||
nn.GroupNorm(8, channels),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(channels, channels, 3, padding=1),
|
|
||||||
nn.GroupNorm(8, channels)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, logits):
|
||||||
return x + self.conv(x)
|
# logits形状: [BS, C, H, W]
|
||||||
|
y = F.gumbel_softmax(logits, tau=self.tau, hard=self.hard)
|
||||||
|
|
||||||
class DynamicPadConv(nn.Module):
|
# 转换为类索引的连续表示
|
||||||
"""支持动态处理奇数尺寸的智能卷积"""
|
class_indices = torch.arange(y.size(1), device=y.device).view(1, -1, 1, 1)
|
||||||
def __init__(self, in_ch, out_ch, kernel=3, stride=1):
|
return (y * class_indices).sum(dim=1) # 形状[BS, H, W]
|
||||||
super().__init__()
|
|
||||||
self.conv = nn.Conv2d(
|
|
||||||
in_ch, out_ch, kernel,
|
|
||||||
stride=stride,
|
|
||||||
padding=kernel//2
|
|
||||||
)
|
|
||||||
self.requires_pad = (stride > 1) # 仅在下采样时需要填充
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.requires_pad:
|
|
||||||
# 动态计算各维度需要填充的量
|
|
||||||
pad_h = x.size(-2) % 2
|
|
||||||
pad_w = x.size(-1) % 2
|
|
||||||
if pad_h or pad_w:
|
|
||||||
x = F.pad(x, (0, pad_w, 0, pad_h)) # 右下填充
|
|
||||||
return self.conv(x)
|
|
||||||
|
|
||||||
class ConditionInjector(nn.Module):
|
|
||||||
"""基于注意力机制的条件注入"""
|
|
||||||
def __init__(self, cond_dim=128, feat_dim=256):
|
|
||||||
super().__init__()
|
|
||||||
self.cond_proj = nn.Sequential(
|
|
||||||
nn.Linear(cond_dim, feat_dim * 2),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.LayerNorm(feat_dim * 2)
|
|
||||||
)
|
|
||||||
self.channel_att = nn.Sequential(
|
|
||||||
nn.Conv2d(feat_dim, feat_dim//8, 1),
|
|
||||||
nn.GELU(),
|
|
||||||
nn.Conv2d(feat_dim//8, feat_dim, 1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, cond):
|
|
||||||
# 投影条件向量
|
|
||||||
gamma, beta = self.cond_proj(cond).chunk(2, dim=1) # [B, D]
|
|
||||||
|
|
||||||
# 通道注意力调制
|
|
||||||
att = self.channel_att(x) # [B, C, H, W]
|
|
||||||
modulated = x * att
|
|
||||||
|
|
||||||
# 添加条件偏置
|
|
||||||
return modulated + beta.view(-1, gamma.size(1), 1, 1)
|
|
||||||
|
|
||||||
class GinkaEncoder(nn.Module):
|
|
||||||
"""编码器(下采样)部分"""
|
|
||||||
def __init__(self, in_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encoder = nn.Sequential(
|
|
||||||
DynamicPadConv(in_ch, out_ch, stride=1),
|
|
||||||
ResidualBlock(out_ch),
|
|
||||||
CBAM(out_ch),
|
|
||||||
nn.GroupNorm(8, out_ch),
|
|
||||||
nn.GELU()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return self.encoder(x)
|
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, in_ch=1, base_ch=64, num_classes=32):
|
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||||
"""Ginka Model 模型定义部分
|
"""Ginka Model 模型定义部分
|
||||||
|
|
||||||
Args:
|
|
||||||
in_ch (int, optional): 输入通道数,默认是 1
|
|
||||||
base_ch (int, optional): UNet 上下采样卷积基础通道数,默认 64
|
|
||||||
num_classes (int, optional): 图块种类数量,默认 32 预留出一部分以供后续拓展功能
|
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.base_ch = base_ch
|
||||||
# 轻量级文本编码器(使用BERT前4层)
|
self.fc = nn.Sequential(
|
||||||
self.bert = BertModel.from_pretrained('google-bert/bert-base-chinese', output_hidden_states=True)
|
nn.Linear(feat_dim, 32 * 32 * base_ch)
|
||||||
self.text_proj = nn.Linear(768, 128)
|
|
||||||
|
|
||||||
# 动态尺寸处理系统
|
|
||||||
self.size_embed = nn.Embedding(32, 16) # 处理最大32的尺寸
|
|
||||||
|
|
||||||
# 编码器
|
|
||||||
self.enc1 = GinkaEncoder(in_ch, base_ch)
|
|
||||||
self.enc2 = GinkaEncoder(base_ch, base_ch * 2)
|
|
||||||
# self.enc3 = GinkaEncoder(base_ch * 2, base_ch * 4)
|
|
||||||
|
|
||||||
# 中间层
|
|
||||||
self.mid = nn.Sequential(
|
|
||||||
DynamicPadConv(base_ch * 2, base_ch * 4),
|
|
||||||
ConditionInjector(160, base_ch * 4)
|
|
||||||
)
|
)
|
||||||
|
self.unet = GinkaUNet(base_ch, num_classes)
|
||||||
|
self.softmax = GumbelSoftmax()
|
||||||
|
|
||||||
# 解码器,解码器仅使用空间注意力
|
def forward(self, feat):
|
||||||
self.dec1 = HybridUpsample(base_ch * 4, base_ch * 2)
|
|
||||||
self.dec1_att = SpatialAttention()
|
|
||||||
|
|
||||||
self.dec2 = HybridUpsample(base_ch * 2, base_ch)
|
|
||||||
self.dec2_att = SpatialAttention()
|
|
||||||
|
|
||||||
# self.dec3 = HybridUpsample(base_ch * 2, base_ch)
|
|
||||||
# self.dec3_att = SpatialAttention()
|
|
||||||
|
|
||||||
# 输出层
|
|
||||||
self.out = FinalUpsample(base_ch, num_classes)
|
|
||||||
|
|
||||||
def forward(self, noise, input_ids, attention_mask, map_size):
|
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
noise: 噪声输入 [BS, H, W, 1]
|
feat: 参考地图的特征向量
|
||||||
input_ids: 文本token id [BS, seq_len]
|
|
||||||
attention_mask: 文本attention mask [BS, seq_len]
|
|
||||||
map_size: 地图尺寸 [BS, 2] (height, width)
|
|
||||||
Returns:
|
Returns:
|
||||||
logits: 输出logits [BS, num_classes, H, W]
|
logits: 输出logits [BS, num_classes, H, W]
|
||||||
"""
|
"""
|
||||||
# 文本特征提取
|
x = self.fc(feat)
|
||||||
with torch.no_grad(): # 冻结BERT参数
|
x = x.view(-1, self.base_ch, 32, 32)
|
||||||
bert_outputs = self.bert(
|
x = self.unet(x)
|
||||||
input_ids=input_ids,
|
x = F.interpolate(x, (13, 13), mode='bilinear', align_corners=False)
|
||||||
attention_mask=attention_mask,
|
return x, self.softmax(x)
|
||||||
output_hidden_states=True
|
|
||||||
)
|
|
||||||
# 取前4层隐藏状态的平均
|
|
||||||
hidden_states = torch.stack(bert_outputs.hidden_states[1:5]) # [4, BS, seq_len, 768]
|
|
||||||
text_features = torch.mean(hidden_states, dim=0)[:, 0, :] # [BS, 768]
|
|
||||||
text_features = self.text_proj(text_features) # [BS, 128]
|
|
||||||
|
|
||||||
# 尺寸特征处理
|
|
||||||
h_emb = self.size_embed(map_size[:, 0]) # [BS, 16]
|
|
||||||
w_emb = self.size_embed(map_size[:, 1]) # [BS, 16]
|
|
||||||
size_features = torch.cat([h_emb, w_emb], dim=1) # [BS, 32]
|
|
||||||
|
|
||||||
# 特征融合
|
|
||||||
conditional = torch.cat([text_features, size_features], dim=1) # [BS, 160]
|
|
||||||
|
|
||||||
# 调整噪声输入维度
|
|
||||||
x = noise.permute(0, 3, 1, 2) # [BS, 1, H, W]
|
|
||||||
|
|
||||||
# 编码器路径
|
|
||||||
x1 = self.enc1(x) # [BS, 64, H / 2, W / 2]
|
|
||||||
x2 = self.enc2(x1) # [BS, 128, H / 4, W / 4]
|
|
||||||
|
|
||||||
# 中间层(注入条件)
|
|
||||||
x_mid = self.mid[0](x2) # [BS, 256, H / 4, W / 4]
|
|
||||||
x_mid = self.mid[1](x_mid, conditional)
|
|
||||||
|
|
||||||
# 解码器路径
|
|
||||||
d1 = self.dec1(x_mid, x2) # [BS, 128, H / 2, W / 2]
|
|
||||||
d1 = self.dec1_att(d1)
|
|
||||||
d2 = self.dec2(d1, x1) # [BS, 64, H, W]
|
|
||||||
d2 = self.dec2_att(d2)
|
|
||||||
# d3 = self.dec3(d2, x1)
|
|
||||||
# d3 = self.dec3_att(d3)
|
|
||||||
|
|
||||||
# 最终自适应上采样
|
|
||||||
h, w = noise.shape[1:3] # 获取原始输入尺寸
|
|
||||||
return self.out(d2, (h, w))
|
|
||||||
|
|
||||||
|
|
||||||
@ -1,79 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import torchvision.ops as ops
|
|
||||||
|
|
||||||
class HybridUpsample(nn.Module):
|
|
||||||
"""自适应尺寸的混合上采样"""
|
|
||||||
def __init__(self, in_ch, out_ch, skip_ch=None):
|
|
||||||
super().__init__()
|
|
||||||
# 子像素卷积上采样
|
|
||||||
self.subpixel = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch * 4, 3, padding=1),
|
|
||||||
nn.PixelShuffle(2) # 2倍上采样
|
|
||||||
)
|
|
||||||
|
|
||||||
# 跳跃连接处理
|
|
||||||
self.skip_conv = nn.Conv2d(skip_ch, out_ch, 1) if skip_ch else None
|
|
||||||
self.adaptive_pool = nn.AdaptiveAvgPool2d(None)
|
|
||||||
|
|
||||||
def forward(self, x, skip=None):
|
|
||||||
x = self.subpixel(x) # [B, out_ch, 2H, 2W]
|
|
||||||
|
|
||||||
if skip is not None and self.skip_conv:
|
|
||||||
# 自动对齐尺寸
|
|
||||||
if x.shape[-2:] != skip.shape[-2:]:
|
|
||||||
skip = F.interpolate(skip, size=x.shape[-2:], mode='nearest')
|
|
||||||
|
|
||||||
# 融合特征
|
|
||||||
x = x + self.skip_conv(skip)
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
class DiscreteAwareUpsample(nn.Module):
|
|
||||||
"""离散感知的智能上采样模块"""
|
|
||||||
def __init__(self, in_ch, out_ch, base_size=16):
|
|
||||||
super().__init__()
|
|
||||||
self.base_size = base_size
|
|
||||||
self.scale_factors = [2, 4, 8] # 支持放大倍数
|
|
||||||
|
|
||||||
# 可变形卷积增强几何感知
|
|
||||||
self.deform_conv = ops.DeformConv2d(in_ch, in_ch, kernel_size=3, padding=1)
|
|
||||||
|
|
||||||
# 多尺度特征融合
|
|
||||||
self.multi_scale = nn.ModuleList([
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, in_ch//4, 1),
|
|
||||||
nn.Upsample(scale_factor=s, mode='nearest')
|
|
||||||
) for s in self.scale_factors
|
|
||||||
])
|
|
||||||
|
|
||||||
# 门控上采样机制
|
|
||||||
self.gate_conv = nn.Conv2d(in_ch*2, len(self.scale_factors)+1, 3, padding=1)
|
|
||||||
|
|
||||||
# 离散化输出层
|
|
||||||
self.final_conv = nn.Sequential(
|
|
||||||
nn.Conv2d(in_ch, out_ch*4, 3, padding=1),
|
|
||||||
nn.PixelShuffle(2), # 亚像素卷积
|
|
||||||
nn.Conv2d(out_ch, out_ch, 3, padding=1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x, target_size):
|
|
||||||
# 几何特征提取
|
|
||||||
deform_feat = self.deform_conv(x)
|
|
||||||
|
|
||||||
# 生成多尺度特征
|
|
||||||
scale_features = [f(deform_feat) for f in self.multi_scale]
|
|
||||||
|
|
||||||
# 动态门控选择
|
|
||||||
gate_map = F.softmax(self.gate_conv(torch.cat([x, deform_feat], dim=1)), dim=1)
|
|
||||||
|
|
||||||
# 加权融合多尺度特征
|
|
||||||
combined = sum(g * F.interpolate(f, size=target_size, mode='nearest')
|
|
||||||
for g, f in zip(gate_map.unbind(1), scale_features+[x]))
|
|
||||||
|
|
||||||
# 离散化上采样
|
|
||||||
out = self.final_conv(combined)
|
|
||||||
|
|
||||||
# 结构化约束(保持通道独立性)
|
|
||||||
return out.argmax(dim=1).unsqueeze(1).float() # 伪梯度保留
|
|
||||||
91
ginka/model/unet.py
Normal file
91
ginka/model/unet.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
class GinkaEncoder(nn.Module):
|
||||||
|
"""编码器(下采样)部分"""
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.pool = nn.MaxPool2d(2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_res = self.conv(x) # 卷积提取特征
|
||||||
|
x_down = self.pool(x_res) # 进行池化
|
||||||
|
return x_down, x_res # 返回池化后的特征和跳跃连接特征
|
||||||
|
|
||||||
|
class GinkaDecoder(nn.Module):
|
||||||
|
"""解码器(上采样)部分"""
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
# 上采样(双线性插值 + 卷积)
|
||||||
|
self.upsample = nn.Sequential(
|
||||||
|
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True),
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
# 跳跃连接融合
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x, skip):
|
||||||
|
x = self.upsample(x)
|
||||||
|
# 跳跃连接融合
|
||||||
|
x = torch.cat([x, skip], dim=1)
|
||||||
|
x = self.conv(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GinkaBottleneck(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels):
|
||||||
|
super().__init__()
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
||||||
|
nn.BatchNorm2d(out_channels),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv(x)
|
||||||
|
|
||||||
|
class GinkaUNet(nn.Module):
|
||||||
|
def __init__(self, in_ch=64, out_ch=32):
|
||||||
|
"""Ginka Model UNet 部分
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.down1 = GinkaEncoder(in_ch, in_ch*2)
|
||||||
|
self.down2 = GinkaEncoder(in_ch*2, in_ch*4)
|
||||||
|
|
||||||
|
self.bottleneck = GinkaBottleneck(in_ch*4, in_ch*4)
|
||||||
|
|
||||||
|
self.up1 = GinkaDecoder(in_ch*4, in_ch*2)
|
||||||
|
self.up2 = GinkaDecoder(in_ch*2, in_ch)
|
||||||
|
|
||||||
|
self.final = nn.Sequential(
|
||||||
|
nn.Conv2d(in_ch, out_ch, 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x, skip1 = self.down1(x)
|
||||||
|
x, skip2 = self.down2(x)
|
||||||
|
|
||||||
|
x = self.bottleneck(x)
|
||||||
|
|
||||||
|
x = self.up1(x, skip2)
|
||||||
|
x = self.up2(x, skip1)
|
||||||
|
|
||||||
|
return self.final(x)
|
||||||
@ -9,87 +9,57 @@ from tqdm import tqdm
|
|||||||
from .model.model import GinkaModel
|
from .model.model import GinkaModel
|
||||||
from .model.loss import GinkaLoss
|
from .model.loss import GinkaLoss
|
||||||
from .dataset import GinkaDataset
|
from .dataset import GinkaDataset
|
||||||
|
from minamo.model.model import MinamoModel
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
|
|
||||||
epochs = 100
|
epochs = 70
|
||||||
|
|
||||||
def collate_fn(batch):
|
def update_tau(epoch):
|
||||||
# 动态填充噪声到最大尺寸
|
start_tau = 1.0
|
||||||
max_h = max([b["noise"].shape[0] for b in batch])
|
min_tau = 0.1
|
||||||
max_w = max([b["noise"].shape[1] for b in batch])
|
decay_rate = 0.95
|
||||||
|
return max(min_tau, start_tau * (decay_rate ** epoch))
|
||||||
padded_batch = {}
|
|
||||||
for key in ["noise", "target"]:
|
|
||||||
padded = []
|
|
||||||
for b in batch:
|
|
||||||
tensor = b[key]
|
|
||||||
pad_h = max_h - tensor.shape[0]
|
|
||||||
pad_w = max_w - tensor.shape[1]
|
|
||||||
padded.append(F.pad(tensor, (0, pad_w, 0, pad_h), value=-100 if key=="target" else 0))
|
|
||||||
padded_batch[key] = torch.stack(padded)
|
|
||||||
|
|
||||||
# 其他字段直接堆叠
|
|
||||||
for key in ["input_ids", "attention_mask", "map_size"]:
|
|
||||||
padded_batch[key] = torch.stack([b[key] for b in batch])
|
|
||||||
|
|
||||||
return padded_batch
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
minamo = MinamoModel(32)
|
||||||
|
minamo.to(device)
|
||||||
|
minamo.eval()
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-chinese')
|
dataset = GinkaDataset("dataset.json", minamo)
|
||||||
dataset = GinkaDataset("F:/github-ai/ginka-generator/dataset.json", tokenizer)
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=4,
|
batch_size=32,
|
||||||
shuffle=True,
|
shuffle=True
|
||||||
collate_fn=collate_fn,
|
|
||||||
num_workers=0
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=3e-4)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss()
|
criterion = GinkaLoss(minamo)
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
model.softmax.tau = update_tau(epoch)
|
||||||
# 温度退火
|
|
||||||
model.gumbel.tau = max(0.1, 1.0 - 0.9 * epoch / epochs)
|
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
noise = batch["noise"].to(device)
|
|
||||||
input_ids = batch["input_ids"].to(device)
|
|
||||||
attention_mask = batch["attention_mask"].to(device)
|
|
||||||
map_size = batch["map_size"].to(device)
|
|
||||||
target = batch["target"].to(device)
|
target = batch["target"].to(device)
|
||||||
|
feat_vec = batch["feat_vec"].to(device)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
outputs = model(noise, input_ids, attention_mask, map_size)
|
output, output_softmax = model(feat_vec)
|
||||||
|
|
||||||
print(torch.argmax(torch.softmax(outputs, dim=1), dim=1))
|
|
||||||
# print(sampled[0, :, :, 1])
|
|
||||||
|
|
||||||
# 构建拓扑图
|
|
||||||
# with torch.no_grad():
|
|
||||||
# pred_graphs = build_topology_graph(outputs.argmax(1))
|
|
||||||
# ref_graphs = build_topology_graph(target)
|
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
loss = criterion(
|
loss = criterion(output, output_softmax, target)
|
||||||
outputs, # 调整为 [BS, C, H, W]
|
|
||||||
target
|
|
||||||
)
|
|
||||||
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|||||||
@ -11,11 +11,8 @@ class MinamoModel(nn.Module):
|
|||||||
# 拓扑相似度部分
|
# 拓扑相似度部分
|
||||||
self.topo_model = MinamoTopoModel(tile_types)
|
self.topo_model = MinamoTopoModel(tile_types)
|
||||||
|
|
||||||
def forward(self, map1, map2, graph1, graph2):
|
def forward(self, map, graph):
|
||||||
vision_feat1 = self.vision_model(map1)
|
vision_feat = self.vision_model(map)
|
||||||
vision_feat2 = self.vision_model(map2)
|
topo_feat = self.topo_model(graph)
|
||||||
|
|
||||||
topo_feat1 = self.topo_model(graph1)
|
return vision_feat, topo_feat
|
||||||
topo_feat2 = self.topo_model(graph2)
|
|
||||||
|
|
||||||
return vision_feat1, vision_feat2, topo_feat1, topo_feat2
|
|
||||||
|
|||||||
@ -76,7 +76,8 @@ def train():
|
|||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
vision_feat1, vision_feat2, topo_feat1, topo_feat2 = model(map1, map2, graph1, graph2)
|
vision_feat1, topo_feat1 = model(map1, graph1)
|
||||||
|
vision_feat2, topo_feat2 = model(map2, graph2)
|
||||||
|
|
||||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
import torch.nn.functional as F
|
||||||
|
from torch_geometric.loader import DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from .model.model import MinamoModel
|
from .model.model import MinamoModel
|
||||||
from .model.loss import MinamoLoss
|
from .model.loss import MinamoLoss
|
||||||
@ -8,32 +9,38 @@ from .dataset import MinamoDataset
|
|||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
def validate():
|
def validate():
|
||||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
model = MinamoModel(32)
|
model = MinamoModel(32)
|
||||||
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
# 准备数据集
|
# 准备数据集
|
||||||
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
val_dataset = MinamoDataset("minamo-eval.json")
|
||||||
val_loader = DataLoader(
|
val_loader = DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=32,
|
batch_size=32,
|
||||||
shuffle=True
|
shuffle=True
|
||||||
)
|
)
|
||||||
|
|
||||||
criterion = MinamoLoss(temp=0.8)
|
criterion = MinamoLoss()
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for val_batch in tqdm(val_loader):
|
for val_batch in tqdm(val_loader):
|
||||||
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
|
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||||
map1_val = map1_val.to(device)
|
map1_val = map1_val.to(device)
|
||||||
map2_val = map2_val.to(device)
|
map2_val = map2_val.to(device)
|
||||||
vision_simi_val = vision_simi_val.to(device)
|
vision_simi_val = vision_simi_val.to(device)
|
||||||
topo_simi_val = topo_simi_val.to(device)
|
topo_simi_val = topo_simi_val.to(device)
|
||||||
|
graph1 = graph1.to(device)
|
||||||
|
graph2 = graph2.to(device)
|
||||||
|
|
||||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
|
vision_feat1, topo_feat1 = model(map1_val, graph1)
|
||||||
|
vision_feat2, topo_feat2 = model(map2_val, graph2)
|
||||||
|
|
||||||
|
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||||
|
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||||
loss_val = criterion(
|
loss_val = criterion(
|
||||||
vision_pred_val, topo_pred_val,
|
vision_pred_val, topo_pred_val,
|
||||||
vision_simi_val, topo_simi_val
|
vision_simi_val, topo_simi_val
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user