mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
Merge branches 'master' and 'master' of github.com:unanmed/ginka-generator
This commit is contained in:
commit
f9211965db
@ -335,13 +335,13 @@ class GinkaLoss(nn.Module):
|
|||||||
|
|
||||||
losses = [
|
losses = [
|
||||||
minamo_loss * self.weight[0],
|
minamo_loss * self.weight[0],
|
||||||
border_loss * self.weight[1] * 0.1,
|
border_loss * self.weight[1],
|
||||||
entrance_loss * self.weight[2],
|
entrance_loss * self.weight[2],
|
||||||
count_loss * self.weight[3],
|
count_loss * self.weight[3],
|
||||||
illegal_loss * self.weight[4]
|
illegal_loss * self.weight[4]
|
||||||
]
|
]
|
||||||
|
|
||||||
# 梯度归一化
|
# 梯度归一化
|
||||||
scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
# scaled_losses = [loss / (loss.detach() + 1e-6) for loss in losses]
|
||||||
total_loss = sum(scaled_losses)
|
total_loss = sum(losses)
|
||||||
return total_loss
|
return total_loss
|
||||||
|
|||||||
@ -8,19 +8,21 @@ from .model.model import GinkaModel
|
|||||||
from .model.loss import GinkaLoss
|
from .model.loss import GinkaLoss
|
||||||
from .dataset import GinkaDataset
|
from .dataset import GinkaDataset
|
||||||
from minamo.model.model import MinamoModel
|
from minamo.model.model import MinamoModel
|
||||||
|
from shared.args import parse_arguments
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||||
|
|
||||||
epochs = 150
|
|
||||||
|
|
||||||
# 在生成器输出后添加梯度检查钩子
|
# 在生成器输出后添加梯度检查钩子
|
||||||
def grad_hook(module, grad_input, grad_output):
|
def grad_hook(module, grad_input, grad_output):
|
||||||
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
print(f"Generator output grad norm: {grad_output[0].norm().item()}")
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
|
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||||
|
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
minamo = MinamoModel(32)
|
minamo = MinamoModel(32)
|
||||||
@ -53,9 +55,15 @@ def train():
|
|||||||
# model.register_full_backward_hook(grad_hook)
|
# model.register_full_backward_hook(grad_hook)
|
||||||
# converter.register_full_backward_hook(grad_hook)
|
# converter.register_full_backward_hook(grad_hook)
|
||||||
# criterion.register_full_backward_hook(grad_hook)
|
# criterion.register_full_backward_hook(grad_hook)
|
||||||
|
if args.resume:
|
||||||
|
data = torch.load(args.from_state, map_location=device)
|
||||||
|
model.load_state_dict(data["model_state"])
|
||||||
|
if args.load_optim:
|
||||||
|
optimizer.load_state_dict(data["optimizer_state"])
|
||||||
|
print("Train from loaded state.")
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(args.epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
@ -118,7 +126,7 @@ def train():
|
|||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": model.state_dict(),
|
"model_state": model.state_dict(),
|
||||||
"optimizer_state": optimizer.state_dict(),
|
# "optimizer_state": optimizer.state_dict(),
|
||||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||||
|
|
||||||
|
|
||||||
@ -126,7 +134,7 @@ def train():
|
|||||||
|
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": model.state_dict(),
|
"model_state": model.state_dict(),
|
||||||
"optimizer_state": optimizer.state_dict(),
|
# "optimizer_state": optimizer.state_dict(),
|
||||||
}, f"result/ginka.pth")
|
}, f"result/ginka.pth")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -8,13 +8,12 @@ from tqdm import tqdm
|
|||||||
from .model.model import MinamoModel
|
from .model.model import MinamoModel
|
||||||
from .model.loss import MinamoLoss
|
from .model.loss import MinamoLoss
|
||||||
from .dataset import MinamoDataset
|
from .dataset import MinamoDataset
|
||||||
|
from shared.args import parse_arguments
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||||
|
|
||||||
epochs = 150
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
"""动态处理不同尺寸地图的批处理"""
|
"""动态处理不同尺寸地图的批处理"""
|
||||||
map1_batch = [item[0] for item in batch]
|
map1_batch = [item[0] for item in batch]
|
||||||
@ -35,6 +34,9 @@ def collate_fn(batch):
|
|||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
|
args = parse_arguments("result/minamo.pth", "minamo-dataset.json", 'minamo-eval.json')
|
||||||
|
|
||||||
model = MinamoModel(32)
|
model = MinamoModel(32)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
@ -57,8 +59,15 @@ def train():
|
|||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = MinamoLoss()
|
criterion = MinamoLoss()
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
data = torch.load(args.from_state, map_location=device)
|
||||||
|
model.load_state_dict(data["model_state"])
|
||||||
|
if args.load_optim:
|
||||||
|
optimizer.load_state_dict(data["optimizer_state"])
|
||||||
|
print("Train from loaded state.")
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
for epoch in tqdm(range(epochs)):
|
for epoch in tqdm(range(args.epochs)):
|
||||||
model.train()
|
model.train()
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
|
|||||||
12
shared/args.py
Normal file
12
shared/args.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
def parse_arguments(from_default: str, train_default: str, val_default: str):
|
||||||
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
|
parser.add_argument("--from_state", type=str, default=from_default)
|
||||||
|
parser.add_argument("--load_optim", type=bool, default=False)
|
||||||
|
parser.add_argument("--train", type=str, default=train_default)
|
||||||
|
parser.add_argument("--validate", type=str, default=val_default)
|
||||||
|
parser.add_argument("--epochs", type=int, default=150)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
Loading…
Reference in New Issue
Block a user