diff --git a/ginka/model/loss.py b/ginka/model/loss.py index f196d7b..9d90951 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -335,13 +335,13 @@ class GinkaLoss(nn.Module): losses = [ minamo_loss * self.weight[0], - border_loss * self.weight[1] * 0.1, + border_loss * self.weight[1], entrance_loss * self.weight[2], count_loss * self.weight[3], illegal_loss * self.weight[4] ] # 梯度归一化 - scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] - total_loss = sum(scaled_losses) - return total_loss \ No newline at end of file + # scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] + total_loss = sum(losses) + return total_loss diff --git a/ginka/train.py b/ginka/train.py index 70047a7..f86e75a 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -8,19 +8,21 @@ from .model.model import GinkaModel from .model.loss import GinkaLoss from .dataset import GinkaDataset from minamo.model.model import MinamoModel +from shared.args import parse_arguments device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/ginka_checkpoint", exist_ok=True) -epochs = 150 - # 在生成器输出后添加梯度检查钩子 def grad_hook(module, grad_input, grad_output): print(f"Generator output grad norm: {grad_output[0].norm().item()}") def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + + args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json') + model = GinkaModel() model.to(device) minamo = MinamoModel(32) @@ -53,9 +55,15 @@ def train(): # model.register_full_backward_hook(grad_hook) # converter.register_full_backward_hook(grad_hook) # criterion.register_full_backward_hook(grad_hook) + if args.resume: + data = torch.load(args.from_state, map_location=device) + model.load_state_dict(data["model_state"]) + if args.load_optim: + optimizer.load_state_dict(data["optimizer_state"]) + print("Train from loaded state.") # 开始训练 - for epoch in tqdm(range(epochs)): + for epoch in tqdm(range(args.epochs)): model.train() total_loss = 0 @@ -118,7 +126,7 @@ def train(): tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") torch.save({ "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), + # "optimizer_state": optimizer.state_dict(), }, f"result/ginka_checkpoint/{epoch + 1}.pth") @@ -126,7 +134,7 @@ def train(): torch.save({ "model_state": model.state_dict(), - "optimizer_state": optimizer.state_dict(), + # "optimizer_state": optimizer.state_dict(), }, f"result/ginka.pth") if __name__ == "__main__": diff --git a/minamo/train.py b/minamo/train.py index 153ac90..bc3a1d9 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -8,13 +8,12 @@ from tqdm import tqdm from .model.model import MinamoModel from .model.loss import MinamoLoss from .dataset import MinamoDataset +from shared.args import parse_arguments device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True) -epochs = 150 - def collate_fn(batch): """动态处理不同尺寸地图的批处理""" map1_batch = [item[0] for item in batch] @@ -35,6 +34,9 @@ def collate_fn(batch): def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + + args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json') + model = MinamoModel(32) model.to(device) @@ -57,8 +59,15 @@ def train(): scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = MinamoLoss() + if args.resume: + data = torch.load(args.from_state, map_location=device) + model.load_state_dict(data["model_state"]) + if args.load_optim: + optimizer.load_state_dict(data["optimizer_state"]) + print("Train from loaded state.") + # 开始训练 - for epoch in tqdm(range(epochs)): + for epoch in tqdm(range(args.epochs)): model.train() total_loss = 0 diff --git a/shared/args.py b/shared/args.py new file mode 100644 index 0000000..3fd8f20 --- /dev/null +++ b/shared/args.py @@ -0,0 +1,12 @@ +import argparse + +def parse_arguments(from_default: str, train_default: str, val_default: str): + parser = argparse.ArgumentParser(description="training codes") + parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--from_state", type=str, default=from_default) + parser.add_argument("--load_optim", type=bool, default=False) + parser.add_argument("--train", type=str, default=train_default) + parser.add_argument("--validate", type=str, default=val_default) + parser.add_argument("--epochs", type=int, default=150) + args = parser.parse_args() + return args