mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 16:41:10 +08:00
refactor: Minamo Model 改为双特征提取通道
This commit is contained in:
parent
5a19a23518
commit
c34756016d
@ -13,7 +13,7 @@ GINKA Model 内部集成了 Minamo Model 用于判别两个地图的相似性,
|
||||
3. 楼层中不应该有闲置怪,不应该有连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口
|
||||
4. 最外面一层围上一圈墙壁(箭头楼层切换除外)
|
||||
5. 将所有的墙壁换成黄墙(数字 1)
|
||||
6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),删除除此之外的资源,剑盾可以当成红蓝宝石看待
|
||||
6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源
|
||||
7. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81)
|
||||
8. 所有箭头换成样板原版箭头(数字 161 至 164),所有上下楼梯换成样板原版楼梯(数字 87 和 88)
|
||||
9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203)
|
||||
|
||||
@ -14,7 +14,8 @@ const numMap: Record<number, number> = {
|
||||
161: 11, // 箭头
|
||||
162: 11, // 箭头
|
||||
163: 11, // 箭头
|
||||
164: 11 // 箭头
|
||||
164: 11, // 箭头
|
||||
53: 12 // 道具
|
||||
};
|
||||
|
||||
export function convertFloor(
|
||||
|
||||
@ -1,18 +1,10 @@
|
||||
import { readFile, writeFile } from 'fs-extra';
|
||||
import { join } from 'path';
|
||||
import { convertFloor } from './floor';
|
||||
import {
|
||||
FloorData,
|
||||
getAllFloors,
|
||||
mergeDataset,
|
||||
mergeFloorIds,
|
||||
parseTowerInfo
|
||||
} from './utils';
|
||||
import { writeFile } from 'fs-extra';
|
||||
import { FloorData, getAllFloors, parseTowerInfo } from './utils';
|
||||
import { compareMap } from './topology/compare';
|
||||
import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform';
|
||||
import { directions, tileType } from './topology/graph';
|
||||
import { calculateVisualSimilarity } from './vision/similarity';
|
||||
import { BaseConfig, TowerInfo } from './types';
|
||||
import { BaseConfig } from './types';
|
||||
|
||||
interface MinamoConfig extends BaseConfig {}
|
||||
|
||||
|
||||
@ -7,13 +7,13 @@ import {
|
||||
} from './interface';
|
||||
|
||||
export const tileType = new Set(
|
||||
Array(12)
|
||||
Array(13)
|
||||
.fill(0)
|
||||
.map((_, i) => i)
|
||||
);
|
||||
const branchType = new Set([6, 7, 8, 9]);
|
||||
const entranceType = new Set([10, 11]);
|
||||
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11]);
|
||||
const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]);
|
||||
|
||||
export const directions: [number, number][] = [
|
||||
[-1, 0],
|
||||
|
||||
@ -20,7 +20,8 @@ const DEFAULT_CONFIG: VisualSimilarityConfig = {
|
||||
8: 0.6, // 中怪
|
||||
9: 0.6, // 强怪
|
||||
10: 0.4, // 楼梯
|
||||
11: 0.4 // 箭头
|
||||
11: 0.4, // 箭头
|
||||
12: 0.7 // 道具
|
||||
},
|
||||
enableVisualFocus: true,
|
||||
enableDensityAwareness: true
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||
def __init__(self, vision_weight=0, topo_weight=1):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
@ -11,4 +11,4 @@ class MinamoLoss(nn.Module):
|
||||
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
|
||||
vis_loss = self.mse(vis_pred, vis_true)
|
||||
topo_loss = self.mse(topo_pred, topo_true)
|
||||
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
|
||||
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
|
||||
|
||||
@ -45,77 +45,52 @@ class DirectionalAttention(nn.Module):
|
||||
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, num_tile_types, embedding_dim=64, conv_channels=256):
|
||||
def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(num_tile_types, embedding_dim)
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
|
||||
# 共享特征提取的卷积层
|
||||
self.conv_layers = nn.Sequential(
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
DualAttention(conv_channels),
|
||||
DirectionalAttention(),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
DualAttention(conv_channels*2),
|
||||
DirectionalAttention(),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
DualAttention(conv_channels*4),
|
||||
DirectionalAttention(),
|
||||
nn.ReLU(),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
|
||||
# 自适应池化处理任意尺寸
|
||||
self.pool = nn.ModuleDict({
|
||||
'avg': nn.AdaptiveAvgPool2d((1,1)),
|
||||
'max': nn.AdaptiveMaxPool2d((1,1))
|
||||
})
|
||||
# 拓扑特征分支
|
||||
self.topo_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.MaxPool2d(2),
|
||||
# GraphConvLayer(128, 256), # 图卷积层
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
# 多任务预测头
|
||||
head_dim = conv_channels * 4 * 2 * 4 # 2个池化,四个交互项
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Linear(head_dim, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 1),
|
||||
nn.Linear(conv_channels*2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.topo_head = nn.Sequential(
|
||||
nn.Linear(head_dim, 512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.GELU(),
|
||||
nn.Linear(256, 1),
|
||||
nn.Linear(conv_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
|
||||
def forward(self, map1, map2):
|
||||
# 增强特征提取
|
||||
def process_map(x):
|
||||
x = self.embedding(x).permute(0,3,1,2)
|
||||
x = self.conv_layers(x)
|
||||
return torch.cat([
|
||||
self.pool['avg'](x),
|
||||
self.pool['max'](x)
|
||||
], dim=1).flatten(1)
|
||||
e1 = self.embedding(map1).permute(0, 3, 1, 2)
|
||||
e2 = self.embedding(map2).permute(0, 3, 1, 2)
|
||||
|
||||
f1 = process_map(map1)
|
||||
f2 = process_map(map2)
|
||||
v1 = self.vision_conv(e1).squeeze()
|
||||
v2 = self.vision_conv(e2).squeeze()
|
||||
|
||||
# 特征融合
|
||||
combined = torch.cat([f1, f2, f1-f2, f1*f2], dim=1) # [B, 256]
|
||||
t1 = self.topo_conv(e1).squeeze()
|
||||
t2 = self.topo_conv(e2).squeeze()
|
||||
|
||||
# 多任务输出
|
||||
vision_sim = self.vision_head(combined)
|
||||
topo_sim = self.topo_head(combined)
|
||||
vision_sim = self.vision_head(torch.abs(v1 - v2))
|
||||
topo_sim = self.topo_head(torch.abs(t1 - t2))
|
||||
|
||||
return vision_sim, topo_sim
|
||||
|
||||
@ -10,6 +10,7 @@ from .dataset import MinamoDataset
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 100
|
||||
|
||||
@ -51,7 +52,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-3)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
@ -91,7 +92,7 @@ def train():
|
||||
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
|
||||
|
||||
ave_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
@ -116,8 +117,11 @@ def train():
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation :: loss: {avg_val_loss:.6f}")
|
||||
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
|
||||
49
minamo/validate.py
Normal file
49
minamo/validate.py
Normal file
@ -0,0 +1,49 @@
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
from .dataset import MinamoDataset
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def validate():
|
||||
print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.")
|
||||
model = MinamoModel(32)
|
||||
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
model.to(device)
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
criterion = MinamoLoss(temp=0.8)
|
||||
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(val_loader):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
vision_simi_val = vision_simi_val.to(device)
|
||||
topo_simi_val = topo_simi_val.to(device)
|
||||
|
||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
|
||||
loss_val = criterion(
|
||||
vision_pred_val, topo_pred_val,
|
||||
vision_simi_val, topo_simi_val
|
||||
)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(2)
|
||||
validate()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user