mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
use mps if available
This commit is contained in:
parent
9887abcd01
commit
2357363d76
@ -55,7 +55,11 @@ from shared.image import matrix_to_image_cv
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 48
|
||||
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
device = torch.device(
|
||||
"cuda:1" if torch.cuda.is_available()
|
||||
else "mps" if torch.mps.is_available()
|
||||
else "cpu"
|
||||
)
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/vae", exist_ok=True)
|
||||
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
||||
@ -75,7 +79,7 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
print(f"Using {device.type} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user