mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
Merge branches 'master' and 'master' of github.com:unanmed/ginka-generator
This commit is contained in:
commit
f9211965db
@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
border_loss * self.weight[1] * 0.1,
|
||||
border_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2],
|
||||
count_loss * self.weight[3],
|
||||
illegal_loss * self.weight[4]
|
||||
]
|
||||
|
||||
# 梯度归一化
|
||||
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||
total_loss = sum(scaled_losses)
|
||||
return total_loss
|
||||
# scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||
total_loss = sum(losses)
|
||||
return total_loss
|
||||
|
||||
@ -8,19 +8,21 @@ from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .dataset import GinkaDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.args import parse_arguments
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 150
|
||||
|
||||
# 在生成器输出后添加梯度检查钩子
|
||||
def grad_hook(module, grad_input, grad_output):
|
||||
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||
|
||||
model = GinkaModel()
|
||||
model.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
@ -53,9 +55,15 @@ def train():
|
||||
# model.register_full_backward_hook(grad_hook)
|
||||
# converter.register_full_backward_hook(grad_hook)
|
||||
# criterion.register_full_backward_hook(grad_hook)
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"])
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(epochs)):
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
@ -118,7 +126,7 @@ def train():
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
# "optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
|
||||
@ -126,7 +134,7 @@ def train():
|
||||
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
"optimizer_state": optimizer.state_dict(),
|
||||
# "optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -8,13 +8,12 @@ from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
from .dataset import MinamoDataset
|
||||
from shared.args import parse_arguments
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 150
|
||||
|
||||
def collate_fn(batch):
|
||||
"""动态处理不同尺寸地图的批处理"""
|
||||
map1_batch = [item[0] for item in batch]
|
||||
@ -35,6 +34,9 @@ def collate_fn(batch):
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
|
||||
|
||||
model = MinamoModel(32)
|
||||
model.to(device)
|
||||
|
||||
@ -57,8 +59,15 @@ def train():
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"])
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(epochs)):
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
|
||||
12
shared/args.py
Normal file
12
shared/args.py
Normal file
@ -0,0 +1,12 @@
|
||||
import argparse
|
||||
|
||||
def parse_arguments(from_default: str, train_default: str, val_default: str):
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--from_state", type=str, default=from_default)
|
||||
parser.add_argument("--load_optim", type=bool, default=False)
|
||||
parser.add_argument("--train", type=str, default=train_default)
|
||||
parser.add_argument("--validate", type=str, default=val_default)
|
||||
parser.add_argument("--epochs", type=int, default=150)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
Loading…
Reference in New Issue
Block a user