fix: maskGIT
@ -93,9 +93,9 @@ def train():
|
|||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
for file in os.listdir('tiles'):
|
for file in os.listdir('tiles2'):
|
||||||
name = os.path.splitext(file)[0]
|
name = os.path.splitext(file)[0]
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||||
@ -111,30 +111,32 @@ def train():
|
|||||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
||||||
loss_total = torch.Tensor([0]).to(device)
|
loss_total = torch.Tensor([0]).to(device)
|
||||||
|
|
||||||
# for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
# target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
# cond = batch["val_cond"].to(device)
|
cond = batch["val_cond"].to(device)
|
||||||
# B, H, W = target_map.shape
|
B, H, W = target_map.shape
|
||||||
# target_map = target_map.view(B, H * W)
|
target_map = target_map.view(B, H * W)
|
||||||
|
|
||||||
# # 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
|
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
|
||||||
# r = torch.rand(B).to(device)
|
r = torch.rand(B).to(device)
|
||||||
# r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
|
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
|
||||||
|
|
||||||
# # 2. 生成掩码矩阵
|
# 2. 生成掩码矩阵
|
||||||
# masks = torch.rand(target_map.shape).to(device) < r
|
masks = torch.rand(target_map.shape).to(device) < r
|
||||||
# masked_input = target_map.clone()
|
masked_input = target_map.clone()
|
||||||
# masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
|
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||||
|
|
||||||
# logits = model(masked_input, cond)
|
logits = model(masked_input, cond)
|
||||||
|
|
||||||
# loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none')
|
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none')
|
||||||
# loss = (loss * masks).sum() / (masks.sum() + 1e-6)
|
loss = (loss * masks).sum() / (masks.sum() + 1e-6)
|
||||||
|
|
||||||
# optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
# loss.backward()
|
loss.backward()
|
||||||
# optimizer.step()
|
optimizer.step()
|
||||||
# loss_total += loss.detach()
|
loss_total += loss.detach()
|
||||||
|
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
avg_loss = loss_total.item() / len(dataloader)
|
avg_loss = loss_total.item() / len(dataloader)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
|
|||||||
@ -102,9 +102,9 @@ def train():
|
|||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
for file in os.listdir('tiles'):
|
for file in os.listdir('tiles2'):
|
||||||
name = os.path.splitext(file)[0]
|
name = os.path.splitext(file)[0]
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||||
@ -196,6 +196,7 @@ def train():
|
|||||||
color = (255, 255, 255) # 白色
|
color = (255, 255, 255) # 白色
|
||||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||||
# 地图重建展示
|
# 地图重建展示
|
||||||
|
vae.teacher_forcing()
|
||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
B, H, W = target_map.shape
|
B, H, W = target_map.shape
|
||||||
@ -218,10 +219,10 @@ def train():
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
# 随机采样
|
# 随机采样
|
||||||
|
vae.autoregressive()
|
||||||
for i in range(0, 8):
|
for i in range(0, 8):
|
||||||
z = torch.randn(1, LATENT_DIM).to(device)
|
z = torch.randn(1, LATENT_DIM).to(device)
|
||||||
|
|
||||||
vae.autoregressive()
|
|
||||||
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
||||||
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
||||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||||
|
|||||||
BIN
tiles2/0.png
Normal file
|
After Width: | Height: | Size: 1.4 KiB |
BIN
tiles2/1.png
Normal file
|
After Width: | Height: | Size: 576 B |
BIN
tiles2/10.png
Normal file
|
After Width: | Height: | Size: 699 B |
BIN
tiles2/2.png
Normal file
|
After Width: | Height: | Size: 426 B |
BIN
tiles2/3.png
Normal file
|
After Width: | Height: | Size: 368 B |
BIN
tiles2/4.png
Normal file
|
After Width: | Height: | Size: 406 B |
BIN
tiles2/5.png
Normal file
|
After Width: | Height: | Size: 396 B |
BIN
tiles2/6.png
Normal file
|
After Width: | Height: | Size: 419 B |
BIN
tiles2/7.png
Normal file
|
After Width: | Height: | Size: 441 B |
BIN
tiles2/8.png
Normal file
|
After Width: | Height: | Size: 448 B |
BIN
tiles2/9.png
Normal file
|
After Width: | Height: | Size: 353 B |