mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
chore: 调整热力图生成
This commit is contained in:
parent
36265a9bce
commit
266f50db73
@ -20,9 +20,9 @@ export function generateHeatmap(
|
|||||||
for (let y = 0; y < height; y++) {
|
for (let y = 0; y < height; y++) {
|
||||||
for (let x = 0; x < width; x++) {
|
for (let x = 0; x < width; x++) {
|
||||||
const left = Math.max(0, x - radius);
|
const left = Math.max(0, x - radius);
|
||||||
const right = Math.min(width, x + radius);
|
const right = Math.min(width, x + radius + 1);
|
||||||
const top = Math.max(0, y - radius);
|
const top = Math.max(0, y - radius);
|
||||||
const bottom = Math.min(height, y + radius);
|
const bottom = Math.min(height, y + radius + 1);
|
||||||
const size = (right - left) * (bottom - top);
|
const size = (right - left) * (bottom - top);
|
||||||
let num = 0;
|
let num = 0;
|
||||||
for (let ky = top; ky < bottom; ky++) {
|
for (let ky = top; ky < bottom; ky++) {
|
||||||
@ -54,9 +54,9 @@ export function gaussainHeatmap(map: number[][], sigma: number = 1) {
|
|||||||
for (let y = 0; y < height; y++) {
|
for (let y = 0; y < height; y++) {
|
||||||
for (let x = 0; x < width; x++) {
|
for (let x = 0; x < width; x++) {
|
||||||
const left = Math.max(0, x - radius);
|
const left = Math.max(0, x - radius);
|
||||||
const right = Math.min(width - 1, x + radius);
|
const right = Math.min(width - 1, x + radius + 1);
|
||||||
const top = Math.max(0, y - radius);
|
const top = Math.max(0, y - radius);
|
||||||
const bottom = Math.min(height - 1, y + radius);
|
const bottom = Math.min(height - 1, y + radius + 1);
|
||||||
let res = 0;
|
let res = 0;
|
||||||
for (let ky = top; ky < bottom; ky++) {
|
for (let ky = top; ky < bottom; ky++) {
|
||||||
for (let kx = left; kx < right; kx++) {
|
for (let kx = left; kx < right; kx++) {
|
||||||
|
|||||||
@ -229,15 +229,17 @@ export function parseFloorInfo(tower: ITowerInfo, map: number[][]): IFloorInfo {
|
|||||||
fishCount,
|
fishCount,
|
||||||
hasUselessBranch,
|
hasUselessBranch,
|
||||||
wallDensityStd: computeWallDensityStd(map, wallTiles, 5),
|
wallDensityStd: computeWallDensityStd(map, wallTiles, 5),
|
||||||
wallHeatmap: gaussainHeatmap(generateHeatmap(map, wallTiles)),
|
wallHeatmap: gaussainHeatmap(generateHeatmap(map, wallTiles, 1)),
|
||||||
enemyHeatmap: gaussainHeatmap(generateHeatmap(map, enemyTiles)),
|
enemyHeatmap: gaussainHeatmap(generateHeatmap(map, enemyTiles, 1)),
|
||||||
resourceHeatmap: gaussainHeatmap(generateHeatmap(map, resourceTiles)),
|
resourceHeatmap: gaussainHeatmap(
|
||||||
potionHeatmap: gaussainHeatmap(generateHeatmap(map, potionTiles)),
|
generateHeatmap(map, resourceTiles, 1)
|
||||||
gemHeatmap: gaussainHeatmap(generateHeatmap(map, gemTiles)),
|
),
|
||||||
keyHeatmap: gaussainHeatmap(generateHeatmap(map, keyTiles)),
|
potionHeatmap: gaussainHeatmap(generateHeatmap(map, potionTiles, 1)),
|
||||||
itemHeatmap: gaussainHeatmap(generateHeatmap(map, itemTiles)),
|
gemHeatmap: gaussainHeatmap(generateHeatmap(map, gemTiles, 1)),
|
||||||
entryHeatmap: gaussainHeatmap(generateHeatmap(map, entryTiles)),
|
keyHeatmap: gaussainHeatmap(generateHeatmap(map, keyTiles, 1)),
|
||||||
doorHeatmap: gaussainHeatmap(generateHeatmap(map, doorTiles))
|
itemHeatmap: gaussainHeatmap(generateHeatmap(map, itemTiles, 1)),
|
||||||
|
entryHeatmap: gaussainHeatmap(generateHeatmap(map, entryTiles, 1)),
|
||||||
|
doorHeatmap: gaussainHeatmap(generateHeatmap(map, doorTiles, 1))
|
||||||
};
|
};
|
||||||
|
|
||||||
return floorInfo;
|
return floorInfo;
|
||||||
|
|||||||
@ -43,6 +43,7 @@ MASK_TOKEN = 15
|
|||||||
GENERATE_STEP = 8
|
GENERATE_STEP = 8
|
||||||
MAP_SIZE = 13 * 13
|
MAP_SIZE = 13 * 13
|
||||||
HEATMAP_CHANNEL = 9
|
HEATMAP_CHANNEL = 9
|
||||||
|
LABEL_SMOOTHING = 0.1
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
@ -129,7 +130,7 @@ def train():
|
|||||||
|
|
||||||
logits = model(masked_input, cond, heatmap)
|
logits = model(masked_input, cond, heatmap)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING)
|
||||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||||
|
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@ -181,7 +182,7 @@ def train():
|
|||||||
|
|
||||||
logits = model(masked_input, cond, heatmap)
|
logits = model(masked_input, cond, heatmap)
|
||||||
|
|
||||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=LABEL_SMOOTHING)
|
||||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||||
|
|
||||||
val_loss_total += loss.detach()
|
val_loss_total += loss.detach()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user