ginka-generator/ginka/test/model.py
2025-03-15 22:26:31 +08:00

46 lines
1.3 KiB
Python

import torch
from ..model.model import DynamicPadConv, ConditionInjector, HybridUpsample
def test_dynamic_conv():
conv = DynamicPadConv(3, 64, stride=2)
# 测试奇数尺寸
x = torch.randn(1, 3, 15, 17)
out = conv(x)
assert out.shape == (1, 64, 8, 9), f"Got {out.shape}"
# 测试偶数尺寸
x = torch.randn(1, 3, 16, 16)
out = conv(x)
assert out.shape == (1, 64, 8, 8)
def test_condition_injector():
injector = ConditionInjector(128, 256)
x = torch.randn(2, 256, 16, 16)
cond = torch.randn(2, 128)
out = injector(x, cond)
assert out.shape == x.shape
assert not torch.allclose(out, x) # 确保条件起作用
def test_hybrid_upsample():
# 带跳跃连接的情况
upsample = HybridUpsample(256, 128, skip_ch=64)
x = torch.randn(2, 256, 8, 8)
skip = torch.randn(2, 64, 16, 16)
out = upsample(x, skip)
assert out.shape == (2, 128, 16, 16)
# 无跳跃连接的情况
upsample = HybridUpsample(256, 128)
out = upsample(x)
assert out.shape == (2, 128, 16, 16)
def test_all():
test_dynamic_conv()
print("✅ 动态卷积测试完毕")
test_condition_injector()
print("✅ 条件注入测试完毕")
test_hybrid_upsample()
print("✅ 混合上采样测试完毕")