본문 바로가기
AI Codes/PyTorch

ML와 DL에서의 seed 고정

by 아르카눔 2025. 1. 2.

난수 생성, Random Number Generation (RNG)에 쓰이는 seed 시드를 고정한다.

 

시드를 고정하면 랜덤 넘버가 생성되는 패턴이 고정되고 재현성이 생긴다.

 

다음의 패키지에 대해서 적용하면 되는듯 하다.

 

  • random
  • numpy
  • sklearn
  • torch
  • torch.cuda

LLM의 경우 multi-gpu 사용시 모든 GPU들에 대한 시드 고려 필요 

 

# Fix the Seed
def set_seed(seed: int = 42):
    """
    PyTorch, NumPy, Python Random, CUDA의 모든 시드를 고정하는 함수
    
    Args:
        seed (int): 고정할 시드값 (기본값: 42)
    """
    # Python Random 시드 고정
    random.seed(seed)
    
    # NumPy 시드 고정
    np.random.seed(seed)
    
    # PyTorch 시드 고정
    torch.manual_seed(seed)
    
    # CUDA 시드 고정 (GPU 사용 시)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # 멀티 GPU 사용 시
        
        # CUDA의 랜덤성 완전 제거
        #torch.backends.cudnn.deterministic = True
        #torch.backends.cudnn.benchmark = False
    
    # scikit-learn의 global random state 고정
    # (일부 함수에서만 적용됨)
    try:
        from sklearn.utils import check_random_state
        check_random_state(seed)
    except ImportError:
        pass

seed = 42
set_seed(seed)

 

 

torch.backends.cudnn.deterministic

 

True라면, 

  • CuDNN에서 항상 고정된 난수를 생성
  • 완벽한 재현성 
  • 일부 비결정론적 알고리즘에 의존하는 사항들의 성능 저하 
  • 속도 저하 

False라면,

  • CuDNN이 항상 고정되지 않은 난수 생성
  • 비교적 덜 정확한 재현성
  • 속도 향상

 

cudnn.benchmark

 

True라면,

  • Benchmarking overhead occurs only in the first few iterations.
  • Training becomes faster for consistent input sizes.

False라면,

  • PyTorch uses a default algorithm for all operations, which might not be optimal.

 

 

 

 

References:

https://knowing.tistory.com/26

https://medium.com/@adhikareen/why-and-when-to-use-cudnn-benchmark-true-in-pytorch-training-f2700bf34289

https://hyunhp.tistory.com/749

https://pytorch.org/docs/stable/notes/randomness.html

 

'AI Codes > PyTorch' 카테고리의 다른 글

PyTorch pre-trained Models  (0) 2025.04.17
PyTorch Tensor의 차원 변환  (0) 2025.04.16
PyTorch repeat, repeat_interleave, expand 차이  (0) 2024.11.01
ResNet (2016) PyTorch Implementation  (0) 2024.04.24
SPPNet(2014) PyTorch Implementation  (0) 2024.04.12