mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 18:31:13 +08:00
fix: 联合训练报错
This commit is contained in:
parent
765cdcaeb0
commit
90cfe54bd2
@ -219,7 +219,7 @@ def train():
|
|||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": maskGIT.state_dict(),
|
"model_state": model.state_dict(),
|
||||||
}, f"result/ginka_heatmap.pth")
|
}, f"result/ginka_heatmap.pth")
|
||||||
|
|
||||||
def get_nms_sampling_count():
|
def get_nms_sampling_count():
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from .heatmap.model import GinkaHeatmapModel
|
|||||||
from .maskGIT.model import GinkaMaskGIT
|
from .maskGIT.model import GinkaMaskGIT
|
||||||
|
|
||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 64
|
||||||
VAL_BATCH_DIVIDER = 64
|
VAL_BATCH_DIVIDER = 64
|
||||||
NUM_CLASSES = 16
|
NUM_CLASSES = 16
|
||||||
MASK_TOKEN = 15
|
MASK_TOKEN = 15
|
||||||
@ -33,6 +33,8 @@ D_MODEL_DIFFUSION = 128
|
|||||||
T_DIFFUSION = 100
|
T_DIFFUSION = 100
|
||||||
MIN_MASK = 0
|
MIN_MASK = 0
|
||||||
MAX_MASK = 1
|
MAX_MASK = 1
|
||||||
|
CE_WEIGHT = 0.5
|
||||||
|
DROP_RATE = 0.2
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
"cuda:1" if torch.cuda.is_available()
|
"cuda:1" if torch.cuda.is_available()
|
||||||
@ -48,15 +50,13 @@ disable_tqdm = not sys.stdout.isatty()
|
|||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="joint training codes")
|
parser = argparse.ArgumentParser(description="joint training codes")
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
parser.add_argument("--state_heatmap", type=str, default="result/heatmap/ginka-100.pth")
|
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||||
parser.add_argument("--epochs", type=int, default=50)
|
parser.add_argument("--epochs", type=int, default=50)
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
parser.add_argument("--checkpoint", type=int, default=5)
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
parser.add_argument("--load_optim", type=bool, default=True)
|
||||||
parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth")
|
parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth")
|
||||||
parser.add_argument("--ce_weight", type=float, default=1.0)
|
|
||||||
parser.add_argument("--cfg_drop_rate", type=float, default=0.2)
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@ -231,7 +231,7 @@ def train():
|
|||||||
|
|
||||||
cond_for_diffusion = cond_heatmap
|
cond_for_diffusion = cond_heatmap
|
||||||
use_unconditional_branch = False
|
use_unconditional_branch = False
|
||||||
if np.random.rand() < args.cfg_drop_rate:
|
if np.random.rand() < DROP_RATE:
|
||||||
cond_for_diffusion = torch.zeros_like(cond_heatmap)
|
cond_for_diffusion = torch.zeros_like(cond_heatmap)
|
||||||
use_unconditional_branch = True
|
use_unconditional_branch = True
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ def train():
|
|||||||
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
||||||
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
||||||
|
|
||||||
loss = diffusion_loss + args.ce_weight * maskgit_loss
|
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
@ -275,7 +275,7 @@ def train():
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = validate(model, maskgit, diffusion, dataloader_val, args.ce_weight)
|
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||||
f"E: {epoch + 1} | "
|
f"E: {epoch + 1} | "
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user