mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 06:04:49 +08:00
15 lines
527 B
Python
15 lines
527 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torchvision.models import resnet18
|
|
|
|
class MinamoVisionModel(nn.Module):
|
|
def __init__(self, tile_types=32, out_dim=512):
|
|
super().__init__()
|
|
self.resnet = resnet18(num_classes=out_dim)
|
|
self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
|
|
|
def forward(self, x):
|
|
vision_vec = self.resnet(x)
|
|
return F.normalize(vision_vec, p=2, dim=-1) # 归一化
|