ginka-generator/ginka/dataset.py

191 lines
6.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import random
import torch
import numpy as np
from torch.utils.data import Dataset
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
data_list.append(value)
return data_list
def compute_symmetry(target_np: np.ndarray) -> tuple:
"""从 numpy 地图矩阵中直接计算三种对称性O(H*W)"""
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
sym_v = bool(np.all(target_np == target_np[::-1, :]))
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
return int(sym_h), int(sym_v), int(sym_c)
class GinkaSeperatedDataset(Dataset):
FLOOR = 0
WALL = 1
DOOR = 2
RESOURCE = 3
MONSTER = 4
ENTRANCE = 5
MASK_ID = 6
def __init__(
self,
data_path: str,
subset_weights: tuple = (0.5, 0.3, 0.2),
subset2_wall_prob: float = 0.7
):
self.data = load_data(data_path)
self.subset2_wall_prob = subset2_wall_prob
total = sum(subset_weights)
self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))]
n = len(self.data)
rs = sorted(item['roomCount'] for item in self.data)
bs = sorted(item['highDegBranchCount'] for item in self.data)
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
if th1_r == th2_r: th2_r = th1_r + 1
if th1_b == th2_b: th2_b = th1_b + 1
self.room_th = (th1_r, th2_r)
self.branch_th = (th1_b, th2_b)
for item in self.data:
item['roomCountLevel'] = self.to_level(item['roomCount'], self.room_th)
item['branchLevel'] = self.to_level(item['highDegBranchCount'], self.branch_th)
def to_level(self, v, th):
return 0 if v < th[0] else (1 if v < th[1] else 2)
def __len__(self):
return len(self.data)
def degrade_tile(self, m: np.ndarray, tiles: list) -> np.ndarray:
# 将指定 tile ID 替换为 floor(0),原地修改
for t in tiles:
m[m == t] = self.FLOOR
return m
def std_mask(self) -> np.ndarray:
# Beta(2,2) 采样掩码比例50% 随机掩码 / 50% 分块掩码,返回 bool[13, 13]
ratio = float(np.random.beta(2, 2)) * 0.95 + 0.05
if random.random() < 0.5:
idx = np.random.choice(169, int(169 * ratio), replace=False)
mask = np.zeros(169, dtype=bool)
mask[idx] = True
return mask.reshape(13, 13)
target = int(169 * ratio)
mask = np.zeros((13, 13), dtype=bool)
while mask.sum() < target:
bh = np.random.randint(2, 7)
bw = np.random.randint(2, 7)
x = np.random.randint(0, 14 - bh)
y = np.random.randint(0, 14 - bw)
mask[x:x + bh, y:y + bw] = True
return mask
def create_degreaded(self, raw: np.ndarray):
# 阶段一:生成墙壁和入口
target1 = raw.copy()
self.degrade_tile(target1, [self.DOOR, self.RESOURCE, self.MONSTER])
inp1 = target1.copy()
# 阶段二:生成怪物、门,同时也允许生成入口以适配结构
target2 = raw.copy()
self.degrade_tile(target2, [self.RESOURCE, self.WALL])
inp2 = raw.copy()
self.degrade_tile(inp2, [self.RESOURCE])
# 阶段三:生成资源
target3 = raw.copy()
self.degrade_tile(target3, [self.WALL, self.DOOR, self.MONSTER, self.ENTRANCE])
inp3 = raw.copy()
return target1, inp1, target2, inp2, target3, inp3
def apply_subset1(self, raw: np.ndarray):
# 子集 1std_mask 随机掩码
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
enc1 = target1.copy()
enc2 = inp2.copy()
enc3 = raw.copy()
# stage1对整图 std_mask
inp1[self.std_mask()] = self.MASK_ID
# stage2对 floor+功能元素区域 std_mask
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
inp2[need_mask & self.std_mask()] = self.MASK_ID
# stage3对 floor+resource 区域 std_mask
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
inp3[need_mask & self.std_mask()] = self.MASK_ID
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
def apply_subset2(self, raw: np.ndarray):
# 子集 2掩码所有内容墙壁随机掩码不掩码入口
target1, inp1, target2, inp2, target3, inp3 = self.create_degreaded(raw)
enc1 = target1.copy()
enc2 = inp2.copy()
enc3 = raw.copy()
if np.random.random() < self.subset2_wall_prob:
inp1[self.std_mask()] = self.MASK_ID
need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE])
inp2[need_mask] = self.MASK_ID
need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE])
inp3[need_mask] = self.MASK_ID
return inp1, target1, enc1, inp2, target2, enc2, inp3, target3, enc3
def apply_subset3(self, raw: np.ndarray):
# 子集 3在 2 的基础上掩码入口
out = self.apply_subset2(raw)
out[0][out[0] == self.ENTRANCE] = self.MASK_ID
return out
def __getitem__(self, idx):
item = self.data[idx]
map_np = np.array(item['map'], dtype=np.int64)
if np.random.rand() > 0.5:
k = np.random.randint(1, 4)
map_np = np.rot90(map_np, k).copy()
if np.random.rand() > 0.5:
map_np = np.fliplr(map_np).copy()
if np.random.rand() > 0.5:
map_np = np.flipud(map_np).copy()
r = random.random()
if r < self.subset_cumw[0]:
out = self.apply_subset1(map_np)
elif r < self.subset_cumw[1]:
out = self.apply_subset2(map_np)
else:
out = self.apply_subset3(map_np)
sym_h, sym_v, sym_c = compute_symmetry(map_np)
cond_sym = sym_h * 4 + sym_v * 2 + sym_c
cond_room = item['roomCountLevel']
cond_branch = item['branchLevel']
cond_outer = item['outerWall']
struct_inject = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
return {
"input_stage1": torch.LongTensor(out[0]),
"target_stage1": torch.LongTensor(out[1]),
"encoder_stage1": torch.LongTensor(out[2]),
"input_stage2": torch.LongTensor(out[3]),
"target_stage2": torch.LongTensor(out[4]),
"encoder_stage2": torch.LongTensor(out[5]),
"input_stage3": torch.LongTensor(out[6]),
"target_stage3": torch.LongTensor(out[7]),
"encoder_stage3": torch.LongTensor(out[8]),
"struct_inject": struct_inject
}