Merge pull request #1 from unanmed/feat-add-mps

use mps if available
This commit is contained in:
AncTe 2026-02-06 23:50:25 +08:00 committed by GitHub
commit dd9d8a3713
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -56,7 +56,11 @@ BATCH_SIZE = 128
LATENT_DIM = 48
KL_BETA = 0.05
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
device = torch.device(
"cuda:1" if torch.cuda.is_available()
else "mps" if torch.mps.is_available()
else "cpu"
)
os.makedirs("result", exist_ok=True)
os.makedirs("result/vae", exist_ok=True)
os.makedirs("result/ginka_vae_img", exist_ok=True)
@ -76,7 +80,7 @@ def parse_arguments():
return args
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
print(f"Using {device.type} to train model.")
args = parse_arguments()