mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 05:11:10 +08:00
chore: 条件编码器的激活函数改为 gelu
This commit is contained in:
parent
fee2bcf344
commit
250c2d5f67
@ -1,5 +1,3 @@
|
||||
import { GinkaTopologicalGraphs } from '../topology/interface';
|
||||
|
||||
export const enum TowerColor {
|
||||
White,
|
||||
Orange,
|
||||
|
||||
@ -10,7 +10,7 @@ class GinkaMaskGITCond(nn.Module):
|
||||
nn.Linear(cond_dim, output_dim // 2),
|
||||
nn.Dropout(0.3),
|
||||
nn.LayerNorm(output_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Linear(output_dim // 2, output_dim)
|
||||
)
|
||||
@ -18,11 +18,11 @@ class GinkaMaskGITCond(nn.Module):
|
||||
self.heatmap_conv = nn.Sequential(
|
||||
nn.Conv2d(heatmap_channel, output_dim // 4, kernel_size=3, padding=1, padding_mode='replicate'),
|
||||
nn.BatchNorm2d(output_dim // 4),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(output_dim // 4, output_dim // 2, kernel_size=3, padding=1, padding_mode='replicate'),
|
||||
nn.BatchNorm2d(output_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(output_dim // 2, output_dim, kernel_size=3, padding=1, padding_mode='replicate')
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user