feat: 修改下采样部分

This commit is contained in:
unanmed 2025-03-22 18:19:24 +08:00
parent 11378380b4
commit ca068bbea3
6 changed files with 31 additions and 12 deletions

View File

@ -26,6 +26,6 @@ class GinkaModel(nn.Module):
x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x)
x = self.down_sample(x)
return F.softmax(x, dim=1)
x = F.interpolate(x, (13, 13), mode='bilinear')
return x, F.softmax(x, dim=1)

View File

@ -7,7 +7,7 @@ class MapDownSample(nn.Module):
self.down = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
)
def forward(self, x):

View File

@ -48,7 +48,7 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
@ -75,10 +75,10 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
optimizer.zero_grad()
output = model(feat_vec)
_, output_softmax = model(feat_vec)
# 计算损失
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
scaled_losses.backward()
@ -115,11 +115,11 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output = model(feat_vec)
output, output_softmax = model(feat_vec)
print(torch.argmax(output, dim=1)[0])
# 计算损失
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
avg_val_loss = loss_val / len(dataloader_val)

View File

@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)
@ -108,7 +108,7 @@ def validate():
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
# 前向传播
output = model(feat_vec)
output, output_softmax = model(feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu():
@ -118,7 +118,7 @@ def validate():
idx += 1
# 计算损失
_, loss = criterion(output, target, target_vision_feat, target_topo_feat)
_, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
val_loss += loss.item()
avg_val_loss = val_loss / len(val_loader)

View File

@ -4,6 +4,22 @@ import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import convert_soft_map_to_graph
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""
C, H, W = onehot_map.shape
# 生成主类别的随机概率 (min_main, max_main)
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
# 计算剩余概率并随机分配到其他类别
noise = torch.rand(C, H, W) * epsilon # 随机噪声
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
# 计算最终平滑 one-hot 结果
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
return smooth_onehot
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
@ -27,6 +43,9 @@ class MinamoDataset(Dataset):
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map1_probs = random_smooth_onehot(map1_probs)
map2_probs = random_smooth_onehot(map2_probs)
graph1 = convert_soft_map_to_graph(map1_probs)
graph2 = convert_soft_map_to_graph(map2_probs)

View File

@ -79,7 +79,7 @@ def train():
# for name, param in model.named_parameters():
# param.requires_grad = True
for batch in dataloader:
for batch in tqdm(dataloader, leave=False):
# 数据迁移到设备
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]