

# HyperPod managed tiered checkpointing
<a name="managed-tier-checkpointing"></a>

This section explains how managed tiered checkpointing works and the benefits it provides for large-scale model training.

Amazon SageMaker HyperPod managed tiered checkpointing helps you train large-scale generative AI models more efficiently. It uses multiple storage tiers, including your cluster’s CPU memory. This approach reduces your time to recovery and minimizes loss in training progress. It also uses underutilized memory resources in your training infrastructure.

Managed tiered checkpointing enables saving checkpoints at a higher frequency to memory. It periodically persists them to durable storage. This maintains both performance and reliability during your training process.

This guide covers how to set up, configure, and use managed tiered checkpointing with PyTorch frameworks on Amazon EKS HyperPod clusters.

## How managed tiered checkpointing works
<a name="managed-tier-checkpointing-works"></a>

Managed tiered checkpointing uses a multi-tier storage approach. CPU memory serves as the primary tier to store model checkpoints. Secondary tiers include persistent storage options like Amazon S3.

When you save a checkpoint, the system stores it in allocated memory space across your cluster nodes. It automatically replicates data across adjacent compute nodes for enhanced reliability. This replication strategy protects against single or multiple node failures while providing fast access for recovery operations.

The system also periodically saves checkpoints to persistent storage according to your configuration. This ensures long-term durability of your training progress.

Key components include:
+ **Memory management system**: A memory management daemon that provides disaggregated memory as a service for checkpoint storage
+ **HyperPod Python library**: Interfaces with the disaggregated storage APIs and provides utilities for saving, loading, and managing checkpoints across tiers
+ **Checkpoint replication**: Automatically replicates checkpoints across multiple nodes for fault tolerance

The system integrates seamlessly with PyTorch training loops through simple API calls. It requires minimal changes to your existing code.

## Benefits
<a name="managed-tier-checkpointing-benefits"></a>

Managed tiered checkpointing delivers several advantages for large-scale model training:
+ **Improved usability**: Manages checkpoint save, replication, persistence, and recovery
+ **Faster checkpoint operations**: Memory-based storage provides faster save and load times compared to disk-based checkpointing, leading to faster recovery
+ **Fault tolerance**: Automatic checkpoint replication across nodes protects against hardware node failures
+ **Minimal code changes**: Simple API integration requires only minor modifications to existing training scripts
+ **Improved training throughput**: Reduced checkpoint overhead means more time spent on actual training

**Topics**
+ [How managed tiered checkpointing works](#managed-tier-checkpointing-works)
+ [Benefits](#managed-tier-checkpointing-benefits)
+ [Set up managed tiered checkpointing](managed-tier-checkpointing-setup.md)
+ [Removing managed tiered checkpointing](managed-tier-checkpointing-remove.md)
+ [Security considerations for managed tiered checkpointing](managed-tier-security-considerations.md)

# Set up managed tiered checkpointing
<a name="managed-tier-checkpointing-setup"></a>

This section contains setup process for managed tiered checkpointing for Amazon SageMaker HyperPod. You’ll learn how to enable the capability on your cluster and implement checkpointing in your training code.

**Topics**
+ [Prerequisites](#managed-tier-checkpointing-setup-prerequisites)
+ [Step 1: Enable managed tiered checkpointing for your cluster](#managed-tier-checkpointing-setup-step-enable-for-cluster)
+ [Step 2: Install the Python library in your training image](#managed-tier-checkpointing-setup-step-install-library)
+ [Step 3: Save checkpoints in your training loop](#managed-tier-checkpointing-setup-step-save-checkpoint-in-loop)
+ [Step 4: Load checkpoints for recovery](#managed-tier-checkpointing-setup-step-load-checkpoint)
+ [Validate your managed tiered checkpointing operations](#managed-tier-checkpointing-setup-validation)

## Prerequisites
<a name="managed-tier-checkpointing-setup-prerequisites"></a>

Before setting up managed tiered checkpointing, ensure you have:
+ An Amazon EKS HyperPod cluster with sufficient CPU memory available for checkpoint allocation
+ PyTorch training workloads and DCP jobs (both are supported)
+ Appropriate IAM permissions for cluster management, including:
  + Amazon CloudWatch and Amazon S3 write permissions for the training pod to read/write checkpoints and push metrics
  + These permissions can be configured via [EKS OIDC setup](https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts.html)

## Step 1: Enable managed tiered checkpointing for your cluster
<a name="managed-tier-checkpointing-setup-step-enable-for-cluster"></a>

**Important**  
You must opt in to use managed tiered checkpointing.

Enable managed tiered checkpointing through the HyperPod APIs when creating or updating your cluster. The service automatically installs the memory management system when you specify the `TieredStorageConfig` parameter.

For new clusters, you can use [https://docs.aws.amazon.com/cli/latest/reference/sagemaker/create-cluster.html](https://docs.aws.amazon.com/cli/latest/reference/sagemaker/create-cluster.html) AWS CLI.

```
aws sagemaker create-cluster \
    --cluster-name cluster-name \
    --orchestrator "Eks={ClusterArn=eks-cluster-arn}" \
    --instance-groups '{
        "InstanceGroupName": "instance-group-name",
        "InstanceType": "instance-type",
        "InstanceCount": instance-count,
        "LifeCycleConfig": {
            "SourceS3Uri": "s3-path-to-lifecycle-scripts",
            "OnCreate": "lifecycle-script-name"
        },
        "ExecutionRole": "instance-group-iam-role",
        "ThreadsPerCore": threads-per-core,
        "InstanceStorageConfigs": [
            { "EbsVolumeConfig": {"VolumeSizeInGB": volume-size} }
        ]
    }' \
    --vpc-config '{
        "SecurityGroupIds": ["security-group-ids"],
        "Subnets": ["subnets"]
    }' \
    --tiered-storage-config '{
        "Mode": "Enable"
    }'
```

The `InstanceMemoryAllocationPercentage` parameter specifies the `percentage` (int) of cluster memory to allocate for checkpointing. The range is 20-100.

## Step 2: Install the Python library in your training image
<a name="managed-tier-checkpointing-setup-step-install-library"></a>

Install the [Amazon SageMaker checkpointing library](https://pypi.org/project/amzn-sagemaker-checkpointing/) and its dependencies in your training image by adding it to your Dockerfile:

```
# Add this line to your training image Dockerfile
RUN pip install amzn-sagemaker-checkpointing s3torchconnector tenacity torch boto3 s3torchconnector
```

## Step 3: Save checkpoints in your training loop
<a name="managed-tier-checkpointing-setup-step-save-checkpoint-in-loop"></a>

In your training loop, you can asynchronously save checkpoints using PyTorch DCP. The following is an example on how to do so.

```
import torch
import torch.distributed as dist
from torch.distributed.checkpoint import async_save, load
from amzn_sagemaker_checkpointing.checkpointing.filesystem.filesystem import (
    SageMakerTieredStorageWriter,
    SageMakerTieredStorageReader
)

# Initialize distributed training
dist.init_process_group(backend="nccl")

# Configure checkpointing
checkpoint_config = SageMakerCheckpointConfig(
    # Unique ID for your training job 
    # Allowed characters in ID include: alphanumeric, hyphens, and underscores
    namespace=os.environ.get('TRAINING_JOB_NAME', f'job-{int(time.time())}'),

    # Number of distributed processes/available GPUs
    world_size=dist.get_world_size(),

    # S3 storage location, required for SageMakerTieredStorageReader for read fallbacks
    # Required for SageMakerTieredStorageWriter when save_to_s3 is True
    s3_tier_base_path="s3://my-bucket/checkpoints"
)

# Your model and optimizer
model = MyModel()
optimizer = torch.optim.AdamW(model.parameters())

# Training loop
future = None
in_memory_ckpt_freq = 10
s3_ckpt_freq = 50

for training_step in range(1000):
    # ... training code ...
    
    # Save checkpoint
    if (training_step % in_memory_ckpt_freq == 0 or 
        training_step % s3_ckpt_freq == 0):
        # Create state dictionary
        state_dict = {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "step": training_step,
            "epoch": epoch
        }
        
        # Create storage writer for current step
        checkpoint_config.save_to_s3 = training_step % s3_ckpt_freq == 0
        storage_writer = SageMakerTieredStorageWriter(
            checkpoint_config=checkpoint_config,
            step=training_step
        )

        # wait for previous checkpoint to get completed
        if future is not None:
            exc = future.exception()
            if exc:
                print(f"Failure in saving previous checkpoint:{str(exc)}")
                # Handle failures as required
            else:
                result = future.result()
                # Process results from save, if required
        
        # Async save checkpoint using PyTorch DCP
        future = async_save(state_dict=state_dict, storage_writer=storage_writer)
        
        # Continue training while checkpoint saves in background
```

## Step 4: Load checkpoints for recovery
<a name="managed-tier-checkpointing-setup-step-load-checkpoint"></a>

The following is an example on loading a checkpoint.

```
# Create state dictionary template
state_dict = {
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    "step": 0,
    "epoch": 0
}

# Load latest checkpoint
storage_reader = SageMakerTieredStorageReader(checkpoint_config=checkpoint_config)
load(state_dict, storage_reader=storage_reader)

# Load specific checkpoint step
storage_reader = SageMakerTieredStorageReader(
    checkpoint_config=checkpoint_config, 
    step=500 # Or don't pass step if you have to load the latest available step.
)
try:
    load(state_dict, storage_reader=storage_reader)
except BaseException as e:
    print(f"Checkpoint load failed: {str(e)}")
    # Add additional exception handling
```

## Validate your managed tiered checkpointing operations
<a name="managed-tier-checkpointing-setup-validation"></a>

You can validate your managed tiered checkpointing operations with logs.

**Custom logging (optional)**

You can integrate checkpointing logs with other logs by passing a custom logger to the library. For example, you can add a custom logger to your training code so that all logs from the library are also collected in the training logger.

**Enhanced service logging (optional)**

For enhanced debugging and service visibility, you can mount the checkpointing log path `/var/log/sagemaker_checkpointing` from within your pod to a path `/var/logs/sagemaker_checkpointing` on your host. This ensures that only library-specific logs are collected separately. This provides the service team with enhanced visibility for debugging and support.

# Removing managed tiered checkpointing
<a name="managed-tier-checkpointing-remove"></a>

This section explains how to disable managed tiered checkpointing when you no longer need it.

To disable managed tiered checkpointing, use the [https://docs.aws.amazon.com/cli/latest/reference/sagemaker/update-cluster.html](https://docs.aws.amazon.com/cli/latest/reference/sagemaker/update-cluster.html) AWS CLI to update your cluster configuration:

```
aws sagemaker update-cluster \
    --cluster-name cluster-name \
    --tiered-storage-config '{ "Mode": "Disable" }'
```

This removes the memory management daemon from your cluster. The daemon is implemented as a standard Kubernetes DaemonSet and follows standard Kubernetes lifecycle management.

# Security considerations for managed tiered checkpointing
<a name="managed-tier-security-considerations"></a>

This section covers important security considerations when using managed tiered checkpointing. It includes Python pickle usage, Amazon S3 encryption, and network endpoint security.

**Python pickle usage**

Managed tiered checkpointing uses Python’s pickle module to deserialize checkpoint data stored in Amazon S3. This implementation has important security implications:
+ **Extended trust boundary**: When using managed tiered checkpointing with Amazon S3, the Amazon S3 bucket becomes part of your cluster’s trust boundary.
+ **Code execution risk**: Python’s pickle module can execute arbitrary code during deserialization. If an unauthorized user gains write access to your checkpoint Amazon S3 bucket, they could potentially craft malicious pickle data that executes when loaded by managed tiered checkpointing.

**Best practices for Amazon S3 storage**

When using managed tiered checkpointing with Amazon S3 storage:
+ **Restrict Amazon S3 bucket access**: Ensure that only authorized users and roles associated with your training cluster have access to the Amazon S3 bucket used for checkpointing.
+ **Implement bucket policies**: Configure appropriate bucket policies to prevent unauthorized access or modifications.
+ **Validate access patterns**: Implement logging for validating access patterns to your checkpoint Amazon S3 buckets.
+ **Validate bucket names**: Use caution with bucket name selection to avoid potential bucket hijacking.

**Network endpoints**

Managed tiered checkpointing enables network endpoints on each of your compute nodes on the following ports: 9200/TCP, 9209/UDP, 9210/UDP, 9219/UDP, 9220/UDP, 9229/UDP, 9230/UDP, 9239/UDP, 9240/UDP. These ports are necessary for the checkpointing service to function and maintain data synchronization.

By default, SageMaker’s network configuration restricts access to these endpoints for security purposes. We recommend that you maintain these default restrictions.

When configuring your network settings for your nodes and VPC, follow AWS best practices for VPCs, security groups, and ACLs. For more information, see the following:
+ [Amazon SageMaker HyperPod prerequisites](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-hyperpod-prerequisites.html#sagemaker-hyperpod-prerequisites-optional-vpcCluster)
+ [VPC security best practices](https://docs.aws.amazon.com/vpc/latest/userguide/vpc-security-best-practices.html)