mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 12:57:15 +08:00
feat: 修改下采样部分
This commit is contained in:
parent
11378380b4
commit
ca068bbea3
@ -26,6 +26,6 @@ class GinkaModel(nn.Module):
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch, 32, 32)
|
||||
x = self.unet(x)
|
||||
x = self.down_sample(x)
|
||||
return F.softmax(x, dim=1)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
@ -7,7 +7,7 @@ class MapDownSample(nn.Module):
|
||||
self.down = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
|
||||
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -48,7 +48,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
@ -75,10 +75,10 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
output = model(feat_vec)
|
||||
_, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
scaled_losses.backward()
|
||||
@ -115,11 +115,11 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
|
||||
# 前向传播
|
||||
output = model(feat_vec)
|
||||
output, output_softmax = model(feat_vec)
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
|
||||
@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
|
||||
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
|
||||
model.load_state_dict(state)
|
||||
model.to(device)
|
||||
|
||||
@ -108,7 +108,7 @@ def validate():
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||
# 前向传播
|
||||
output = model(feat_vec)
|
||||
output, output_softmax = model(feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
for matrix in map_matrix[:].cpu():
|
||||
@ -118,7 +118,7 @@ def validate():
|
||||
idx += 1
|
||||
|
||||
# 计算损失
|
||||
_, loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
_, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
@ -4,6 +4,22 @@ import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import convert_soft_map_to_graph
|
||||
|
||||
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
|
||||
"""
|
||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||
"""
|
||||
C, H, W = onehot_map.shape
|
||||
# 生成主类别的随机概率 (min_main, max_main)
|
||||
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
||||
|
||||
# 计算剩余概率并随机分配到其他类别
|
||||
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
||||
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
||||
|
||||
# 计算最终平滑 one-hot 结果
|
||||
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
||||
return smooth_onehot
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@ -27,6 +43,9 @@ class MinamoDataset(Dataset):
|
||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
map1_probs = random_smooth_onehot(map1_probs)
|
||||
map2_probs = random_smooth_onehot(map2_probs)
|
||||
|
||||
graph1 = convert_soft_map_to_graph(map1_probs)
|
||||
graph2 = convert_soft_map_to_graph(map2_probs)
|
||||
|
||||
|
||||
@ -79,7 +79,7 @@ def train():
|
||||
# for name, param in model.named_parameters():
|
||||
# param.requires_grad = True
|
||||
|
||||
for batch in dataloader:
|
||||
for batch in tqdm(dataloader, leave=False):
|
||||
# 数据迁移到设备
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user