refactor: 修改目录结构 & feat: 条件注入

This commit is contained in:
unanmed 2025-04-29 18:24:01 +08:00
parent a28e56456d
commit 44b90e7630
17 changed files with 294 additions and 497 deletions

43
ginka/common/cond.py Normal file
View File

@ -0,0 +1,43 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConditionEncoder(nn.Module):
def __init__(self, tag_dim, val_dim, hidden_dim, out_dim):
super().__init__()
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
self.val_embed = nn.Linear(val_dim, hidden_dim)
self.fusion = nn.Sequential(
nn.LayerNorm(hidden_dim*2),
nn.ELU(),
nn.Linear(hidden_dim*2, hidden_dim*4),
nn.LayerNorm(hidden_dim*4),
nn.ELU(),
nn.Linear(hidden_dim*4, out_dim)
)
def forward(self, tag, val):
tag = self.tag_embed(tag)
val = self.val_embed(val)
feat = torch.cat([tag, val], dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
)
def forward(self, x, cond):
cond = self.fc(cond)
B, D = cond.shape
cond = cond.view(B, D, 1, 1)
return x + cond

View File

@ -3,30 +3,17 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from ..common.cond import ConditionEncoder, ConditionInjector
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class MinamoModel(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 视觉相似度部分
self.vision_model = MinamoVisionModel(tile_types)
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
def forward(self, map, graph):
vision_feat = self.vision_model(map)
topo_feat = self.topo_model(graph)
return vision_feat, topo_feat
class CNNHead(nn.Module):
def __init__(self, in_ch, out_dim):
def __init__(self, in_ch):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
@ -35,61 +22,69 @@ class CNNHead(nn.Module):
nn.AdaptiveMaxPool2d((2, 2))
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*2*2, out_dim))
spectral_norm(nn.Linear(in_ch*2*2, 1))
)
self.proj = nn.Linear(256, in_ch*2*2)
def forward(self, x):
def forward(self, x, cond):
x = self.cnn(x)
B, C, H, W = x.shape
x = x.view(B, -1)
x = self.fc(x)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class GCNHead(nn.Module):
def __init__(self, in_dim, out_dim):
def __init__(self, in_dim):
super().__init__()
self.gcn = GCNConv(in_dim, in_dim)
self.proj = nn.Linear(256, in_dim)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, out_dim))
spectral_norm(nn.Linear(in_dim, 1))
)
def forward(self, x, graph):
def forward(self, x, graph, cond):
x = self.gcn(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = global_max_pool(x, graph.batch)
x = self.fc(x)
cond = self.proj(cond)
proj = torch.sum(x * cond, dim=1, keepdim=True)
x = self.fc(x) + proj
return x
class MinamoScoreHead(nn.Module):
def __init__(self, vision_dim, topo_dim, out_dim):
def __init__(self, vision_dim, topo_dim):
super().__init__()
self.vision_head = CNNHead(vision_dim, out_dim)
self.topo_head = GCNHead(topo_dim, out_dim)
self.vision_head = CNNHead(vision_dim)
self.topo_head = GCNHead(topo_dim)
def forward(self, vis, topo, graph):
vis_score = self.vision_head(vis)
topo_score = self.topo_head(topo, graph)
def forward(self, vis, topo, graph, cond):
vis_score = self.vision_head(vis, cond)
topo_score = self.topo_head(topo, graph, cond)
return vis_score, topo_score
class MinamoScoreModule(nn.Module):
class MinamoModel(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
self.cond = ConditionEncoder(64, 16, 128, 256)
# 输出层
self.head1 = MinamoScoreHead(512, 512, 1)
self.head2 = MinamoScoreHead(512, 512, 1)
self.head3 = MinamoScoreHead(512, 512, 1)
self.head1 = MinamoScoreHead(512, 512)
self.head2 = MinamoScoreHead(512, 512)
self.head3 = MinamoScoreHead(512, 512)
def forward(self, map, graph, stage):
def forward(self, map, graph, stage, tag_cond, val_cond):
vision = self.vision_model(map)
topo = self.topo_model(graph)
cond = self.cond(tag_cond, val_cond)
if stage == 1:
vision_score, topo_score = self.head1(vision, topo, graph)
vision_score, topo_score = self.head1(vision, topo, graph, cond)
elif stage == 2:
vision_score, topo_score = self.head2(vision, topo, graph)
vision_score, topo_score = self.head2(vision, topo, graph, cond)
elif stage == 3:
vision_score, topo_score = self.head3(vision, topo, graph)
vision_score, topo_score = self.head3(vision, topo, graph, cond)
else:
raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
@ -98,19 +93,22 @@ class MinamoScoreModule(nn.Module):
# 检查显存占用
if __name__ == "__main__":
input = torch.randn((1, 32, 13, 13)).cuda()
tag = torch.rand(1, 64).cuda()
val = torch.rand(1, 16).cuda()
# 初始化模型
model = MinamoScoreModule().cuda()
model = MinamoModel().cuda()
print_memory("初始化后")
# 前向传播
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1)
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1, tag, val)
print_memory("前向传播后")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")

View File

@ -87,54 +87,94 @@ class GinkaWGANDataset(Dataset):
def __len__(self):
return len(self.data)
def handle_stage1(self, target):
def handle_stage1(self, target, tag_cond, val_cond):
# 课程学习第一阶段,蒙版填充
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
return removed1, masked1, removed2, masked2, removed3, masked3
return {
"real1": removed1,
"masked1": masked1,
"real2": removed2,
"masked2": masked2,
"real3": removed3,
"masked3": masked3,
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage2(self, target):
def handle_stage2(self, target, tag_cond, val_cond):
# 课程学习第二阶段,完全随机蒙版
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
return removed1, masked1, removed2, masked2, removed3, masked3
return {
"real1": removed1,
"masked1": masked1,
"real2": removed2,
"masked2": masked2,
"real3": removed3,
"masked3": masked3,
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage3(self, target):
def handle_stage3(self, target, tag_cond, val_cond):
# 第三阶段,联合生成,输入随机蒙版
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
return {
"real1": removed1,
"masked1": masked1,
"real2": removed2,
"masked2": torch.zeros_like(target),
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,
"val_cond": val_cond
}
def handle_stage4(self, target):
# 第四阶段,与第二阶段交替进行,完全随机输入
def handle_stage4(self, target, tag_cond, val_cond):
# 第四阶段,完全随机输入
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
rand = torch.rand(32, 32, 32, device=target.device)
return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
return {
"real1": removed1,
"masked1": rand,
"real2": removed2,
"masked2": torch.zeros_like(target),
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,
"val_cond": val_cond
}
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
tag_cond = torch.FloatTensor(item['tag'])
val_cond = torch.FloatTensor(item['val'])
if self.train_stage == 1:
return self.handle_stage1(target)
return self.handle_stage1(target, tag_cond, val_cond)
elif self.train_stage == 2:
return self.handle_stage2(target)
return self.handle_stage2(target, tag_cond, val_cond)
elif self.train_stage == 3:
return self.handle_stage3(target)
return self.handle_stage3(target, tag_cond, val_cond)
elif self.train_stage == 4:
return self.handle_stage4(target)
return self.handle_stage4(target, tag_cond, val_cond)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")

View File

@ -1,29 +1,34 @@
import torch
import torch.nn as nn
from ..common.common import GCNBlock, DoubleConvBlock
from ..common.cond import ConditionInjector
class RandomInputHead(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(32),
nn.ELU(),
nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(64),
nn.ELU(),
nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(128),
self.conv = DoubleConvBlock([32, 64, 128])
self.gcn = GCNBlock(32, 128, 128, 32, 32)
self.fusion = nn.Sequential(
nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(256),
nn.ELU(),
)
self.out_conv = nn.Sequential(
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(128),
nn.ELU(),
nn.AdaptiveMaxPool2d((13, 13)),
nn.Conv2d(128, 32, 1),
)
self.inject = ConditionInjector(256, 256)
def forward(self, x):
x = self.conv(x)
def forward(self, x, cond):
x_cnn = self.conv(x)
x_gcn = self.gcn(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
x = self.out_conv(x)
return x

View File

@ -4,11 +4,9 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from minamo.model.model import MinamoModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.similarity.topo import overall_similarity, build_topological_graph
from shared.similarity.vision import calculate_visual_similarity
from ..critic.model import MinamoModel
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 13
@ -355,7 +353,7 @@ class WGANGinkaLoss:
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
def compute_gradient_penalty(self, critic, stage, real_data, fake_data):
def compute_gradient_penalty(self, critic, stage, real_data, fake_data, tag_cond, val_cond):
# 进行插值
batch_size = real_data.size(0)
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
@ -366,7 +364,7 @@ class WGANGinkaLoss:
interp_data.requires_grad_()
interp_graph.x.requires_grad_()
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage)
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage, tag_cond, val_cond)
# 计算梯度
grad_vis = torch.autograd.grad(
@ -392,29 +390,30 @@ class WGANGinkaLoss:
return gp_loss
def discriminator_loss(
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor,
tag_cond: torch.Tensor, val_cond: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """
fake_data = F.softmax(fake_data, dim=1)
real_graph = batch_convert_soft_map_to_graph(real_data)
fake_graph = batch_convert_soft_map_to_graph(fake_data)
real_scores, _, _ = critic(real_data, real_graph, stage)
fake_scores, _, _ = critic(fake_data, fake_graph, stage)
real_scores, _, _ = critic(real_data, real_graph, stage, tag_cond, val_cond)
fake_scores, _, _ = critic(fake_data, fake_graph, stage, tag_cond, val_cond)
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data)
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data, tag_cond, val_cond)
total_loss = d_loss + self.lambda_gp * grad_loss
return total_loss, d_loss
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def generator_loss(self, critic, stage, mask_ratio, real, fake: torch.Tensor, input, tag_cond, val_cond) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
@ -439,11 +438,11 @@ class WGANGinkaLoss:
return sum(losses), minamo_loss, ce_loss, immutable_loss
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
constraint_loss = inner_constraint_loss(probs_fake)
@ -462,11 +461,11 @@ class WGANGinkaLoss:
return sum(losses)
def generator_loss_total_with_input(self, critic, stage, fake, input) -> torch.Tensor:
def generator_loss_total_with_input(self, critic, stage, fake, input, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_graph = batch_convert_soft_map_to_graph(probs_fake)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)

View File

@ -4,6 +4,7 @@ import torch.nn.functional as F
from .unet import GinkaUNet
from .output import GinkaOutput
from .input import GinkaInput, RandomInputHead
from ..common.cond import ConditionEncoder
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -14,23 +15,27 @@ class GinkaModel(nn.Module):
"""
super().__init__()
self.head = RandomInputHead()
self.cond = ConditionEncoder(64, 16, 128, 256)
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage, random=False):
def forward(self, x, stage, tag_cond, val_cond, random=False):
cond = self.cond(tag_cond, val_cond)
if random:
x_in = F.softmax(self.head(x), dim=1)
x_in = F.softmax(self.head(x, cond), dim=1)
else:
x_in = x
x = self.input(x_in)
x = self.unet(x)
x = self.output(x, stage)
x = self.unet(x, cond)
x = self.output(x, stage, cond)
return x, x_in
# 检查显存占用
if __name__ == "__main__":
input = torch.randn((1, 32, 32, 32)).cuda()
input = torch.rand(1, 32, 32, 32).cuda()
tag = torch.rand(1, 64).cuda()
val = torch.rand(1, 16).cuda()
# 初始化模型
model = GinkaModel().cuda()
@ -38,12 +43,14 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output, _ = model(input, 1, True)
output, _ = model(input, 1, tag, val, True)
print_memory("前向传播后")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
print(f"Head parameters: {sum(p.numel() for p in model.head.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from .common import GCNBlock, DoubleConvBlock
from ..common.common import GCNBlock, DoubleConvBlock
from ..common.cond import ConditionInjector
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
@ -9,15 +10,21 @@ class StageHead(nn.Module):
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.pool = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
)
self.inject = ConditionInjector(256, in_ch)
def forward(self, x):
def forward(self, x, cond):
x_cnn = self.cnn_head(x)
x_gcn = self.gcn_head(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
x = self.pool(x)
return x
@ -28,13 +35,13 @@ class GinkaOutput(nn.Module):
self.head2 = StageHead(in_ch, out_ch, out_size)
self.head3 = StageHead(in_ch, out_ch, out_size)
def forward(self, x, stage):
def forward(self, x, stage, cond):
if stage == 1:
x = self.head1(x)
x = self.head1(x, cond)
elif stage == 2:
x = self.head2(x)
x = self.head2(x, cond)
elif stage == 3:
x = self.head3(x)
x = self.head3(x, cond)
else:
raise RuntimeError("Unknown generate stage.")
return x

View File

@ -2,7 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import ChannelAttention
from .common import GCNBlock, DoubleConvBlock
from ..common.common import GCNBlock
from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
@ -53,7 +54,7 @@ class ConvBlock(nn.Module):
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
@ -66,10 +67,12 @@ class GinkaEncoder(nn.Module):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x):
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x = self.inject(x, cond)
return x
class GinkaGCNFusedEncoder(nn.Module):
@ -79,12 +82,14 @@ class GinkaGCNFusedEncoder(nn.Module):
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.pool = nn.MaxPool2d(2)
self.fusion = FusionModule(out_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x):
def forward(self, x, cond):
x = self.conv(x)
x = self.pool(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
x = self.inject(x, cond)
return x
class GinkaUpSample(nn.Module):
@ -105,11 +110,13 @@ class GinkaDecoder(nn.Module):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat):
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaGCNFusedDecoder(nn.Module):
@ -119,13 +126,15 @@ class GinkaGCNFusedDecoder(nn.Module):
self.conv = ConvBlock(in_ch, out_ch)
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
self.fusion = FusionModule(out_ch*2, out_ch)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat):
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x2 = self.gcn(x)
x = self.fusion(x, x2)
x = self.inject(x, cond)
return x
class GinkaBottleneck(nn.Module):
@ -136,9 +145,10 @@ class GinkaBottleneck(nn.Module):
token_size=16, ff_dim=1024, num_layers=4
)
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
self.fusion = FusionModule(module_ch*2, module_ch)
self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.inject = ConditionInjector(256, module_ch)
def forward(self, x):
def forward(self, x, cond):
B = x.size(0)
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
@ -146,7 +156,9 @@ class GinkaBottleneck(nn.Module):
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x2 = self.gcn(x)
x = self.fusion(x1, x2)
x = torch.cat([x, x1, x2], dim=1)
x = self.fusion(x)
x = self.inject(x, cond)
return x
@ -162,7 +174,7 @@ class GinkaUNet(nn.Module):
self.down1 = ConvBlock(in_ch, base_ch)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
self.down4 = GinkaGCNFusedEncoder(base_ch*4, base_ch*8, 4, 4)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
@ -175,17 +187,17 @@ class GinkaUNet(nn.Module):
nn.ELU(),
)
def forward(self, x):
def forward(self, x, cond):
x1 = self.down1(x) # [B, 64, 32, 32]
x2 = self.down2(x1) # [B, 128, 16, 16]
x3 = self.down3(x2) # [B, 256, 8, 8]
x4 = self.down4(x3) # [B, 512, 4, 4]
x4 = self.bottleneck(x4) # [B, 512, 4, 4]
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
# 上采样
x = self.up1(x4, x3) # [B, 256, 8, 8]
x = self.up2(x, x2) # [B, 128, 16, 16]
x = self.up3(x, x1) # [B, 64, 32, 32]
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
x = self.final(x) # [B, 32, 32, 32]
return x

View File

@ -8,13 +8,38 @@ import torch.nn.functional as F
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from .model.input import RandomInputHead
from minamo.model.model import MinamoScoreModule
from .generator.loss import WGANGinkaLoss
from .generator.input import RandomInputHead
from .critic.model import MinamoModel
from shared.image import matrix_to_image_cv
# 标签定义:
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
# 8. 剧情层, 9. 水层, 10. 爽塔, 11. Boss层, 12. 纯Boss层, 13. 多房间, 14. 多走廊, 15. 道具塔
# 标量值定义:
# 0. 整体密度,非空白图块/地图面积,空白图块还包括装饰图块
# 1. 怪物密度,怪物数量/地图面积
# 2. 资源密度,资源数量/地图面积
# 3. 门密度,门数量/地图面积
# 4. 入口数量
# 图块定义:
# 0. 空地, 1. 墙壁, 2. 装饰(用于野外装饰,视为空地),
# 3. 黄门, 4. 蓝门, 5. 红门, 6. 机关门, 其余种类的门如绿门都视为红门
# 7-9. 黄蓝红门钥匙,机关门不使用钥匙开启
# 10-12. 三种等级的红宝石
# 13-15. 三种等级的蓝宝石
# 16-18. 三种等级的绿宝石
# 19-21. 三种等级的血瓶
# 22-24. 三种等级的道具
# 25-27. 三种等级的怪物
# 28-29. 留空
# 30. 楼梯入口
# 31. 箭头入口
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@ -34,27 +59,28 @@ def parse_arguments():
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
parser.add_argument("--tuning", type=bool, default=False)
args = parser.parse_args()
return args
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1, _ = gen(masked1, 1)
fake2, _ = gen(masked2, 2)
fake3, _ = gen(masked3, 3)
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1, _ = gen(masked1, 1, False, tag, val)
fake2, _ = gen(masked2, 2, False, tag, val)
fake3, _ = gen(masked3, 3, False, tag, val)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if progress_detach:
fake1, x_in = gen(input.detach(), 1, random)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3)
fake1, x_in = gen(input.detach(), 1, random, tag, val)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, False, tag, val)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, False, tag, val)
else:
fake1, x_in = gen(input, 1, random)
fake2, _ = gen(F.softmax(fake1, dim=1), 2)
fake3, _ = gen(F.softmax(fake2, dim=1), 3)
fake1, x_in = gen(input, 1, random, tag, val)
fake2, _ = gen(F.softmax(fake1, dim=1), 2, False, tag, val)
fake3, _ = gen(F.softmax(fake2, dim=1), 3, False, tag, val)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
else:
@ -74,7 +100,7 @@ def train():
ginka = GinkaModel().to(device)
ginka_head = RandomInputHead().to(device)
minamo = MinamoScoreModule().to(device)
minamo = MinamoModel().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
@ -133,6 +159,14 @@ def train():
print("Train from loaded state.")
curr_epoch = args.curr_epoch
if args.tuning:
train_stage = 1
curr_epoch = curr_epoch // 4
stage_epoch = 0
mask_ratio = 0.2
low_loss_epochs = 0
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
@ -142,7 +176,14 @@ def train():
loss_ce_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
real1 = batch["real1"].to(device)
masked1 = batch["masked1"].to(device)
real2 = batch["real2"].to(device)
masked2 = batch["masked2"].to(device)
real3 = batch["real3"].to(device)
masked3 = batch["masked3"].to(device)
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
# ---------- 训练判别器
for _ in range(c_steps):
@ -152,10 +193,10 @@ def train():
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
@ -235,7 +276,7 @@ def train():
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch:
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
@ -283,13 +324,21 @@ def train():
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
real1 = batch["real1"].to(device)
masked1 = batch["masked1"].to(device)
real2 = batch["real2"].to(device)
masked2 = batch["masked2"].to(device)
real3 = batch["real3"].to(device)
masked3 = batch["masked3"].to(device)
tag_cond = batch["tag_cond"].to(device)
val_cond = batch["val_cond"].to(device)
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()

View File

@ -1,49 +0,0 @@
import json
import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import differentiable_convert_to_data
from shared.utils import random_smooth_onehot
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
data_list = []
for value in data["data"].values():
data_list.append(value)
return data_list
class MinamoDataset(Dataset):
def __init__(self, data_path: str):
self.data = load_data(data_path) # 自定义数据加载函数
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.6, 1)
max_main = random.uniform(0.8, 1)
epsilon = random.uniform(0, 0.4)
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)
graph1 = differentiable_convert_to_data(map1_probs)
graph2 = differentiable_convert_to_data(map2_probs)
return (
map1_probs,
map2_probs,
torch.FloatTensor([item['visionSimilarity']]),
torch.FloatTensor([item['topoSimilarity']]),
graph1,
graph2
)

View File

@ -1,17 +0,0 @@
import torch.nn as nn
from tqdm import tqdm
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=0.2, topo_weight=0.8):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight
self.loss = nn.L1Loss()
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
vis_loss = self.loss(vis_pred, vis_true)
topo_loss = self.loss(topo_pred, topo_true)
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}")
# print(vis_loss.item(), topo_loss.item())
return self.vision_weight * vis_loss + self.topo_weight * topo_loss

View File

@ -1,83 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.data import Data
class MinamoSimilarityVision(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
nn.InstanceNorm2d(in_ch * 2),
nn.ReLU(),
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
nn.InstanceNorm2d(in_ch * 4),
nn.ReLU(),
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
nn.InstanceNorm2d(in_ch * 8),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
self.fc = nn.Sequential(
nn.Linear(in_ch * 8, out_ch),
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class MinamoSimilarityTopo(nn.Module):
def __init__(self, in_ch, hidden_dim, out_ch):
super().__init__()
self.input_fc = nn.Sequential(
nn.Linear(in_ch, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
)
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
self.norm1 = nn.LayerNorm(hidden_dim*2)
self.norm2 = nn.LayerNorm(hidden_dim*4)
self.norm3 = nn.LayerNorm(hidden_dim*8)
self.output_fc = nn.Sequential(
nn.Linear(hidden_dim*8, out_ch)
)
def forward(self, graph: Data):
x = self.input_fc(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.relu(self.norm1(x))
x = self.conv2(x, graph.edge_index)
x = F.relu(self.norm2(x))
x = self.conv3(x, graph.edge_index)
x = F.relu(self.norm3(x))
x = global_mean_pool(x, graph.batch)
x = self.output_fc(x)
return x
class MinamoSimilarityModel(nn.Module):
def __init__(self, tile_type=32):
super().__init__()
self.vision = MinamoSimilarityVision(tile_type, 512)
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
def forward(self, x, graph):
vis_feat = self.vision(x)
topo_feat = self.topo(graph)
return vis_feat, topo_feat

View File

@ -1,153 +0,0 @@
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
from .dataset import MinamoDataset
from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
disable_tqdm = not sys.stdout.isatty() # 如果 stdout 被重定向,则禁用 tqdm
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
model = MinamoModel(32)
model.to(device)
# 准备数据集
dataset = MinamoDataset(args.train)
val_dataset = MinamoDataset(args.validate)
dataloader = DataLoader(
dataset,
batch_size=64,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=64,
shuffle=True
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss()
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# for name, param in model.named_parameters():
# if 'ins' not in name: # 仅训练扩展部分
# param.requires_grad = False
# 开始训练
for epoch in tqdm(range(args.epochs), disable=disable_tqdm):
model.train()
total_loss = 0
# if epoch == 30:
# for name, param in model.named_parameters():
# param.requires_grad = True
for batch in tqdm(dataloader, leave=False, disable=disable_tqdm):
# 数据迁移到设备
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]
map2 = map2.to(device)
topo_simi = topo_simi.to(device)
vision_simi = vision_simi.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
if map1.shape[0] == 1:
continue
# 前向传播
optimizer.zero_grad()
vision_feat1, topo_feat1 = model(map1, graph1)
vision_feat2, topo_feat2 = model(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
# 反向传播
loss.backward()
optimizer.step()
total_loss += loss.item()
ave_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# total_norm = 0
# for p in model.parameters():
# if p.grad is not None:
# param_norm = p.grad.detach().data.norm(2)
# total_norm += param_norm.item() ** 2
# total_norm = total_norm ** 0.5
# tqdm.write(f"Gradient Norm: {total_norm:.4f}") # 正常应保持在1~100之间
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
# 学习率调整
scheduler.step()
# 每十轮推理一次验证集
if (epoch + 1) % 5 == 0:
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(val_loader, leave=False, disable=disable_tqdm):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
# 计算损失
loss_val = criterion(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
}, "result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(2)
train()

View File

@ -1,61 +0,0 @@
import torch
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
from .dataset import MinamoDataset
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = MinamoModel(32)
model.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
model.to(device)
for name, param in model.named_parameters():
print(f"Layer: {name}, Params: {param.numel()}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
# 准备数据集
val_dataset = MinamoDataset("datasets/minamo-eval-1.json")
val_loader = DataLoader(
val_dataset,
batch_size=32,
shuffle=True
)
criterion = MinamoLoss()
model.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(val_loader):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
map1_val = map1_val.to(device)
map2_val = map2_val.to(device)
vision_simi_val = vision_simi_val.to(device)
topo_simi_val = topo_simi_val.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
vision_feat1, topo_feat1 = model(map1_val, graph1)
vision_feat2, topo_feat2 = model(map2_val, graph2)
vision_pred_val = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
topo_pred_val = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
loss_val = criterion(
vision_pred_val, topo_pred_val,
vision_simi_val, topo_simi_val
)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(val_loader)
tqdm.write(f"Validation::loss: {avg_val_loss:.6f}")
if __name__ == "__main__":
torch.set_num_threads(2)
validate()