diff --git a/README.md b/README.md index 3c1acb9..52387c3 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ GINKA Model 内部集成了 Minamo Model 用于判别两个地图的相似性, 3. 楼层中不应该有闲置怪,不应该有连续 3 个以上的怪物,不应该有无法到达的区域,不宜有过多的入口 4. 最外面一层围上一圈墙壁(箭头楼层切换除外) 5. 将所有的墙壁换成黄墙(数字 1) -6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),删除除此之外的资源,剑盾可以当成红蓝宝石看待 +6. 将所有的血瓶换成红血瓶(数字 31),所有红宝石换成最基础的红宝石(数字 27),蓝宝石换成最基础的蓝宝石(数字 28),道具全部换为幸运金币(数字 53),剑盾可以当成红蓝宝石看待,删除除此之外的资源 7. 所有钥匙换成黄钥匙(数字 21),所有门换成黄门(数字 81) 8. 所有箭头换成样板原版箭头(数字 161 至 164),所有上下楼梯换成样板原版楼梯(数字 87 和 88) 9. 怪物分为三个强度,弱怪,中怪,强怪,弱怪换为绿头怪(数字 201),中怪换成红头怪(数字 202),强怪换成青头怪(数字 203) diff --git a/data/src/floor.ts b/data/src/floor.ts index 3de9db4..d4b802d 100644 --- a/data/src/floor.ts +++ b/data/src/floor.ts @@ -14,7 +14,8 @@ const numMap: Record = { 161: 11, // 箭头 162: 11, // 箭头 163: 11, // 箭头 - 164: 11 // 箭头 + 164: 11, // 箭头 + 53: 12 // 道具 }; export function convertFloor( diff --git a/data/src/minamo.ts b/data/src/minamo.ts index 2d6f08f..d2e53ea 100644 --- a/data/src/minamo.ts +++ b/data/src/minamo.ts @@ -1,18 +1,10 @@ -import { readFile, writeFile } from 'fs-extra'; -import { join } from 'path'; -import { convertFloor } from './floor'; -import { - FloorData, - getAllFloors, - mergeDataset, - mergeFloorIds, - parseTowerInfo -} from './utils'; +import { writeFile } from 'fs-extra'; +import { FloorData, getAllFloors, parseTowerInfo } from './utils'; import { compareMap } from './topology/compare'; import { mirrorMapX, mirrorMapY, rotateMap } from './topology/transform'; import { directions, tileType } from './topology/graph'; import { calculateVisualSimilarity } from './vision/similarity'; -import { BaseConfig, TowerInfo } from './types'; +import { BaseConfig } from './types'; interface MinamoConfig extends BaseConfig {} diff --git a/data/src/topology/graph.ts b/data/src/topology/graph.ts index eeee6a5..5345f08 100644 --- a/data/src/topology/graph.ts +++ b/data/src/topology/graph.ts @@ -7,13 +7,13 @@ import { } from './interface'; export const tileType = new Set( - Array(12) + Array(13) .fill(0) .map((_, i) => i) ); const branchType = new Set([6, 7, 8, 9]); const entranceType = new Set([10, 11]); -const resourceType = new Set([0, 2, 3, 4, 5, 10, 11]); +const resourceType = new Set([0, 2, 3, 4, 5, 10, 11, 12]); export const directions: [number, number][] = [ [-1, 0], diff --git a/data/src/vision/similarity.ts b/data/src/vision/similarity.ts index 85e131b..5caab99 100644 --- a/data/src/vision/similarity.ts +++ b/data/src/vision/similarity.ts @@ -20,7 +20,8 @@ const DEFAULT_CONFIG: VisualSimilarityConfig = { 8: 0.6, // 中怪 9: 0.6, // 强怪 10: 0.4, // 楼梯 - 11: 0.4 // 箭头 + 11: 0.4, // 箭头 + 12: 0.7 // 道具 }, enableVisualFocus: true, enableDensityAwareness: true diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 94a2cdf..4ada733 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=0.4, topo_weight=0.6): + def __init__(self, vision_weight=0, topo_weight=1): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight @@ -11,4 +11,4 @@ class MinamoLoss(nn.Module): # print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item()) vis_loss = self.mse(vis_pred, vis_true) topo_loss = self.mse(topo_pred, topo_true) - return self.vision_weight * vis_loss + self.topo_weight * topo_loss \ No newline at end of file + return self.vision_weight * vis_loss + self.topo_weight * topo_loss diff --git a/minamo/model/model.py b/minamo/model/model.py index 5ebe553..1ea8ad0 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -45,77 +45,52 @@ class DirectionalAttention(nn.Module): return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1) class MinamoModel(nn.Module): - def __init__(self, num_tile_types, embedding_dim=64, conv_channels=256): + def __init__(self, tile_types=32, embedding_dim=64, conv_channels=256): super().__init__() # 嵌入层处理不同图块类型 - self.embedding = nn.Embedding(num_tile_types, embedding_dim) + self.embedding = nn.Embedding(tile_types, embedding_dim) - # 共享特征提取的卷积层 - self.conv_layers = nn.Sequential( + self.vision_conv = nn.Sequential( nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), DualAttention(conv_channels), - DirectionalAttention(), - nn.ReLU(), nn.BatchNorm2d(conv_channels), - + nn.ReLU(), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), DualAttention(conv_channels*2), - DirectionalAttention(), - nn.ReLU(), - nn.BatchNorm2d(conv_channels*2), - - nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), - DualAttention(conv_channels*4), - DirectionalAttention(), - nn.ReLU(), - nn.BatchNorm2d(conv_channels*4), + nn.AdaptiveAvgPool2d(1) ) - # 自适应池化处理任意尺寸 - self.pool = nn.ModuleDict({ - 'avg': nn.AdaptiveAvgPool2d((1,1)), - 'max': nn.AdaptiveMaxPool2d((1,1)) - }) + # 拓扑特征分支 + self.topo_conv = nn.Sequential( + nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构 + nn.MaxPool2d(2), + # GraphConvLayer(128, 256), # 图卷积层 + nn.AdaptiveMaxPool2d(1) + ) # 多任务预测头 - head_dim = conv_channels * 4 * 2 * 4 # 2个池化,四个交互项 self.vision_head = nn.Sequential( - nn.Linear(head_dim, 512), - nn.ReLU(), - nn.Dropout(0.3), - nn.Linear(512, 1), + nn.Linear(conv_channels*2, 1), nn.Sigmoid() ) self.topo_head = nn.Sequential( - nn.Linear(head_dim, 512), - nn.GELU(), - nn.Dropout(0.3), - nn.Linear(512, 256), - nn.GELU(), - nn.Linear(256, 1), + nn.Linear(conv_channels, 1), nn.Sigmoid() ) - def forward(self, map1, map2): - # 增强特征提取 - def process_map(x): - x = self.embedding(x).permute(0,3,1,2) - x = self.conv_layers(x) - return torch.cat([ - self.pool['avg'](x), - self.pool['max'](x) - ], dim=1).flatten(1) + e1 = self.embedding(map1).permute(0, 3, 1, 2) + e2 = self.embedding(map2).permute(0, 3, 1, 2) - f1 = process_map(map1) - f2 = process_map(map2) + v1 = self.vision_conv(e1).squeeze() + v2 = self.vision_conv(e2).squeeze() - # 特征融合 - combined = torch.cat([f1, f2, f1-f2, f1*f2], dim=1) # [B, 256] + t1 = self.topo_conv(e1).squeeze() + t2 = self.topo_conv(e2).squeeze() # 多任务输出 - vision_sim = self.vision_head(combined) - topo_sim = self.topo_head(combined) + vision_sim = self.vision_head(torch.abs(v1 - v2)) + topo_sim = self.topo_head(torch.abs(t1 - t2)) return vision_sim, topo_sim diff --git a/minamo/train.py b/minamo/train.py index 179f424..7143a72 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -10,6 +10,7 @@ from .dataset import MinamoDataset device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) +os.makedirs("result/minamo_checkpoint", exist_ok=True) epochs = 100 @@ -51,7 +52,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4) + optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=5e-3) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) criterion = MinamoLoss() @@ -91,7 +92,7 @@ def train(): # tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间 ave_loss = total_loss / len(dataloader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}") # 学习率调整 scheduler.step() @@ -116,8 +117,11 @@ def train(): val_loss += loss_val.item() avg_val_loss = val_loss / len(val_loader) - tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation :: loss: {avg_val_loss:.6f}") - + tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") + torch.save({ + "model_state": model.state_dict(), + "optimizer_state": optimizer.state_dict(), + }, f"result/minamo_checkpoint/{epoch + 1}.pth") print("Train ended.") diff --git a/minamo/validate.py b/minamo/validate.py new file mode 100644 index 0000000..ca5b629 --- /dev/null +++ b/minamo/validate.py @@ -0,0 +1,49 @@ +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from .model.model import MinamoModel +from .model.loss import MinamoLoss +from .dataset import MinamoDataset + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def validate(): + print(f"Using {"cuda" if torch.cuda.is_available() else "cpu"} to validate model.") + model = MinamoModel(32) + model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"]) + model.to(device) + + # 准备数据集 + val_dataset = MinamoDataset("F:/github-ai/ginka-generator/minamo-eval.json") + val_loader = DataLoader( + val_dataset, + batch_size=32, + shuffle=True + ) + + criterion = MinamoLoss(temp=0.8) + + model.eval() + val_loss = 0 + with torch.no_grad(): + for val_batch in tqdm(val_loader): + map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch + map1_val = map1_val.to(device) + map2_val = map2_val.to(device) + vision_simi_val = vision_simi_val.to(device) + topo_simi_val = topo_simi_val.to(device) + + vision_pred_val, topo_pred_val = model(map1_val, map2_val) + loss_val = criterion( + vision_pred_val, topo_pred_val, + vision_simi_val, topo_simi_val + ) + val_loss += loss_val.item() + + avg_val_loss = val_loss / len(val_loader) + tqdm.write(f"Validation::loss: {avg_val_loss:.6f}") + +if __name__ == "__main__": + torch.set_num_threads(2) + validate() + \ No newline at end of file