chore: 调整部分依赖

This commit is contained in:
unanmed 2026-03-12 20:41:03 +08:00
parent 513f27c7ac
commit d0f86018f1
5 changed files with 5 additions and 13 deletions

View File

@ -1,12 +1,6 @@
import json import json
import math
import random
import torch import torch
import torch.nn.functional as F
from torch.utils.data import Dataset from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
from typing import List
def load_data(path: str): def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:

View File

@ -1,5 +1,4 @@
import random import random
import torch
import numpy as np import numpy as np
from scipy.ndimage import binary_dilation, binary_erosion from scipy.ndimage import binary_dilation, binary_erosion

View File

@ -9,8 +9,8 @@ import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
import cv2 import cv2
import numpy as np import numpy as np
from torch_geometric.loader import DataLoader
from tqdm import tqdm from tqdm import tqdm
from torch.utils.data import DataLoader
from .maskGIT.model import GinkaMaskGIT from .maskGIT.model import GinkaMaskGIT
from .dataset import GinkaMaskGITDataset from .dataset import GinkaMaskGITDataset
from shared.image import matrix_to_image_cv from shared.image import matrix_to_image_cv
@ -61,7 +61,7 @@ disable_tqdm = not sys.stdout.isatty()
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="training codes") parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/vae/ginka-100.pth") parser.add_argument("--state_ginka", type=str, default="result/transformer/ginka-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json") parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json") parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--epochs", type=int, default=100)

View File

@ -1,7 +1,7 @@
torch torch
torchvision
torchaudio
tqdm tqdm
torch-geometric torch-geometric
transformers transformers
torch-scatter scipy
numpy
cv2

View File

@ -1 +0,0 @@
python3 -u -m ginka.train_wgan --epochs 200 --checkpoint 20 --resume true --state_ginka result/wgan/ginka-400.pth --state_minamo result/wgan/minamo-400.pth >> output.log