mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
commit
dd9d8a3713
@ -56,7 +56,11 @@ BATCH_SIZE = 128
|
|||||||
LATENT_DIM = 48
|
LATENT_DIM = 48
|
||||||
KL_BETA = 0.05
|
KL_BETA = 0.05
|
||||||
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
device = torch.device(
|
||||||
|
"cuda:1" if torch.cuda.is_available()
|
||||||
|
else "mps" if torch.mps.is_available()
|
||||||
|
else "cpu"
|
||||||
|
)
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/vae", exist_ok=True)
|
os.makedirs("result/vae", exist_ok=True)
|
||||||
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
os.makedirs("result/ginka_vae_img", exist_ok=True)
|
||||||
@ -76,7 +80,7 @@ def parse_arguments():
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {device.type} to train model.")
|
||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user