File size: 224 Bytes
c2fb848
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
# device_config.py
import torch

def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"  # For Apple Silicon
    else:
        return "cpu"