From fa72e18c1c1f3dbdf9b2fc0c86d6390b6633127e Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 10 Mar 2026 23:26:36 +0800 Subject: [PATCH] fix: maskGIT --- ginka/train_maskGIT.py | 44 +++++++++++++++++++------------------ ginka/train_transformer.py | 7 +++--- tiles2/0.png | Bin 0 -> 1410 bytes tiles2/1.png | Bin 0 -> 576 bytes tiles2/10.png | Bin 0 -> 699 bytes tiles2/2.png | Bin 0 -> 426 bytes tiles2/3.png | Bin 0 -> 368 bytes tiles2/4.png | Bin 0 -> 406 bytes tiles2/5.png | Bin 0 -> 396 bytes tiles2/6.png | Bin 0 -> 419 bytes tiles2/7.png | Bin 0 -> 441 bytes tiles2/8.png | Bin 0 -> 448 bytes tiles2/9.png | Bin 0 -> 353 bytes 13 files changed, 27 insertions(+), 24 deletions(-) create mode 100644 tiles2/0.png create mode 100644 tiles2/1.png create mode 100644 tiles2/10.png create mode 100644 tiles2/2.png create mode 100644 tiles2/3.png create mode 100644 tiles2/4.png create mode 100644 tiles2/5.png create mode 100644 tiles2/6.png create mode 100644 tiles2/7.png create mode 100644 tiles2/8.png create mode 100644 tiles2/9.png diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 0fd3657..4abcc68 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -93,9 +93,9 @@ def train(): # 用于生成图片 tile_dict = dict() - for file in os.listdir('tiles'): + for file in os.listdir('tiles2'): name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -111,30 +111,32 @@ def train(): for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): loss_total = torch.Tensor([0]).to(device) - # for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - # target_map = batch["target_map"].to(device) - # cond = batch["val_cond"].to(device) - # B, H, W = target_map.shape - # target_map = target_map.view(B, H * W) + for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + target_map = batch["target_map"].to(device) + cond = batch["val_cond"].to(device) + B, H, W = target_map.shape + target_map = target_map.view(B, H * W) - # # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - # r = torch.rand(B).to(device) - # r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) + r = torch.rand(B).to(device) + r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 - # # 2. 生成掩码矩阵 - # masks = torch.rand(target_map.shape).to(device) < r - # masked_input = target_map.clone() - # masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + # 2. 生成掩码矩阵 + masks = torch.rand(target_map.shape).to(device) < r + masked_input = target_map.clone() + masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 - # logits = model(masked_input, cond) + logits = model(masked_input, cond) - # loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') - # loss = (loss * masks).sum() / (masks.sum() + 1e-6) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') + loss = (loss * masks).sum() / (masks.sum() + 1e-6) - # optimizer.zero_grad() - # loss.backward() - # optimizer.step() - # loss_total += loss.detach() + optimizer.zero_grad() + loss.backward() + optimizer.step() + loss_total += loss.detach() + + scheduler.step() avg_loss = loss_total.item() / len(dataloader) tqdm.write( diff --git a/ginka/train_transformer.py b/ginka/train_transformer.py index 750e8e0..e3b5198 100644 --- a/ginka/train_transformer.py +++ b/ginka/train_transformer.py @@ -102,9 +102,9 @@ def train(): # 用于生成图片 tile_dict = dict() - for file in os.listdir('tiles'): + for file in os.listdir('tiles2'): name = os.path.splitext(file)[0] - tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED) + tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -196,6 +196,7 @@ def train(): color = (255, 255, 255) # 白色 vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 # 地图重建展示 + vae.teacher_forcing() for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): target_map = batch["target_map"].to(device) B, H, W = target_map.shape @@ -218,10 +219,10 @@ def train(): idx += 1 # 随机采样 + vae.autoregressive() for i in range(0, 8): z = torch.randn(1, LATENT_DIM).to(device) - vae.autoregressive() fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) diff --git a/tiles2/0.png b/tiles2/0.png new file mode 100644 index 0000000000000000000000000000000000000000..9649930b3c311f91fd88e6d8175b41feb3eac820 GIT binary patch literal 1410 zcmV-|1%3L7P)Px)JV``BR9HuqSKChGNDw`3JDHuQ5rQO+12xK zc8P?)Bth48aQr!kx~uVZ{uNo8;oHw|IP4C9LpVd(=jZ2W>IS>*F09vWjX^qqTl}@~ ztJM-kF~e@RLq5t;Jys}ZCF-Vz^A7i`d%y#O!5|n>?Y0<=M%eB57^Xvpde!quX+`S+`<@CJB3#`^FWG+LC7PF$ja=Sv;waDBccoFuK40U^?dJK|52tkSz z$M%TCB}lUr<+KdS#-lN^EQ5y+&=XazPbN9s<>jSkoOpmp*gK&jgr}#c*j8Kk&iAob z4)7kF?q1N~bDrnHJKi^&O~1}FeeO7i*=z=PadF|Pl*(SLnxLwx@GRwAuj~2+ULY8$ z^Ng?eJ}7G2w%?~SQWnc+sQ69@9}%FXRK`26^?DudGJ*+51wu;eJfEg%I7@}A)ha0E zJ`oV2@p;POpXUPt!{IQz#5*r~OCxAGm9Wf~%VjT|#u6#+>yki=9F4F}NWMb$@bCb4 zb#>)+NJ1ryF`|dDa3W%b#YQv36*ivZGpmQrC}-rPyS=^jj0uIxGkQz}ilPVx5K2A^ z-WUN2V~7}6zC-0q+>a2j=iOql@PwC2iNxV>2zMGqg-S}rO&;-_O+dw)?8#(gB2$qP zFg{OtJnQc6?z}9bLJE=+dZ0_j5l%8Dl~B0Kk4OnIA|aV7J(Z}=^EnuBb93VfpGAO5 zR25w3OH@i&t#TPfBzVroR|rrMDV)a9Fs_RQLoZ<&+SpQwl&i$#5tY+3kqytq3CVpT zryQV?`FtL7Nph4UG^XbP#z>yWLMenO#qh`V&F~2+kr<6pHQ9afMBss#v6L``RW(b} zV9aR*Ew`kk#HCQ~%NZ$Q2$;;;*E;nE18joC9B-fLv2wt!$GyaIz8tVls$v4k*V`_UuT$1lK zralKm6e1;?C_+X(KklBk6;loo5gmuw(J-f-BVJX3Xw~-f_p(_DXpvE9R87i!TODx~ zh&6j2AbggLGzX{^$HRyvD}^SESih2(XvrDL6K(ro+h?9xpL#F%P?5Am2gQpmj=;)Ix{;r5;!#DE)719V^3_(24GkcZY zniFkfqGx<2%<*@{dn9DZ$^G;9!f&?~M#U5fKd0Lo6+be)gG*A3CpoHW9Slg3U|Qy3 zy=_`ln+kQ?gzv=8b(oiBIMezDn>A+!j|MIH@y1X5G~>6SN0uh2n=a`47yG(;Xs8*e Q=>Px#07*qoM6N<$g0R=B&Hw-a literal 0 HcmV?d00001 diff --git a/tiles2/1.png b/tiles2/1.png new file mode 100644 index 0000000000000000000000000000000000000000..f8e7142b0bce6b6b1db68d5aa7d7bdea44e82c0d GIT binary patch literal 576 zcmV-G0>AxPx$`bk7VR9Hu?SHY3OKnzP-NiC^`7OsVBfm)7QXoVYRMj1uO_AWy{K=5W|*^)Qw z&%V7Db8XwUJfF`c`@|CbAHU-}U$I^<#Kh<0@mSXTZGUAOWx!(LzV8P>UF+7wJ%C}5 zac>OQ#eV<}PzQhv^{4k8cmN9kxQ-G4fKuaX+5tjcyEy*;ECJud zgR(Mo<*6E2!Ic1JGI8BL(x7Jvv^H4@xXCIBj(|T^kOgN{BQt=EX8QCxd|-BQ2B7K< z-2+V89;g7RoRIy@=FHqzeSkr9fB@Y3K<4M8^I@i>Z4a;~RKqD$q%&f~%%%WTnK97L z$SzL!_j@ZjAwVzgq@$Mz9o}yDq6Y{i~^G10|%nTY#*_+Y2gDXHpl>;b$z4|LMX7;rtH45+!S{Lkx$LoFR+2*n@oJ_T)mpRvXPur&kd>nJ$3Fks9vZj1CncW% O0000Px%bxA})R9HvFmphNzKoEduOWdq)g2E`|fy92+_Z9`)-pJ%?Au?f8MA@3X?f&dUr)9xfm z^86o$Ven?NxdZ_2`R4x_Bh_m4D`V_l6h+xf7LUgxaL)fyO26y6zV-#6Y1#)$=}%df zKdY)rD5b6#_WL~yheOzGHVeksmsYE_@&!<@*HZxSST2_n@p#--aQmY5dJXM%dvQ1% zf&)}4l|&>G@dRKP26Q@|#df>(1#l9NQksfJqYuSmQSl03nkIC+U0)d>wOZ|yWm)%> zQiTw5CyJuy5~2sY-OjrMe#LXv2cLd07=US--T{J$XG389+wFER06*dd1qdvj4dAB) zp@?S#91aJtEX#cfP&`7&?_@Ih5htG4mU9Y-^&gE!kk98~JRZAi!!X2i*1yx~06xTO zwQ|=5Hb4;ZoB-H@xm?aw2C(+BX__fnmVcE>CC?M@d_H3d`u)DE41_8kf}nhO1VOm& zK$zlTP;5P1!T26fC=_5eo56CqeA|IA#dEIUUatp=qCh5-!ESiF3&IrdTtIBWOBo1V zJSPAOcqs$V?|AWeTmb+%5{Yji{>EalQwvU~)0Yl}EuM4sI|ZE2=a5dP zUpjCL@m>=!o6WwSfm;IL0Ky4~0$$9(@pybpBoYZ#RZr6@tYc3=Q4~)OOePb)-EN-* zLHPFkU4RYPx$WJyFpR9HvtSIe!!Knx_Uq?0FwRya_=(<=c{00&wj5}l+Ko*Xe~7}?n$kPBRZ zU3)z1^>4e){c}0&Zup)1z5P3PCflO>1km?=mhEYpT)9$xJ3t5aP)>osxRR*Vp9!dP z>V!(7MgKtnAmccO*I+V`GZG*Yt$RSjFwD!)eb+7K8$$;u$IEG{Y3(v7fh*y zz!F9l0S-oETSFki_!21c0K2gy1`x!Yvbe?~0$dX~-x-)Q11SaqoVHa2G#wcb#91^cPKUC)y8sdE}j>>S%CF)KWCO$|7m6{^?ew{fHt1!gt_|j;}pUzPwW=meA@bssI20 literal 0 HcmV?d00001 diff --git a/tiles2/3.png b/tiles2/3.png new file mode 100644 index 0000000000000000000000000000000000000000..339c1c358d1dfc99d5846ac3ee76a8eb19451464 GIT binary patch literal 368 zcmV-$0gwKPP)Px$DoI2^R9HvdmfaBoAq<4MR<0LYu?9=H8Z5yUtVLSU$wQnFlJFOXQy!he{Wg#d zTUfld#rom@0bY;eO{*MZ442#eYkYbIP6M#-djR*-E5IE?`9U1k5&%mapxo~Otw2e9 zDnO068vsA|TJb3WE#gIh=W$I!1OtTzN$Z*O=oc>`3Zomo3EBx4e$(n4_h$egPK7xH z73UCApO=PR0zkzz_ZTQ1;@KTQ083nQgjL}&TXLTO)VkLIdc;M5R`*)Q*SI$Utji{K zg=+=qnQHCE9~tfdS_Pyd>S@0$2YCPg48TYWsyTq=u_Bj#(|i6!0DjxGSat&dL4FtD zH)|e@h^yncG!5xIprg4qeCT)Pn^1cXz#ohXtK+!G#mFA_S8*%Ma)1v_OByiC=&{`Z O0000Px$P)S5VR9HvtmfaD9AP|JH6>6atQ;Q{FE3p<@N#>Z@WOJC|M;Ko8B~D1#&%I;M z#Z11I$@=C7Xm-G_?#1PM>_^jHO@9(`XRJ9)hzz)Ifv5(6xV6^o;fUh61gI9G3V_8$ z6aWGvR|p}xQUg&8pv565p8#lqC<6E=&ON{o;*4A+PGkWfjyj7FtVaC>MDBSFzt2Gw z0j$X6YHpz~n}0DG>V6dgY-h$^p< zY{c_Gk6u~lcLKQ4Z5ZK6^h8A0%lg@@$jt!kn@vKlJwx!Z!xSt{Ki;znK$oNC*y#YN z24ILOFQwcQ{4}LC@8Up`!xz42D!~Q4b%707*qoM6N<$f?Gzd AX#fBK literal 0 HcmV?d00001 diff --git a/tiles2/5.png b/tiles2/5.png new file mode 100644 index 0000000000000000000000000000000000000000..792ed88c56a26fe861b8cc51f302981cd5e0e0de GIT binary patch literal 396 zcmV;70dxL|P)Px$MoC0LR9Hvtmfa15FbsrY6=q3CU=|pGoyrQVz$&dwb@XIX#STPyD0!k3av$fv zv`yvJR@QHBfNTe*>G|UDjpHM0udJ6uyz9DVUDu5P*EtYw0Emm|JkOgrmH>GnTmURC zq5uGd0C*wH09srCI0Wp12f_sKOPm2X6yl65OE{7M5XToH;Y6IRmuvnH&!-?v03n3k zZdeaN{05K)y9_`J_5C;dm*tZWERN?d1wiHxL)c4C#5J6!EX^UnrjlR)fOr)E7XLp0 zsS3S{CLx!cZok|Si&cScGP(zF)_aGOsS0>gZbTf0qKybGHWrE)5N3pB9wZ%*O?V#Y z(Tf9#cEchr&zJxhA_fqf6{!{sU^9S*;IiZ6nddCvasb2}j-0a7{ln#ehNyBVWeENn qYs@ck;LwLp+{g;S3BALsH^3X6{arA>nyOL&0000Px$T}ebiR9HvtmfI16Fcd^X3AT_HXu%d}0b9`$paf{a7HEMMEJ5dHW|{S(lLyJ< zhy2KZWKZ_KXiVkRRMs~qK(Yhf(7ic)^KDMrE9oT>x7UkVpD)LN>lBD+0EoBNnt5ED zIFRgF|2nNvN5R_j4q(B4#{1fLKUITVnXUI2(A&LRY> z5q|-ZYkr3R`yhe3ze)AXyfGyW2MaCCv zM7%M2K9AyG)@N(QvJuY%J$mIie-prsY{LjA(Gw9_?*MPx$b4f%&R9Hvtm(dl2APhz02wlMy=m^fDBRGR2lvz4LS-}-L0zDCj zahwQ%+dCInE(h&7JFKi!GS-( z-+>=HfLCyMhNf%izsmBPz`g|}?8|LINM1x1K>XBW1sF@D|3GmbrX+wbuDm)NK0N@< zalfZS!UDj}IO0<}DmCWFL4eTdw@$un5D=K-Lev2&IybGMd_16UnlGLA)B~*hlPmyO zc5m8=mSs77_IPx$dPzhblj-&FXL?xw zJCLwtq$*ot!2qNN(FDLt=P(+`9Yi+(U;q+@7zV&WA`n9XzTWBBu0QEbIr#`FwL|7% z7y!hfszGRZ7ytQ{_t)0PLXadm0sdD;CamxkdAPf){*Ia0Jpj8|$@En94=n)sc8*IfYUAtW8g>Imy=Fqb*N`)?gA-pah| qJ_ZpW{=0Gjx+bBwnwE8s1N;FqAv7@R56cGt0000Px$8%ab#R9HvtmeCD^FbqVKQJ^C*10%2tBd`l2FhfQtj6!vhOs!C2C$x$xg&zdA zx#u$pSX210h4oDZs18&I{?CEg&HL5fF?2(6L+S!9+_vqBVO`hbUJ8f`fW$uuha?1| zg%AN!#0e$^A_~C8_vd@`_Hix?<5~z20Jh`dUDq9MJLe3CxQYWp8pS05b>HmGA>Q}B zhgg=y5Cn~pM{e@;%a134nGnhH=IDDJ=u>7Ljj0_o5=00000NkvXXu0mjfo%)eN literal 0 HcmV?d00001