mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 联合训练报错
This commit is contained in:
parent
765cdcaeb0
commit
90cfe54bd2
@ -219,7 +219,7 @@ def train():
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": maskGIT.state_dict(),
|
||||
"model_state": model.state_dict(),
|
||||
}, f"result/ginka_heatmap.pth")
|
||||
|
||||
def get_nms_sampling_count():
|
||||
|
||||
@ -17,7 +17,7 @@ from .heatmap.model import GinkaHeatmapModel
|
||||
from .maskGIT.model import GinkaMaskGIT
|
||||
|
||||
|
||||
BATCH_SIZE = 128
|
||||
BATCH_SIZE = 64
|
||||
VAL_BATCH_DIVIDER = 64
|
||||
NUM_CLASSES = 16
|
||||
MASK_TOKEN = 15
|
||||
@ -33,6 +33,8 @@ D_MODEL_DIFFUSION = 128
|
||||
T_DIFFUSION = 100
|
||||
MIN_MASK = 0
|
||||
MAX_MASK = 1
|
||||
CE_WEIGHT = 0.5
|
||||
DROP_RATE = 0.2
|
||||
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
@ -48,15 +50,13 @@ disable_tqdm = not sys.stdout.isatty()
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="joint training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state_heatmap", type=str, default="result/heatmap/ginka-100.pth")
|
||||
parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=50)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth")
|
||||
parser.add_argument("--ce_weight", type=float, default=1.0)
|
||||
parser.add_argument("--cfg_drop_rate", type=float, default=0.2)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
@ -231,7 +231,7 @@ def train():
|
||||
|
||||
cond_for_diffusion = cond_heatmap
|
||||
use_unconditional_branch = False
|
||||
if np.random.rand() < args.cfg_drop_rate:
|
||||
if np.random.rand() < DROP_RATE:
|
||||
cond_for_diffusion = torch.zeros_like(cond_heatmap)
|
||||
use_unconditional_branch = True
|
||||
|
||||
@ -245,7 +245,7 @@ def train():
|
||||
generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t)
|
||||
maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map)
|
||||
|
||||
loss = diffusion_loss + args.ce_weight * maskgit_loss
|
||||
loss = diffusion_loss + CE_WEIGHT * maskgit_loss
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
@ -275,7 +275,7 @@ def train():
|
||||
checkpoint_path,
|
||||
)
|
||||
|
||||
metrics = validate(model, maskgit, diffusion, dataloader_val, args.ce_weight)
|
||||
metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT)
|
||||
tqdm.write(
|
||||
f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] "
|
||||
f"E: {epoch + 1} | "
|
||||
|
||||
Loading…
Reference in New Issue
Block a user