ginka-generator/minamo/model/vision.py

15 lines
527 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, out_dim=512):
super().__init__()
self.resnet = resnet18(num_classes=out_dim)
self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False)
def forward(self, x):
vision_vec = self.resnet(x)
return F.normalize(vision_vec, p=2, dim=-1) # 归一化