diff --git a/ginka/train_vae.py b/ginka/train_vae.py index ed3126e..3eda0c0 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -55,7 +55,11 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 128 LATENT_DIM = 48 -device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") +device = torch.device( + "cuda:1" if torch.cuda.is_available() + else "mps" if torch.mps.is_available() + else "cpu" +) os.makedirs("result", exist_ok=True) os.makedirs("result/vae", exist_ok=True) os.makedirs("result/ginka_vae_img", exist_ok=True) @@ -75,7 +79,7 @@ def parse_arguments(): return args def train(): - print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") + print(f"Using {device.type} to train model.") args = parse_arguments()