Merge branches 'master' and 'master' of github.com:unanmed/ginka-generator

This commit is contained in:
unanmed 2025-03-21 12:46:31 +08:00
commit f9211965db
4 changed files with 41 additions and 12 deletions

View File

@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
losses = [ losses = [
minamo_loss * self.weight[0], minamo_loss * self.weight[0],
border_loss * self.weight[1] * 0.1, border_loss * self.weight[1],
entrance_loss * self.weight[2], entrance_loss * self.weight[2],
count_loss * self.weight[3], count_loss * self.weight[3],
illegal_loss * self.weight[4] illegal_loss * self.weight[4]
] ]
# 梯度归一化 # 梯度归一化
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses] # scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
total_loss = sum(scaled_losses) total_loss = sum(losses)
return total_loss return total_loss

View File

@ -8,19 +8,21 @@ from .model.model import GinkaModel
from .model.loss import GinkaLoss from .model.loss import GinkaLoss
from .dataset import GinkaDataset from .dataset import GinkaDataset
from minamo.model.model import MinamoModel from minamo.model.model import MinamoModel
from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True) os.makedirs("result/ginka_checkpoint", exist_ok=True)
epochs = 150
# 在生成器输出后添加梯度检查钩子 # 在生成器输出后添加梯度检查钩子
def grad_hook(module, grad_input, grad_output): def grad_hook(module, grad_input, grad_output):
print(f"Generator output grad norm: {grad_output[0].norm().item()}") print(f"Generator output grad norm: {grad_output[0].norm().item()}")
def train(): def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
model = GinkaModel() model = GinkaModel()
model.to(device) model.to(device)
minamo = MinamoModel(32) minamo = MinamoModel(32)
@ -53,9 +55,15 @@ def train():
# model.register_full_backward_hook(grad_hook) # model.register_full_backward_hook(grad_hook)
# converter.register_full_backward_hook(grad_hook) # converter.register_full_backward_hook(grad_hook)
# criterion.register_full_backward_hook(grad_hook) # criterion.register_full_backward_hook(grad_hook)
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# 开始训练 # 开始训练
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(args.epochs)):
model.train() model.train()
total_loss = 0 total_loss = 0
@ -118,7 +126,7 @@ def train():
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}") tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({ torch.save({
"model_state": model.state_dict(), "model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(), # "optimizer_state": optimizer.state_dict(),
}, f"result/ginka_checkpoint/{epoch + 1}.pth") }, f"result/ginka_checkpoint/{epoch + 1}.pth")
@ -126,7 +134,7 @@ def train():
torch.save({ torch.save({
"model_state": model.state_dict(), "model_state": model.state_dict(),
"optimizer_state": optimizer.state_dict(), # "optimizer_state": optimizer.state_dict(),
}, f"result/ginka.pth") }, f"result/ginka.pth")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -8,13 +8,12 @@ from tqdm import tqdm
from .model.model import MinamoModel from .model.model import MinamoModel
from .model.loss import MinamoLoss from .model.loss import MinamoLoss
from .dataset import MinamoDataset from .dataset import MinamoDataset
from shared.args import parse_arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True)
epochs = 150
def collate_fn(batch): def collate_fn(batch):
"""动态处理不同尺寸地图的批处理""" """动态处理不同尺寸地图的批处理"""
map1_batch = [item[0] for item in batch] map1_batch = [item[0] for item in batch]
@ -35,6 +34,9 @@ def collate_fn(batch):
def train(): def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
model = MinamoModel(32) model = MinamoModel(32)
model.to(device) model.to(device)
@ -57,8 +59,15 @@ def train():
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss() criterion = MinamoLoss()
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"])
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
# 开始训练 # 开始训练
for epoch in tqdm(range(epochs)): for epoch in tqdm(range(args.epochs)):
model.train() model.train()
total_loss = 0 total_loss = 0

12
shared/args.py Normal file
View File

@ -0,0 +1,12 @@
import argparse
def parse_arguments(from_default: str, train_default: str, val_default: str):
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default=from_default)
parser.add_argument("--load_optim", type=bool, default=False)
parser.add_argument("--train", type=str, default=train_default)
parser.add_argument("--validate", type=str, default=val_default)
parser.add_argument("--epochs", type=int, default=150)
args = parser.parse_args()
return args