mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
chore: 条件编码器的激活函数改为 gelu
This commit is contained in:
parent
fee2bcf344
commit
250c2d5f67
@ -1,5 +1,3 @@
|
|||||||
import { GinkaTopologicalGraphs } from '../topology/interface';
|
|
||||||
|
|
||||||
export const enum TowerColor {
|
export const enum TowerColor {
|
||||||
White,
|
White,
|
||||||
Orange,
|
Orange,
|
||||||
|
|||||||
@ -10,7 +10,7 @@ class GinkaMaskGITCond(nn.Module):
|
|||||||
nn.Linear(cond_dim, output_dim // 2),
|
nn.Linear(cond_dim, output_dim // 2),
|
||||||
nn.Dropout(0.3),
|
nn.Dropout(0.3),
|
||||||
nn.LayerNorm(output_dim // 2),
|
nn.LayerNorm(output_dim // 2),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Linear(output_dim // 2, output_dim)
|
nn.Linear(output_dim // 2, output_dim)
|
||||||
)
|
)
|
||||||
@ -18,11 +18,11 @@ class GinkaMaskGITCond(nn.Module):
|
|||||||
self.heatmap_conv = nn.Sequential(
|
self.heatmap_conv = nn.Sequential(
|
||||||
nn.Conv2d(heatmap_channel, output_dim // 4, kernel_size=3, padding=1, padding_mode='replicate'),
|
nn.Conv2d(heatmap_channel, output_dim // 4, kernel_size=3, padding=1, padding_mode='replicate'),
|
||||||
nn.BatchNorm2d(output_dim // 4),
|
nn.BatchNorm2d(output_dim // 4),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Conv2d(output_dim // 4, output_dim // 2, kernel_size=3, padding=1, padding_mode='replicate'),
|
nn.Conv2d(output_dim // 4, output_dim // 2, kernel_size=3, padding=1, padding_mode='replicate'),
|
||||||
nn.BatchNorm2d(output_dim // 2),
|
nn.BatchNorm2d(output_dim // 2),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, padding=1, padding_mode='replicate')
|
nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, padding=1, padding_mode='replicate')
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user