定义训练指标 - 亚马逊 SageMaker AI

本文属于机器翻译版本。若本译文内容与英语原文存在差异,则一律以英文原文为准。

定义训练指标

SageMaker AI 会自动解析训练作业日志并将训练指标发送到 CloudWatch。默认情况下,SageMaker AI 会发送 SageMaker AI Jobs and Endpoint Metrics 中列出的系统资源利用率指标。如果您希望 SageMaker AI 解析日志并将您自己算法的训练作业中的自定义指标发送到 CloudWatch,则需要通过在配置 SageMaker AI 训练作业请求时传递指标名称和正则表达式来指定指标定义。

您可以使用 SageMaker AI 控制台、SageMaker AI Python SDK 或低级 SageMaker AI API 来指定要跟踪的指标。

如果您使用的是自己的算法,请执行以下操作:

  • 确保算法将要捕获的指标写入日志。

  • 定义一个正则表达式以准确地搜索日志,来捕获要发送到 CloudWatch 的指标的值。

例如,假定您的算法发出有关训练错误和验证错误的以下指标:

Train_error=0.138318; Valid_error=0.324557;

如果您想在 CloudWatch 中监控这两个指标,指标定义的字典应类似于以下示例:

[ { "Name": "train:error", "Regex": "Train_error=(.*?);" }, { "Name": "validation:error", "Regex": "Valid_error=(.*?);" } ]

在上一示例中定义的 train:error 指标的正则表达式中,第一部分查找确切的文本“Train_error=”,表达式 (.*?); 捕获任意字符,直至第一个分号字符出现。在此表达式中,括号告诉正则表达式捕获括号内的内容,. 表示任何字符,* 表示零或更多个字符,? 表示仅捕获到遇到 ; 字符的第一个实例为止。

使用 SageMaker AI Python SDK 定义指标

通过在初始化 Estimator 对象时将一系列指标名称和正则表达式指定为 metric_definitions 参数,来定义要发送到 CloudWatch 的指标。例如,如果您想在 CloudWatch 中监控 train:errorvalidation:error 指标,则您的 Estimator 初始化将与以下示例类似:

import sagemaker from sagemaker.estimator import Estimator estimator = Estimator( image_uri="your-own-image-uri", role=sagemaker.get_execution_role(), sagemaker_session=sagemaker.Session(), instance_count=1, instance_type='ml.c4.xlarge', metric_definitions=[ {'Name': 'train:error', 'Regex': 'Train_error=(.*?);'}, {'Name': 'validation:error', 'Regex': 'Valid_error=(.*?);'} ] )

有关使用 Amazon SageMaker Python SDK 估算器进行训练的更多信息,请参阅 GitHub 上的 Sagemaker Python SDK

使用 SageMaker AI 控制台定义指标

如果您在创建训练作业时选择 ECR 中您自己的算法容器选项作为 SageMaker AI 控制台中的算法源,请在指标部分中添加指标定义。以下屏幕截图显示了其在添加示例指标名称和相应的正则表达式后的外观。

管理控制台中的算法选项表单示例。

使用低级 SageMaker AI API 定义指标

通过在您传递到 CreateTrainingJob 操作的 AlgorithmSpecification 输入参数的 MetricDefinitions 字段中指定指标名称和正则表达式的列表,定义要发送到 CloudWatch 的指标。例如,如果您想在 CloudWatch 中监控 train:errorvalidation:error 指标,则您的 AlgorithmSpecification 将与以下示例类似:

"AlgorithmSpecification": { "TrainingImage": your-own-image-uri, "TrainingInputMode": "File", "MetricDefinitions" : [ { "Name": "train:error", "Regex": "Train_error=(.*?);" }, { "Name": "validation:error", "Regex": "Valid_error=(.*?);" } ] }

有关使用低级 SageMaker AI API 定义和运行训练作业的更多信息,请参阅 CreateTrainingJob