mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 优化训练流程,改进判别器与损失值
This commit is contained in:
parent
a94b07bda8
commit
5be9decf95
7
cycle.sh
7
cycle.sh
@ -1,7 +0,0 @@
|
||||
i=$1
|
||||
while true
|
||||
do
|
||||
sh gan.sh "$i"
|
||||
i=$((i+1))
|
||||
echo "第 $i 次循环完成"
|
||||
done
|
||||
@ -1,7 +0,0 @@
|
||||
start=$1
|
||||
end=$2
|
||||
for ((i=start; i<=end; i=i+1))
|
||||
do
|
||||
sh gan.sh "$i"
|
||||
echo "第 $i 次循环完成"
|
||||
done
|
||||
17
gan.sh
17
gan.sh
@ -1,17 +0,0 @@
|
||||
# 训练部分
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m ginka.train --epochs 30 --resume true
|
||||
python3 -m ginka.validate
|
||||
# 训练完毕,处理数据
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||
cd data
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
||||
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
|
||||
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
|
||||
cd ..
|
||||
@ -325,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False):
|
||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
||||
|
||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=10)
|
||||
return torch.log1p(0.5 * (kl_pm + kl_qm))
|
||||
|
||||
def immutable_penalty_loss(
|
||||
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||
@ -334,8 +334,8 @@ def immutable_penalty_loss(
|
||||
惩罚模型修改不可更改区域的损失。
|
||||
|
||||
Args:
|
||||
input: 模型输出 [B, C, H, W],概率分布 (softmax 后)
|
||||
target: 原始输入图 [B, C, H, W],概率分布 (softmax 后)
|
||||
input: 模型输出 [B, C, H, W],概率分布 (softmax 前)
|
||||
target: 原始输入图 [B, C, H, W],概率分布 (softmax 前)
|
||||
modifiable_classes: 允许被修改的类别列表
|
||||
"""
|
||||
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
|
||||
@ -344,13 +344,10 @@ def immutable_penalty_loss(
|
||||
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
|
||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||
|
||||
target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布
|
||||
input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布
|
||||
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True)
|
||||
penalty = F.cross_entropy(input_mask, target_mask)
|
||||
|
||||
return torch.clamp(penalty, max=1)
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
|
||||
@ -420,7 +417,7 @@ class WGANGinkaLoss:
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
@ -471,7 +468,7 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
@ -496,7 +493,7 @@ class WGANGinkaLoss:
|
||||
losses = [
|
||||
input_head_illegal_loss(probs),
|
||||
input_head_wall_loss(probs),
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.2
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.3
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -20,7 +20,7 @@ class GinkaModel(nn.Module):
|
||||
|
||||
def forward(self, x, stage, random=False):
|
||||
if random:
|
||||
x_in = F.softmax(self.head(x))
|
||||
x_in = F.softmax(self.head(x), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
x = self.input(x_in)
|
||||
@ -30,7 +30,7 @@ class GinkaModel(nn.Module):
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
input = torch.randn((1, 32, 13, 13)).cuda()
|
||||
input = torch.randn((1, 32, 32, 32)).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaModel().cuda()
|
||||
@ -38,7 +38,7 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output = model(input, 1)
|
||||
output, _ = model(input, 1, True)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
|
||||
@ -49,12 +49,12 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
|
||||
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
if progress_detach:
|
||||
fake1, x_in = gen(input.detach(), 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1.detach()), 2)
|
||||
fake3, _ = gen(F.softmax(fake2.detach()), 3)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3)
|
||||
else:
|
||||
fake1, x_in = gen(input, 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2), 3)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
||||
else:
|
||||
@ -69,7 +69,6 @@ def train():
|
||||
g_steps = 1
|
||||
# 训练阶段
|
||||
train_stage = 1
|
||||
last_stage = False
|
||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
|
||||
@ -83,7 +82,6 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
@ -117,9 +115,6 @@ def train():
|
||||
if data_ginka.get("stage") is not None:
|
||||
train_stage = data_ginka["stage"]
|
||||
|
||||
if data_ginka.get("last_stage") is not None:
|
||||
last_stage = data_ginka["last_stage"]
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
@ -149,22 +144,11 @@ def train():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
|
||||
if train_stage == 4:
|
||||
# 最后一个阶段训练输入头
|
||||
count = 5 if stage_epoch <= 20 else 2
|
||||
for _ in range(count):
|
||||
optimizer_head.zero_grad()
|
||||
output = F.softmax(ginka_head(masked1), dim=1)
|
||||
loss_head = criterion.generator_input_head_loss(output)
|
||||
loss_head.backward()
|
||||
optimizer_head.step()
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
# 生成假样本
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
optimizer_head.zero_grad()
|
||||
|
||||
with torch.no_grad():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
@ -193,7 +177,6 @@ def train():
|
||||
for _ in range(g_steps):
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
optimizer_head.zero_grad()
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
||||
|
||||
@ -210,10 +193,10 @@ def train():
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4)
|
||||
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
|
||||
else:
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||
@ -221,7 +204,7 @@ def train():
|
||||
|
||||
if train_stage == 4:
|
||||
loss_head = criterion.generator_input_head_loss(x_in)
|
||||
loss_head.backward()
|
||||
loss_head.backward(retain_graph=True)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g.backward()
|
||||
@ -239,40 +222,26 @@ def train():
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 1.0:
|
||||
if avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
|
||||
if train_stage >= 2:
|
||||
train_stage += 1
|
||||
|
||||
if train_stage == 5:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
|
||||
if (train_stage == 3 or train_stage == 2) and not last_stage:
|
||||
if stage_epoch >= 25:
|
||||
train_stage += 1
|
||||
stage_epoch = 0
|
||||
|
||||
if train_stage == 4:
|
||||
last_stage = True
|
||||
|
||||
if train_stage >= 3 or last_stage:
|
||||
# 第三阶段后交叉熵损失不再应该生效
|
||||
mask_ratio = 1.0
|
||||
|
||||
if last_stage:
|
||||
mask_ratio = 1.0
|
||||
if train_stage == 2 and stage_epoch % 5 == 0:
|
||||
train_stage = 4
|
||||
|
||||
if train_stage == 4 and stage_epoch % 5 == 1:
|
||||
train_stage = 2
|
||||
|
||||
stage_epoch += 1
|
||||
|
||||
@ -305,7 +274,6 @@ def train():
|
||||
"stage": train_stage,
|
||||
"mask_ratio": mask_ratio,
|
||||
"stage_epoch": stage_epoch,
|
||||
"last_stage": last_stage
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict(),
|
||||
|
||||
@ -2,9 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
@ -19,20 +24,51 @@ class MinamoModel(nn.Module):
|
||||
topo_feat = self.topo_model(graph)
|
||||
|
||||
return vision_feat, topo_feat
|
||||
|
||||
class CNNHead(nn.Module):
|
||||
def __init__(self, in_ch, out_dim):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
nn.AdaptiveMaxPool2d((2, 2))
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*2*2, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn(x)
|
||||
B, C, H, W = x.shape
|
||||
x = x.view(B, -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class MinamoScoreHead(nn.Module):
|
||||
class GCNHead(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.vision_fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, out_dim)),
|
||||
)
|
||||
self.topo_fc = nn.Sequential(
|
||||
self.gcn = GCNConv(in_dim, in_dim)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, vis_feat, topo_feat):
|
||||
vis_score = self.vision_fc(vis_feat)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
def forward(self, x, graph):
|
||||
x = self.gcn(x, graph.edge_index)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
x = global_max_pool(x, graph.batch)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
class MinamoScoreHead(nn.Module):
|
||||
def __init__(self, vision_dim, topo_dim, out_dim):
|
||||
super().__init__()
|
||||
self.vision_head = CNNHead(vision_dim, out_dim)
|
||||
self.topo_head = GCNHead(topo_dim, out_dim)
|
||||
|
||||
def forward(self, vis, topo, graph):
|
||||
vis_score = self.vision_head(vis)
|
||||
topo_score = self.topo_head(topo, graph)
|
||||
return vis_score, topo_score
|
||||
|
||||
class MinamoScoreModule(nn.Module):
|
||||
@ -41,20 +77,41 @@ class MinamoScoreModule(nn.Module):
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.head1 = MinamoScoreHead(512, 1)
|
||||
self.head2 = MinamoScoreHead(512, 1)
|
||||
self.head3 = MinamoScoreHead(512, 1)
|
||||
self.head1 = MinamoScoreHead(512, 512, 1)
|
||||
self.head2 = MinamoScoreHead(512, 512, 1)
|
||||
self.head3 = MinamoScoreHead(512, 512, 1)
|
||||
|
||||
def forward(self, map, graph, stage):
|
||||
vision_feat = self.vision_model(map)
|
||||
topo_feat = self.topo_model(graph)
|
||||
vision = self.vision_model(map)
|
||||
topo = self.topo_model(graph)
|
||||
if stage == 1:
|
||||
vision_score, topo_score = self.head1(vision_feat, topo_feat)
|
||||
vision_score, topo_score = self.head1(vision, topo, graph)
|
||||
elif stage == 2:
|
||||
vision_score, topo_score = self.head2(vision_feat, topo_feat)
|
||||
vision_score, topo_score = self.head2(vision, topo, graph)
|
||||
elif stage == 3:
|
||||
vision_score, topo_score = self.head3(vision_feat, topo_feat)
|
||||
vision_score, topo_score = self.head3(vision, topo, graph)
|
||||
else:
|
||||
raise RuntimeError("Unknown critic stage.")
|
||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||
return score, vision_score, topo_score
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
input = torch.randn((1, 32, 13, 13)).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = MinamoScoreModule().cuda()
|
||||
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
|
||||
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
|
||||
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -2,12 +2,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool
|
||||
from torch_geometric.nn import GATConv
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512
|
||||
):
|
||||
super().__init__()
|
||||
# 传入 softmax 概率值,直接映射
|
||||
@ -20,15 +20,6 @@ class MinamoTopoModel(nn.Module):
|
||||
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
|
||||
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
|
||||
|
||||
# self.norm1 = nn.LayerNorm(hidden_dim*8)
|
||||
# self.norm2 = nn.LayerNorm(hidden_dim*8)
|
||||
# self.norm3 = nn.LayerNorm(out_dim)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(out_dim, feat_dim)),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.input_proj(graph.x)
|
||||
|
||||
@ -41,10 +32,5 @@ class MinamoTopoModel(nn.Module):
|
||||
x = self.conv3(x, graph.edge_index)
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
# 池化
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
|
||||
topo_vec = self.fc(x)
|
||||
|
||||
return topo_vec
|
||||
return x
|
||||
|
||||
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, in_ch=32, out_dim=512):
|
||||
def __init__(self, in_ch=32, out_ch=512):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
|
||||
@ -13,18 +13,10 @@ class MinamoVisionModel(nn.Module):
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
|
||||
spectral_norm(nn.Conv2d(in_ch*4, out_ch, 3)), # 7*7
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
nn.AdaptiveAvgPool2d(2)
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
4
train.sh
4
train.sh
@ -1,4 +1,4 @@
|
||||
# 从头训练
|
||||
python3 -u -m ginka.train_wgan >> output.log
|
||||
python3 -u -m ginka.train_wgan --epochs 300 >> output.log
|
||||
# 接续训练
|
||||
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||
|
||||
Loading…
Reference in New Issue
Block a user