diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 0fd3657..4abcc68 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -93,9 +93,9 @@ def train(): # 用于生成图片 tile_dict = dict() - for file in os.listdir('tiles'): + for file in os.listdir('tiles2'): name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -111,30 +111,32 @@ def train(): for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): loss_total = torch.Tensor([0]).to(device) - # for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - # target_map = batch["target_map"].to(device) - # cond = batch["val_cond"].to(device) - # B, H, W = target_map.shape - # target_map = target_map.view(B, H * W) + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + target_map = batch["target_map"].to(device) + cond = batch["val_cond"].to(device) + B, H, W = target_map.shape + target_map = target_map.view(B, H * W) - # # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - # r = torch.rand(B).to(device) - # r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) + r = torch.rand(B).to(device) + r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 - # # 2. 生成掩码矩阵 - # masks = torch.rand(target_map.shape).to(device) < r - # masked_input = target_map.clone() - # masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + # 2. 生成掩码矩阵 + masks = torch.rand(target_map.shape).to(device) < r + masked_input = target_map.clone() + masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 - # logits = model(masked_input, cond) + logits = model(masked_input, cond) - # loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') - # loss = (loss * masks).sum() / (masks.sum() + 1e-6) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') + loss = (loss * masks).sum() / (masks.sum() + 1e-6) - # optimizer.zero_grad() - # loss.backward() - # optimizer.step() - # loss_total += loss.detach() + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_total += loss.detach() + + scheduler.step() avg_loss = loss_total.item() / len(dataloader) tqdm.write( diff --git a/ginka/train_transformer.py b/ginka/train_transformer.py index 750e8e0..e3b5198 100644 --- a/ginka/train_transformer.py +++ b/ginka/train_transformer.py @@ -102,9 +102,9 @@ def train(): # 用于生成图片 tile_dict = dict() - for file in os.listdir('tiles'): + for file in os.listdir('tiles2'): name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -196,6 +196,7 @@ def train(): color = (255, 255, 255) # 白色 vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 # 地图重建展示 + vae.teacher_forcing() for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): target_map = batch["target_map"].to(device) B, H, W = target_map.shape @@ -218,10 +219,10 @@ def train(): idx += 1 # 随机采样 + vae.autoregressive() for i in range(0, 8): z = torch.randn(1, LATENT_DIM).to(device) - vae.autoregressive() fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) diff --git a/tiles2/0.png b/tiles2/0.png new file mode 100644 index 0000000..9649930 Binary files /dev/null and b/tiles2/0.png differ diff --git a/tiles2/1.png b/tiles2/1.png new file mode 100644 index 0000000..f8e7142 Binary files /dev/null and b/tiles2/1.png differ diff --git a/tiles2/10.png b/tiles2/10.png new file mode 100644 index 0000000..d2eb533 Binary files /dev/null and b/tiles2/10.png differ diff --git a/tiles2/2.png b/tiles2/2.png new file mode 100644 index 0000000..83de73a Binary files /dev/null and b/tiles2/2.png differ diff --git a/tiles2/3.png b/tiles2/3.png new file mode 100644 index 0000000..339c1c3 Binary files /dev/null and b/tiles2/3.png differ diff --git a/tiles2/4.png b/tiles2/4.png new file mode 100644 index 0000000..08409ab Binary files /dev/null and b/tiles2/4.png differ diff --git a/tiles2/5.png b/tiles2/5.png new file mode 100644 index 0000000..792ed88 Binary files /dev/null and b/tiles2/5.png differ diff --git a/tiles2/6.png b/tiles2/6.png new file mode 100644 index 0000000..4b8d3a6 Binary files /dev/null and b/tiles2/6.png differ diff --git a/tiles2/7.png b/tiles2/7.png new file mode 100644 index 0000000..b121323 Binary files /dev/null and b/tiles2/7.png differ diff --git a/tiles2/8.png b/tiles2/8.png new file mode 100644 index 0000000..38d7a35 Binary files /dev/null and b/tiles2/8.png differ diff --git a/tiles2/9.png b/tiles2/9.png new file mode 100644 index 0000000..1329097 Binary files /dev/null and b/tiles2/9.png differ