mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
style: 调整代码风格问题
This commit is contained in:
parent
5d95027894
commit
5f542fb577
@ -93,12 +93,12 @@ FOCAL_GAMMA = 2.0 # Focal Loss 参数
|
||||
VQ_BETA = 0.5 # 承诺损失权重
|
||||
|
||||
# 训练超参
|
||||
BATCH_SIZE = 64 # 每批样本数
|
||||
LR = 1e-4 # AdamW 初始学习率
|
||||
MIN_LR = 1e-6 # 余弦退火最低学习率
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
||||
EPOCHS = 400 # 总训练轮数
|
||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||
BATCH_SIZE = 64 # 每批样本数
|
||||
LR = 1e-4 # AdamW 初始学习率
|
||||
MIN_LR = 1e-6 # 余弦退火最低学习率
|
||||
WEIGHT_DECAY = 1e-4 # L2 正则化系数
|
||||
EPOCHS = 400 # 总训练轮数
|
||||
CHECKPOINT = 20 # 每隔多少 epoch 保存检查点并执行验证
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -173,10 +173,10 @@ def focal_loss(logits, target):
|
||||
def random_struct(device: torch.device) -> torch.Tensor:
|
||||
# 随机采样一组结构参量,用于无条件自由生成
|
||||
# struct_inject 格式:[cond_sym(0-7), cond_room(0-2), cond_branch(0-2), cond_outer(0-1)]
|
||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
||||
cond_room = random.randint(0, 2) # 房间数量档位
|
||||
cond_branch = random.randint(0, 2) # 分支复杂度档位
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廊
|
||||
cond_sym = random.randint(0, 7) # 地图对称类型
|
||||
cond_room = random.randint(0, 2) # 房间数量档位
|
||||
cond_branch = random.randint(0, 2) # 分支复杂度档位
|
||||
cond_outer = random.randint(0, 1) # 是否有外围走廊
|
||||
return torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer]).unsqueeze(0).to(device)
|
||||
|
||||
def maskgit_sample(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user