mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-13 20:32:44 +08:00
feat: 资源压缩为一种 tile
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
3874d4dd95
commit
3e273c5a9d
190
docs/resource-compress-design.md
Normal file
190
docs/resource-compress-design.md
Normal file
@ -0,0 +1,190 @@
|
||||
# 资源压缩与后续规划设计文档
|
||||
|
||||
## 背景与问题诊断
|
||||
|
||||
### 现有图块定义
|
||||
|
||||
当前地图生成模型共有以下图块类别(NUM_CLASSES = 16,掩码 token 占位 15):
|
||||
|
||||
| ID | 语义 | 阶段 |
|
||||
| --- | ---------- | ------ |
|
||||
| 0 | 空地 | 阶段一 |
|
||||
| 1 | 墙壁 | 阶段一 |
|
||||
| 2 | 门 | 阶段二 |
|
||||
| 3 | 钥匙 | 阶段三 |
|
||||
| 4 | 红宝石 | 阶段三 |
|
||||
| 5 | 蓝宝石 | 阶段三 |
|
||||
| 6 | 绿宝石 | 阶段三 |
|
||||
| 7 | 血瓶 | 阶段三 |
|
||||
| 8 | 道具 | 阶段三 |
|
||||
| 9 | 怪物 | 阶段二 |
|
||||
| 10 | 入口 | 阶段二 |
|
||||
| 15 | 掩码 token | — |
|
||||
|
||||
### 问题分析
|
||||
|
||||
三阶段生成的任务量严重不均衡:
|
||||
|
||||
- **阶段一**(空地 + 墙壁):2 类,决定地图基本骨架,结构约束强。
|
||||
- **阶段二**(门 + 怪物 + 入口):3 类,决定关卡通路与挑战点,结构约束强。
|
||||
- **阶段三**(钥匙 + 红宝石 + 蓝宝石 + 绿宝石 + 血瓶 + 道具):**6 类**,资源种类几乎与前两阶段总和相同,但资源放置对地图结构的约束极弱——同一位置放红宝石还是蓝宝石,对地图整体结构几乎没有影响。
|
||||
|
||||
这种任务不均衡导致:
|
||||
|
||||
1. 模型在阶段三花费大量参数容量学习细粒度资源分类,而这一分类对结构生成没有实质贡献。
|
||||
2. 掩码预测任务的类别分布偏斜(大量位置都是资源,其类别却彼此高度相似),训练信号稀疏,模型难以稳定收敛。
|
||||
|
||||
---
|
||||
|
||||
## 方案一:资源类别压缩(核心改进)
|
||||
|
||||
### 核心思路
|
||||
|
||||
将阶段三的所有资源种类(钥匙、红宝石、蓝宝石、绿宝石、血瓶、道具)**统一压缩为单一的 `Resource` 类别**,地图生成模型不再区分具体资源类型。资源的具体种类与数值由后续独立模型负责(见方案二规划)。
|
||||
|
||||
### 新图块定义
|
||||
|
||||
压缩后,NUM_CLASSES 从 16 降至 **7**(含掩码 token),图块重新编号如下:
|
||||
|
||||
| 新 ID | 语义 | 原 ID |
|
||||
| ----- | ------------ | ----------- |
|
||||
| 0 | 空地 | 0 |
|
||||
| 1 | 墙壁 | 1 |
|
||||
| 2 | 门 | 2 |
|
||||
| 3 | 资源(统一) | 3/4/5/6/7/8 |
|
||||
| 4 | 怪物 | 9 |
|
||||
| 5 | 入口 | 10 |
|
||||
| 6 | 掩码 token | 15 |
|
||||
|
||||
三阶段任务调整为:
|
||||
|
||||
| 阶段 | 包含类别 | 类别数 |
|
||||
| ------ | ----------------------------- | ------ |
|
||||
| 阶段一 | 空地(0)、墙壁(1) | 2 |
|
||||
| 阶段二 | 门(2)、怪物(4)、入口(5) | 3 |
|
||||
| 阶段三 | 资源(3) | **1** |
|
||||
|
||||
阶段三现在退化为"在已知位置上填入资源"的简单任务,模型只需判断哪些空位应当放资源,而无需区分资源种类,任务难度大幅降低,信号更加清晰。
|
||||
|
||||
### 需要修改的位置
|
||||
|
||||
**实施策略**:优先在 Python 训练侧完成验证,确认效果后再同步修改 TypeScript 数据管线。各图块的原始数字编号在此阶段保持不变(如 entry 仍为 10),最小化改动范围;后续与 TS 侧统一调整时再重新编号。
|
||||
|
||||
- **立即执行**:第 3、4 项(训练脚本、dataset 重映射)
|
||||
- **与 TS 同步执行**:第 1、2 项(数据管线重编号)
|
||||
- **无需修改**:第 5 项(可视化模块)
|
||||
|
||||
#### 1. `data/src/shared.ts`(后续,与 TS 同步执行)
|
||||
|
||||
将 `resourceTiles` 的各子类别合并计入统一映射逻辑,图块 ID 重新编号:
|
||||
|
||||
```typescript
|
||||
// 新图块 ID 常量
|
||||
export const TILE_EMPTY = 0;
|
||||
export const TILE_WALL = 1;
|
||||
export const TILE_DOOR = 2;
|
||||
export const TILE_RESOURCE = 3; // 统一资源(原 3~8)
|
||||
export const TILE_ENEMY = 4; // 原 9
|
||||
export const TILE_ENTRY = 5; // 原 10
|
||||
export const TILE_MASK = 6; // 原 15,掩码 token
|
||||
```
|
||||
|
||||
#### 2. `data/src/auto/converter.ts`(后续,与 TS 同步执行)
|
||||
|
||||
在 `convertTile()` 方法中,将原 `key / redGem / blueGem / greenGem / potion / item` 的输出统一映射为 `TILE_RESOURCE = 3`。同时保留原始图块 ID 至 `ResourceType` 的映射,供后续数值模型使用(不丢弃语义信息,只是在地图 token 层面合并)。
|
||||
|
||||
#### 3. `ginka/train_joint.py`(及其他训练脚本)
|
||||
|
||||
```python
|
||||
NUM_CLASSES = 7 # 原来是 16
|
||||
MASK_TOKEN = 6 # 原来是 15
|
||||
```
|
||||
|
||||
当前尚未实现阶段分离,阶段划分常量待后续引入多阶段生成时再补充。
|
||||
|
||||
#### 4. `ginka/dataset.py`
|
||||
|
||||
在 `__getitem__` 中,加载 `map` 数据后做一次原地重映射,将各类资源统一压缩为 3,其余图块保持原始编号不变(与 TS 侧尚未同步重编号保持一致):
|
||||
|
||||
```python
|
||||
REMAP = {
|
||||
0: 0, # 空地
|
||||
1: 1, # 墙壁
|
||||
2: 2, # 门
|
||||
3: 3, 4: 3, 5: 3, 6: 3, 7: 3, 8: 3, # 各类资源 → 统一资源
|
||||
9: 9, # 怪物(保持原始编号)
|
||||
10: 10, # 入口(保持原始编号)
|
||||
}
|
||||
|
||||
target_np = np.vectorize(REMAP.get)(target_np)
|
||||
```
|
||||
|
||||
#### 5. `shared/image.py` 与 `shared/visual.py`(无需修改)
|
||||
|
||||
`shared/image.py` 使用图片素材渲染,不依赖颜色调色板;`shared/visual.py` 仅用于数据集可视化查看,两者均不受此次图块合并影响,无需改动。
|
||||
|
||||
### 预期收益
|
||||
|
||||
| 指标 | 改动前 | 改动后 |
|
||||
| ---------------- | ------------------ | -------------------- |
|
||||
| NUM_CLASSES | 16 | 7 |
|
||||
| 阶段三任务复杂度 | 6 类细粒度分类 | 1 类二元(放/不放) |
|
||||
| 嵌入表大小 | 16 × d_model | 7 × d_model |
|
||||
| 分类头输出维度 | 16 | 7 |
|
||||
| 训练信号质量 | 阶段三信号弱、偏斜 | 三阶段均衡、信号清晰 |
|
||||
|
||||
模型参数量略有下降,但更重要的是任务难度降低、各阶段学习目标更清晰,预期可显著改善阶段三(及整体)的收敛稳定性。
|
||||
|
||||
---
|
||||
|
||||
## 方案二:资源数值模型与怪物数值模型(后续规划)
|
||||
|
||||
> 此部分为后续计划,暂不细化实现,待方案一验证稳定后推进。
|
||||
|
||||
方案一的地图生成模型只负责"在哪里放资源/怪物",而"放什么资源/什么强度的怪物"由一组独立模型负责。
|
||||
|
||||
### 整体流程
|
||||
|
||||
```
|
||||
地图生成模型(方案一)
|
||||
│ 输出:含统一 Resource/Enemy 的地图骨架
|
||||
▼
|
||||
资源数值模型 / 怪物数值模型
|
||||
│ 输入:地图骨架 + 当前关卡强度条件
|
||||
│ 输出:每个资源/怪物位置的 { type, value } 分布
|
||||
▼
|
||||
分类模型
|
||||
│ 输入:{ type, value } 分布
|
||||
│ 输出:具体图块 ID(如 redGem_lv2、potion_lv1)
|
||||
▼
|
||||
完整地图(含具体资源与怪物种类)
|
||||
```
|
||||
|
||||
### 两类子模型结构
|
||||
|
||||
每类模型(资源 / 怪物)均分为两个独立子模型:
|
||||
|
||||
**a. 数值模型(Value Model)**
|
||||
|
||||
- 输入:地图骨架(含资源/怪物占位标记)+ 关卡强度向量
|
||||
- 输出:每个位置的类型与归一化数值,例如 `{ type: "potion", value: 0.8 }`
|
||||
- 数值在 `[0, 1]` 范围内归一化,推理时线性映射到目标区间(如等级 1~5)
|
||||
|
||||
**b. 分类模型(Classifier Model)**
|
||||
|
||||
- 输入:数值模型的输出分布
|
||||
- 输出:具体图块 ID(离散)
|
||||
- 职责:将连续数值量化为游戏中实际存在的有限种类,防止模型输出连续值导致种类爆炸(一般地图上资源/怪物只有少数几种反复复用)
|
||||
|
||||
### 设计动机
|
||||
|
||||
- 解耦地图结构与关卡数值,使二者可独立调优。
|
||||
- 分类模型的引入是关键:直接让数值模型输出离散 ID 会导致种类碎片化;分类模型将"数值相近的资源/怪物聚类到同一具体图块",符合游戏设计中资源重用的实际规律。
|
||||
- 两类模型结构一致,可共用框架代码,仅在训练数据与输出头上有差异。
|
||||
|
||||
---
|
||||
|
||||
## 实施顺序
|
||||
|
||||
1. **完成方案一**:修改数据管线(TypeScript 侧)、数据集类(Python 侧)和训练脚本,重新生成数据集,以新图块体系从头训练地图生成模型,验证收敛效果。
|
||||
2. **稳定后推进方案二**:在地图生成模型可以稳定生成结构合理的骨架图之后,再设计并实现数值模型与分类模型,最终串联为完整的地图生成管线。
|
||||
@ -15,6 +15,14 @@ def load_data(path: str):
|
||||
|
||||
return data_list
|
||||
|
||||
# 资源类别压缩:将所有资源 tile(钥匙/红宝石/蓝宝石/绿宝石/血瓶/道具)统一映射为 3
|
||||
# 其余 tile 保持原始编号(enemy=9, entry=10, mask=15)
|
||||
_RESOURCE_REMAP = np.array([0, 1, 2, 3, 3, 3, 3, 3, 3, 9, 10, 11, 12, 13, 14, 15], dtype=np.int64)
|
||||
|
||||
def remap_resources(arr: np.ndarray) -> np.ndarray:
|
||||
"""将地图 numpy 数组中的资源 tile (3~8) 统一压缩为 3。"""
|
||||
return _RESOURCE_REMAP[arr]
|
||||
|
||||
class GinkaMaskGITDataset(Dataset):
|
||||
def __init__(
|
||||
self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6,
|
||||
@ -444,6 +452,7 @@ class GinkaVQDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
raw_np = self._augment(np.array(item['map'], dtype=np.int64)) # [H, W]
|
||||
raw_np = remap_resources(raw_np) # 资源压缩
|
||||
subset = self._choose_subset()
|
||||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
@ -525,6 +534,7 @@ class GinkaSplitDataset(Dataset):
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
arr = np.array(item['map'], dtype=np.int64) # [H, W]
|
||||
arr = remap_resources(arr) # 资源压缩
|
||||
|
||||
# 随机旋转 / 翻转数据增强
|
||||
if np.random.rand() > 0.5:
|
||||
|
||||
@ -61,12 +61,18 @@ class GinkaMaskGIT(nn.Module):
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
# 结构标签嵌入(编码到 d_z 维度,与 z 拼接后统一投影到 d_model)
|
||||
# 结构标签嵌入(编码到 d_z 维度)
|
||||
# 注意:结构标签与 VQ 码字语义不同,使用独立投影层避免混用
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
self.struct_proj = nn.Sequential(
|
||||
nn.Linear(d_z, d_model),
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||
self.transformer = Transformer(
|
||||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||||
@ -131,10 +137,11 @@ class GinkaMaskGIT(nn.Module):
|
||||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
|
||||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||
z_ext = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
|
||||
|
||||
# 统一投影到 d_model 维度
|
||||
z_mem = self.z_proj(z_ext) # [B, L+4, d_model]
|
||||
# VQ 码字与结构标签语义不同,使用各自独立的投影层后再拼接
|
||||
z_mem_vq = self.z_proj(z) # [B, L, d_model]
|
||||
z_mem_struct = self.struct_proj(struct_seq) # [B, 4, d_model]
|
||||
z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L+4, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
|
||||
@ -62,7 +62,7 @@ VQ_DIM_FF = 256
|
||||
# 通道专属损失计算范围(用于监控验证召回率)
|
||||
CH1_LOSS = {1}
|
||||
CH2_LOSS = {2, 9, 10}
|
||||
CH3_LOSS = {3, 4, 5, 6, 7, 8}
|
||||
CH3_LOSS = {3} # 资源已压缩为单一 tile=3
|
||||
|
||||
# MaskGIT 超参
|
||||
MG_D_MODEL = 256
|
||||
|
||||
@ -26,16 +26,6 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
if '0' in tile_set:
|
||||
canvas[y:y+tile_size, x:x+tile_size] = tile_set['0'][:, :, :3] # 仅填充 RGB
|
||||
|
||||
if tile_index == '30':
|
||||
if row == 0:
|
||||
tile_index = '30_1'
|
||||
elif row == W - 1:
|
||||
tile_index = '30_3'
|
||||
elif col == 0:
|
||||
tile_index = '30_2'
|
||||
elif col == H - 1:
|
||||
tile_index = '30_4'
|
||||
|
||||
# 叠加其他透明图块
|
||||
if tile_index in tile_set and tile_index != 0:
|
||||
tile_rgba = tile_set[tile_index]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user