mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
feat: 增加 Minamo Model 模型深度
This commit is contained in:
parent
c9c52109ed
commit
eb0626ef88
@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
border_loss * self.weight[1],
|
||||
border_loss * self.weight[1] * 0.1,
|
||||
entrance_loss * self.weight[2],
|
||||
count_loss * self.weight[3],
|
||||
illegal_loss * self.weight[4]
|
||||
]
|
||||
|
||||
# 梯度归一化
|
||||
# scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||
total_loss = sum(losses)
|
||||
return total_loss
|
||||
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||
total_loss = sum(scaled_losses)
|
||||
return total_loss, sum(losses)
|
||||
|
||||
@ -78,12 +78,12 @@ def train():
|
||||
output = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
scaled_losses.backward()
|
||||
optimizer.step()
|
||||
total_loss += loss.item()
|
||||
total_loss += losses.item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
@ -119,8 +119,8 @@ def train():
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += loss.item()
|
||||
scaled_losses, losses = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
|
||||
@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
state = torch.load("result/ginka_checkpoint/15.pth", map_location=device)["model_state"]
|
||||
state = torch.load("result/ginka_checkpoint/30.pth", map_location=device)["model_state"]
|
||||
model.load_state_dict(state)
|
||||
model.to(device)
|
||||
|
||||
@ -113,7 +113,7 @@ def validate():
|
||||
idx += 1
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
_, loss = criterion(output, target, target_vision_feat, target_topo_feat)
|
||||
val_loss += loss.item()
|
||||
|
||||
avg_val_loss = val_loss / len(val_loader)
|
||||
|
||||
@ -12,4 +12,5 @@ class MinamoLoss(nn.Module):
|
||||
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
|
||||
vis_loss = self.mse(vis_pred, vis_true)
|
||||
topo_loss = self.mse(topo_pred, topo_true)
|
||||
# print(vis_loss.item(), topo_loss.item())
|
||||
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
|
||||
|
||||
@ -14,11 +14,15 @@ class MinamoTopoModel(nn.Module):
|
||||
# 图卷积层
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
||||
self.conv2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4)
|
||||
self.conv_ins2 = GATConv(hidden_dim*16, hidden_dim*4, heads=4, dropout=0.3)
|
||||
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
|
||||
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
|
||||
|
||||
# 正则化
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm_ins2 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm_ins1 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm3 = nn.LayerNorm(out_dim)
|
||||
|
||||
# 池化层
|
||||
@ -40,6 +44,12 @@ class MinamoTopoModel(nn.Module):
|
||||
x = self.conv2(x, graph.edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
|
||||
x = self.conv_ins2(x, graph.edge_index)
|
||||
x = F.elu(self.norm_ins2(x))
|
||||
|
||||
x = self.conv_ins1(x, graph.edge_index)
|
||||
x = F.elu(self.norm_ins1(x))
|
||||
|
||||
x = self.conv3(x, graph.edge_index)
|
||||
x = F.elu(self.norm3(x))
|
||||
|
||||
|
||||
@ -30,6 +30,11 @@ class MinamoVisionModel(nn.Module):
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
|
||||
@ -61,16 +61,24 @@ def train():
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"])
|
||||
model.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if 'ins' not in name: # 仅训练扩展部分
|
||||
# param.requires_grad = False
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
# if epoch == 30:
|
||||
# for name, param in model.named_parameters():
|
||||
# param.requires_grad = True
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
|
||||
@ -15,7 +15,7 @@ def validate():
|
||||
model.to(device)
|
||||
|
||||
# 准备数据集
|
||||
val_dataset = MinamoDataset("minamo-eval.json")
|
||||
val_dataset = MinamoDataset("datasets/minamo-eval.json")
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=32,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user