diff --git a/ginka/train_vae.py b/ginka/train_vae.py index c3d4b65..b529c71 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -56,7 +56,11 @@ BATCH_SIZE = 128 LATENT_DIM = 48 KL_BETA = 0.05 -device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.mps.is_available() + else "cpu" +) os.makedirs("result", exist_ok=True) os.makedirs("result/vae", exist_ok=True) os.makedirs("result/ginka_vae_img", exist_ok=True) @@ -76,7 +80,7 @@ def parse_arguments(): return args def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + print(f"Using {device.type} to train model.") args = parse_arguments()