fix: 联合训练报错

This commit is contained in:
unanmed 2026-04-23 18:24:18 +08:00
parent 765cdcaeb0
commit 90cfe54bd2
2 changed files with 8 additions and 8 deletions

View File

@ -219,7 +219,7 @@ def train():
print("Train ended.")
torch.save({
"model_state": maskGIT.state_dict(),
"model_state": model.state_dict(),
}, f"result/ginka_heatmap.pth")
def get_nms_sampling_count():

View File

@ -17,7 +17,7 @@ from .heatmap.model import GinkaHeatmapModel
from .maskGIT.model import GinkaMaskGIT
BATCH_SIZE = 128
BATCH_SIZE = 64
VAL_BATCH_DIVIDER = 64
NUM_CLASSES = 16
MASK_TOKEN = 15
@ -33,6 +33,8 @@ D_MODEL_DIFFUSION = 128
T_DIFFUSION = 100
MIN_MASK = 0
MAX_MASK = 1
CE_WEIGHT = 0.5
DROP_RATE = 0.2
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -48,15 +50,13 @@ disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="joint training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_heatmap", type=str, default="result/heatmap/ginka-100.pth")
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=50)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth")
parser.add_argument("--ce_weight", type=float, default=1.0)
parser.add_argument("--cfg_drop_rate", type=float, default=0.2)
args = parser.parse_args()
return args
@ -231,7 +231,7 @@ def train():
cond_for_diffusion = cond_heatmap
use_unconditional_branch = False
if np.random.rand() < args.cfg_drop_rate:
if np.random.rand() < DROP_RATE:
cond_for_diffusion = torch.zeros_like(cond_heatmap)
use_unconditional_branch = True
@ -245,7 +245,7 @@ def train():
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
loss = diffusion_loss + args.ce_weight * maskgit_loss
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
loss.backward()
optimizer.step()
@ -275,7 +275,7 @@ def train():
checkpoint_path,
)
metrics = validate(model, maskgit, diffusion, dataloader_val, args.ce_weight)
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
tqdm.write(
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
f"E: {epoch + 1} | "