diff --git a/data/src/auto/heatmap.ts b/data/src/auto/heatmap.ts index 2a27844..40876fe 100644 --- a/data/src/auto/heatmap.ts +++ b/data/src/auto/heatmap.ts @@ -20,9 +20,9 @@ export function generateHeatmap( for (let y = 0; y < height; y++) { for (let x = 0; x < width; x++) { const left = Math.max(0, x - radius); - const right = Math.min(width, x + radius); + const right = Math.min(width, x + radius + 1); const top = Math.max(0, y - radius); - const bottom = Math.min(height, y + radius); + const bottom = Math.min(height, y + radius + 1); const size = (right - left) * (bottom - top); let num = 0; for (let ky = top; ky < bottom; ky++) { @@ -54,9 +54,9 @@ export function gaussainHeatmap(map: number[][], sigma: number = 1) { for (let y = 0; y < height; y++) { for (let x = 0; x < width; x++) { const left = Math.max(0, x - radius); - const right = Math.min(width - 1, x + radius); + const right = Math.min(width - 1, x + radius + 1); const top = Math.max(0, y - radius); - const bottom = Math.min(height - 1, y + radius); + const bottom = Math.min(height - 1, y + radius + 1); let res = 0; for (let ky = top; ky < bottom; ky++) { for (let kx = left; kx < right; kx++) { diff --git a/data/src/auto/info.ts b/data/src/auto/info.ts index eebb302..3e3495b 100644 --- a/data/src/auto/info.ts +++ b/data/src/auto/info.ts @@ -229,15 +229,17 @@ export function parseFloorInfo(tower: ITowerInfo, map: number[][]): IFloorInfo { fishCount, hasUselessBranch, wallDensityStd: computeWallDensityStd(map, wallTiles, 5), - wallHeatmap: gaussainHeatmap(generateHeatmap(map, wallTiles)), - enemyHeatmap: gaussainHeatmap(generateHeatmap(map, enemyTiles)), - resourceHeatmap: gaussainHeatmap(generateHeatmap(map, resourceTiles)), - potionHeatmap: gaussainHeatmap(generateHeatmap(map, potionTiles)), - gemHeatmap: gaussainHeatmap(generateHeatmap(map, gemTiles)), - keyHeatmap: gaussainHeatmap(generateHeatmap(map, keyTiles)), - itemHeatmap: gaussainHeatmap(generateHeatmap(map, itemTiles)), - entryHeatmap: gaussainHeatmap(generateHeatmap(map, entryTiles)), - doorHeatmap: gaussainHeatmap(generateHeatmap(map, doorTiles)) + wallHeatmap: gaussainHeatmap(generateHeatmap(map, wallTiles, 1)), + enemyHeatmap: gaussainHeatmap(generateHeatmap(map, enemyTiles, 1)), + resourceHeatmap: gaussainHeatmap( + generateHeatmap(map, resourceTiles, 1) + ), + potionHeatmap: gaussainHeatmap(generateHeatmap(map, potionTiles, 1)), + gemHeatmap: gaussainHeatmap(generateHeatmap(map, gemTiles, 1)), + keyHeatmap: gaussainHeatmap(generateHeatmap(map, keyTiles, 1)), + itemHeatmap: gaussainHeatmap(generateHeatmap(map, itemTiles, 1)), + entryHeatmap: gaussainHeatmap(generateHeatmap(map, entryTiles, 1)), + doorHeatmap: gaussainHeatmap(generateHeatmap(map, doorTiles, 1)) }; return floorInfo; diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 26d0df7..15a02d2 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -43,6 +43,7 @@ MASK_TOKEN = 15 GENERATE_STEP = 8 MAP_SIZE = 13 * 13 HEATMAP_CHANNEL = 9 +LABEL_SMOOTHING = 0.1 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 device = torch.device( @@ -129,7 +130,7 @@ def train(): logits = model(masked_input, cond, heatmap) - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING) loss = (loss * mask).sum() / (mask.sum() + 1e-6) optimizer.zero_grad() @@ -181,7 +182,7 @@ def train(): logits = model(masked_input, cond, heatmap) - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING) loss = (loss * mask).sum() / (mask.sum() + 1e-6) val_loss_total += loss.detach()