Merge branches 'master' and 'master' of github.com:unanmed/ginka-generator

This commit is contained in:
unanmed 2025-03-21 12:46:31 +08:00
commit f9211965db
4 changed files with 41 additions and 12 deletions

View File

@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
losses = [
minamo_loss * self.weight[0],
border_loss * self.weight[1] * 0.1,
border_loss * self.weight[1],
entrance_loss * self.weight[2],
count_loss * self.weight[3],
illegal_loss * self.weight[4]
]
# 梯度归一化
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(scaled_losses)
return total_loss
# scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(losses)
return total_loss

View File

@ -8,19 +8,21 @@ from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 150
# 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
model = GinkaModel()
model.to(device)
minamo = MinamoModel(32)
@ -53,9 +55,15 @@ def train():
# model.register_full_backward_hook(grad_hook)
# converter.register_full_backward_hook(grad_hook)
# criterion.register_full_backward_hook(grad_hook)
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# 开始训练
for epoch in tqdm(range(epochs)):
for epoch in tqdm(range(args.epochs)):
model.train()
total_loss = 0
@ -118,7 +126,7 @@ def train():
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
# "optimizer_state": optimizer.state_dict(),
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
@ -126,7 +134,7 @@ def train():
torch.save({
"model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(),
# "optimizer_state": optimizer.state_dict(),
}, f"result/ginka.pth")
if __name__ == "__main__":

View File

@ -8,13 +8,12 @@ from tqdm import tqdm
from .model.model import MinamoModel
from .model.loss import MinamoLoss
from .dataset import MinamoDataset
from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
epochs = 150
def collate_fn(batch):
"""动态处理不同尺寸地图的批处理"""
map1_batch = [item[0] for item in batch]
@ -35,6 +34,9 @@ def collate_fn(batch):
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
model = MinamoModel(32)
model.to(device)
@ -57,8 +59,15 @@ def train():
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss()
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# 开始训练
for epoch in tqdm(range(epochs)):
for epoch in tqdm(range(args.epochs)):
model.train()
total_loss = 0

12
shared/args.py Normal file
View File

@ -0,0 +1,12 @@
import argparse
def parse_arguments(from_default: str, train_default: str, val_default: str):
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default=from_default)
parser.add_argument("--load_optim", type=bool, default=False)
parser.add_argument("--train", type=str, default=train_default)
parser.add_argument("--validate", type=str, default=val_default)
parser.add_argument("--epochs", type=int, default=150)
args = parser.parse_args()
return args