mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
feat: 修改下采样部分
This commit is contained in:
parent
11378380b4
commit
ca068bbea3
@ -26,6 +26,6 @@ class GinkaModel(nn.Module):
|
|||||||
x = self.fc(feat)
|
x = self.fc(feat)
|
||||||
x = x.view(-1, self.base_ch, 32, 32)
|
x = x.view(-1, self.base_ch, 32, 32)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = self.down_sample(x)
|
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||||
return F.softmax(x, dim=1)
|
return x, F.softmax(x, dim=1)
|
||||||
|
|
||||||
@ -7,7 +7,7 @@ class MapDownSample(nn.Module):
|
|||||||
self.down = nn.Sequential(
|
self.down = nn.Sequential(
|
||||||
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Conv2d(in_ch, out_ch, 3, stride=1, padding=0)
|
nn.Conv2d(in_ch, out_ch, 4, stride=1, padding=0)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|||||||
@ -48,7 +48,7 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss(minamo)
|
criterion = GinkaLoss(minamo)
|
||||||
|
|
||||||
@ -75,10 +75,10 @@ def train():
|
|||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
output = model(feat_vec)
|
_, output_softmax = model(feat_vec)
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
scaled_losses.backward()
|
scaled_losses.backward()
|
||||||
@ -115,11 +115,11 @@ def train():
|
|||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
output = model(feat_vec)
|
output, output_softmax = model(feat_vec)
|
||||||
print(torch.argmax(output, dim=1)[0])
|
print(torch.argmax(output, dim=1)[0])
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
scaled_losses, losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
loss_val += losses.item()
|
loss_val += losses.item()
|
||||||
|
|
||||||
avg_val_loss = loss_val / len(dataloader_val)
|
avg_val_loss = loss_val / len(dataloader_val)
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
|||||||
def validate():
|
def validate():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
|
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
|
||||||
model.load_state_dict(state)
|
model.load_state_dict(state)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@ -108,7 +108,7 @@ def validate():
|
|||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device)
|
||||||
# 前向传播
|
# 前向传播
|
||||||
output = model(feat_vec)
|
output, output_softmax = model(feat_vec)
|
||||||
map_matrix = torch.argmax(output, dim=1)
|
map_matrix = torch.argmax(output, dim=1)
|
||||||
|
|
||||||
for matrix in map_matrix[:].cpu():
|
for matrix in map_matrix[:].cpu():
|
||||||
@ -118,7 +118,7 @@ def validate():
|
|||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
# 计算损失
|
# 计算损失
|
||||||
_, loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
_, loss = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||||
val_loss += loss.item()
|
val_loss += loss.item()
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(val_loader)
|
avg_val_loss = val_loss / len(val_loader)
|
||||||
|
|||||||
@ -4,6 +4,22 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from shared.graph import convert_soft_map_to_graph
|
from shared.graph import convert_soft_map_to_graph
|
||||||
|
|
||||||
|
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
|
||||||
|
"""
|
||||||
|
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||||
|
"""
|
||||||
|
C, H, W = onehot_map.shape
|
||||||
|
# 生成主类别的随机概率 (min_main, max_main)
|
||||||
|
main_prob = torch.rand(H, W) * (max_main - min_main) + min_main
|
||||||
|
|
||||||
|
# 计算剩余概率并随机分配到其他类别
|
||||||
|
noise = torch.rand(C, H, W) * epsilon # 随机噪声
|
||||||
|
noise = noise / noise.sum(dim=1, keepdim=True) # 归一化到总和为 epsilon
|
||||||
|
|
||||||
|
# 计算最终平滑 one-hot 结果
|
||||||
|
smooth_onehot = onehot_map * main_prob + (1 - onehot_map) * noise
|
||||||
|
return smooth_onehot
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -27,6 +43,9 @@ class MinamoDataset(Dataset):
|
|||||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
|
|
||||||
|
map1_probs = random_smooth_onehot(map1_probs)
|
||||||
|
map2_probs = random_smooth_onehot(map2_probs)
|
||||||
|
|
||||||
graph1 = convert_soft_map_to_graph(map1_probs)
|
graph1 = convert_soft_map_to_graph(map1_probs)
|
||||||
graph2 = convert_soft_map_to_graph(map2_probs)
|
graph2 = convert_soft_map_to_graph(map2_probs)
|
||||||
|
|
||||||
|
|||||||
@ -79,7 +79,7 @@ def train():
|
|||||||
# for name, param in model.named_parameters():
|
# for name, param in model.named_parameters():
|
||||||
# param.requires_grad = True
|
# param.requires_grad = True
|
||||||
|
|
||||||
for batch in dataloader:
|
for batch in tqdm(dataloader, leave=False):
|
||||||
# 数据迁移到设备
|
# 数据迁移到设备
|
||||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user