mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
import torch
|
|
from ..model.model import DynamicPadConv, ConditionInjector, HybridUpsample
|
|
|
|
def test_dynamic_conv():
|
|
conv = DynamicPadConv(3, 64, stride=2)
|
|
|
|
# 测试奇数尺寸
|
|
x = torch.randn(1, 3, 15, 17)
|
|
out = conv(x)
|
|
assert out.shape == (1, 64, 8, 9), f"Got {out.shape}"
|
|
|
|
# 测试偶数尺寸
|
|
x = torch.randn(1, 3, 16, 16)
|
|
out = conv(x)
|
|
assert out.shape == (1, 64, 8, 8)
|
|
|
|
def test_condition_injector():
|
|
injector = ConditionInjector(128, 256)
|
|
x = torch.randn(2, 256, 16, 16)
|
|
cond = torch.randn(2, 128)
|
|
|
|
out = injector(x, cond)
|
|
assert out.shape == x.shape
|
|
assert not torch.allclose(out, x) # 确保条件起作用
|
|
|
|
def test_hybrid_upsample():
|
|
# 带跳跃连接的情况
|
|
upsample = HybridUpsample(256, 128, skip_ch=64)
|
|
x = torch.randn(2, 256, 8, 8)
|
|
skip = torch.randn(2, 64, 16, 16)
|
|
out = upsample(x, skip)
|
|
assert out.shape == (2, 128, 16, 16)
|
|
|
|
# 无跳跃连接的情况
|
|
upsample = HybridUpsample(256, 128)
|
|
out = upsample(x)
|
|
assert out.shape == (2, 128, 16, 16)
|
|
|
|
def test_all():
|
|
test_dynamic_conv()
|
|
print("✅ 动态卷积测试完毕")
|
|
test_condition_injector()
|
|
print("✅ 条件注入测试完毕")
|
|
test_hybrid_upsample()
|
|
print("✅ 混合上采样测试完毕")
|
|
|