refactor: Minamo Model 改为双特征提取通道

This commit is contained in:
unanmed 2025-03-16 16:31:40 +08:00
parent 5a19a23518
commit c34756016d
9 changed files with 91 additions and 69 deletions

View File

@ -13,7 +13,7 @@ GINKA Model 内部集成了 Minamo Model 用于判别两个地图的相似性,
3. 楼层中不应该有闲置怪,不应该有连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
4. 最外面一层围上一圈墙壁(箭头楼层切换除外)
5. 将所有的墙壁换成黄墙(数字 1
6. 将所有的血瓶换成红血瓶(数字 31所有红宝石换成最基础的红宝石数字 27蓝宝石换成最基础的蓝宝石数字 28删除除此之外的资源,剑盾可以当成红蓝宝石看待
6. 将所有的血瓶换成红血瓶(数字 31所有红宝石换成最基础的红宝石数字 27蓝宝石换成最基础的蓝宝石数字 28道具全部换为幸运金币(数字 53剑盾可以当成红蓝宝石看待删除除此之外的资源
7. 所有钥匙换成黄钥匙(数字 21所有门换成黄门数字 81
8. 所有箭头换成样板原版箭头(数字 161 至 164所有上下楼梯换成样板原版楼梯数字 87 和 88
9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201中怪换成红头怪数字 202强怪换成青头怪数字 203

View File

@ -14,7 +14,8 @@ const numMap: Record<number, number> = {
161: 11, // 箭头
162: 11, // 箭头
163: 11, // 箭头
164: 11 // 箭头
164: 11, // 箭头
53: 12 // 道具
};
export function convertFloor(

View File

@ -1,18 +1,10 @@
import { readFile, writeFile } from 'fs-extra';
import { join } from 'path';
import { convertFloor } from './floor';
import {
FloorData,
getAllFloors,
mergeDataset,
mergeFloorIds,
parseTowerInfo
} from './utils';
import { writeFile } from 'fs-extra';
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
import { compareMap } from './topology/compare';
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
import { directions, tileType } from './topology/graph';
import { calculateVisualSimilarity } from './vision/similarity';
import { BaseConfig, TowerInfo } from './types';
import { BaseConfig } from './types';
interface MinamoConfig extends BaseConfig {}

View File

@ -7,13 +7,13 @@ import {
} from './interface';
export const tileType = new Set(
Array(12)
Array(13)
.fill(0)
.map((_, i) => i)
);
const branchType = new Set([6, 7, 8, 9]);
const entranceType = new Set([10, 11]);
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11]);
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]);
export const directions: [number, number][] = [
[-1, 0],

View File

@ -20,7 +20,8 @@ const DEFAULT_CONFIG: VisualSimilarityConfig = {
8: 0.6, // 中怪
9: 0.6, // 强怪
10: 0.4, // 楼梯
11: 0.4 // 箭头
11: 0.4, // 箭头
12: 0.7 // 道具
},
enableVisualFocus: true,
enableDensityAwareness: true

View File

@ -1,7 +1,7 @@
import torch.nn as nn
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0.4, topo_weight=0.6):
def __init__(self, vision_weight=0, topo_weight=1):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight
@ -11,4 +11,4 @@ class MinamoLoss(nn.Module):
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
vis_loss = self.mse(vis_pred, vis_true)
topo_loss = self.mse(topo_pred, topo_true)
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
return self.vision_weight * vis_loss + self.topo_weight * topo_loss

View File

@ -45,77 +45,52 @@ class DirectionalAttention(nn.Module):
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
class MinamoModel(nn.Module):
def __init__(self, num_tile_types, embedding_dim=64, conv_channels=256):
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(num_tile_types, embedding_dim)
self.embedding = nn.Embedding(tile_types, embedding_dim)
# 共享特征提取的卷积层
self.conv_layers = nn.Sequential(
self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
DualAttention(conv_channels),
DirectionalAttention(),
nn.ReLU(),
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2),
DirectionalAttention(),
nn.ReLU(),
nn.BatchNorm2d(conv_channels*2),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
DirectionalAttention(),
nn.ReLU(),
nn.BatchNorm2d(conv_channels*4),
nn.AdaptiveAvgPool2d(1)
)
# 自适应池化处理任意尺寸
self.pool = nn.ModuleDict({
'avg': nn.AdaptiveAvgPool2d((1,1)),
'max': nn.AdaptiveMaxPool2d((1,1))
})
# 拓扑特征分支
self.topo_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
nn.MaxPool2d(2),
# GraphConvLayer(128, 256), # 图卷积层
nn.AdaptiveMaxPool2d(1)
)
# 多任务预测头
head_dim = conv_channels * 4 * 2 * 4 # 2个池化四个交互项
self.vision_head = nn.Sequential(
nn.Linear(head_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, 1),
nn.Linear(conv_channels*2, 1),
nn.Sigmoid()
)
self.topo_head = nn.Sequential(
nn.Linear(head_dim, 512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, 256),
nn.GELU(),
nn.Linear(256, 1),
nn.Linear(conv_channels, 1),
nn.Sigmoid()
)
def forward(self, map1, map2):
# 增强特征提取
def process_map(x):
x = self.embedding(x).permute(0,3,1,2)
x = self.conv_layers(x)
return torch.cat([
self.pool['avg'](x),
self.pool['max'](x)
], dim=1).flatten(1)
e1 = self.embedding(map1).permute(0, 3, 1, 2)
e2 = self.embedding(map2).permute(0, 3, 1, 2)
f1 = process_map(map1)
f2 = process_map(map2)
v1 = self.vision_conv(e1).squeeze()
v2 = self.vision_conv(e2).squeeze()
# 特征融合
combined = torch.cat([f1, f2, f1-f2, f1*f2], dim=1) # [B, 256]
t1 = self.topo_conv(e1).squeeze()
t2 = self.topo_conv(e2).squeeze()
# 多任务输出
vision_sim = self.vision_head(combined)
topo_sim = self.topo_head(combined)
vision_sim = self.vision_head(torch.abs(v1 - v2))
topo_sim = self.topo_head(torch.abs(t1 - t2))
return vision_sim, topo_sim

View File

@ -10,6 +10,7 @@ from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
epochs = 100
@ -51,7 +52,7 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
criterion = MinamoLoss()
@ -91,7 +92,7 @@ def train():
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
ave_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler.step()
@ -116,8 +117,11 @@ def train():
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation :: loss: {avg_val_loss:.6f}")
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
print("Train ended.")

49
minamo/validate.py Normal file
View File

@ -0,0 +1,49 @@
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device)
# 准备数据集
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
criterion = MinamoLoss(temp=0.8)
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(val_loader):
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val
)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
if __name__ == "__main__":
torch.set_num_threads(2)
validate()