報酬関数の実装 - Amazon SageMaker AI

翻訳は機械翻訳により提供されています。提供された翻訳内容と英語版の間で齟齬、不一致または矛盾がある場合、英語版が優先します。

報酬関数の実装

概要:

報酬関数 (スコアラーまたはグレーダーとも呼ばれます) は、モデルのレスポンスを評価し、トレーニング用のフィードバックシグナルを提供するコアコンポーネントです。モデルレスポンスを受け入れて報酬スコアを返す Lambda 関数として実装する必要があります。

インターフェイス形式

報酬関数は、次の形式でデータを受け入れて返す必要があります。

トレーニングへのサンプル入力サンプル

{ "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

報酬 Lambda のサンプルペイロード

コンテナは、Lambda 関数に送信する前に、次の方法でデータを自動的に変換します。

  1. 各プロンプトのモデルレスポンスの生成

  2. メッセージ配列にアシスタントターン (生成されたレスポンス) を追加する

  3. 追跡用の一意のidフィールドの追加

Lambda 関数は、次の変換された形式でデータを受け取ります。

{ "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don not have a dedicated security team..." } ], # Following section will be same as your training dataset sample "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

報酬 Lambda 契約

def lambda_handler(event, context): return lambda_grader(event) def lambda_grader(samples: list[dict]) -> list[dict]: """ Args: samples: List of dictionaries in OpenAI format Example input: { "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Company, I don nott have a dedicated security team..." } ], # This section will be same as your training dataset "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } } Returns: List of dictionaries with reward scores: { "id": str, # Same id as input sample "aggregate_reward_score": float, # Overall score for the sample "metrics_list": [ # OPTIONAL: Component scores { "name": str, # Name of the component score "value": float, # Value of the component score "type": str # "Reward" or "Metric" } ] } """

入力フィールドと出力フィールド

入力フィールド

フィールド 説明 追加のメモ
id サンプルの一意の識別子 出力にエコーバックしました。文字列形式
メッセージ OpenAI 形式の順序付けられたチャット履歴 メッセージオブジェクトの配列
messages[].role メッセージの発話者 一般的な値: "user"、"assistant"、"system"
messages[].content メッセージのテキストコンテンツ プレーン文字列
**メタデータ グレーディングに役立つ自由形式の情報 オブジェクト、トレーニングデータから渡されるオプションのフィールド

出力フィールド

フィールド 説明 追加のメモ
id 入力サンプルと同じ識別子 入力と一致する必要があります
aggregate_reward_score サンプルの全体的なスコア 浮動小数点数 (例: 0.0~1.0 またはタスク定義の範囲)
metrics_list 集計を構成するコンポーネントスコア メトリクスオブジェクトの配列

技術的な制約

  • タイムアウト制限 – Lambda 呼び出しあたりの最大実行時間 15 分

  • 同時実行rollout_worker_replicas * 64同時リクエストを処理する必要があります

  • 信頼性 – 適切なエラー処理を実装し、有効なスコアを一貫して返す必要があります

  • パフォーマンス – 効率的なトレーニングを可能にするために高速実行 (分ではなく秒) を最適化する

ベストプラクティス

  • 外部 API コールの最小化

  • 効率的なアルゴリズムとデータ構造を使用する

  • 一時的な障害に対する再試行ロジックの実装

  • 再利用可能な計算をキャッシュする

  • トレーニング前に徹底的にテストしてバグのない実行を確保する

カスタム報酬関数の使用

タスク固有の評価基準がある場合は、カスタム報酬関数を実装します。

  • 評価基準の定義 - タスクに適切な対応を行うものを決定します。

  • Lambda 関数の実装 – インターフェイス形式に従って Lambda 関数を作成する

  • ローカルでテストする – 関数がサンプル入力の正しいスコアを返すことを確認する

  • デプロイ先 AWS – Lambda をデプロイし、ARN を書き留めます。

  • レシピの設定 – Lambda ARN をレシピの reward_lambda_arnフィールドに追加する

  • 小さなデータセットでテストする – 最小限のデータで RFT を実行して統合を検証する

IAM アクセス許可

必要なアクセス許可

SageMaker 実行ロールには、Lambda 関数を呼び出すためのアクセス許可が必要です。このポリシーを SageMaker 実行ロールに追加します。

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "lambda:InvokeFunction" ], "Resource": "arn:aws:lambda:region:account-id:function:function-name" } ] }

Lambda 実行ロール

Lambda 関数の実行ロールには、基本的な Lambda 実行アクセス許可が必要です。

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:*" } ] }

追加のアクセス許可: Lambda 関数が他の AWS サービス (参照データの場合は S3、ログ記録の場合は DynamoDB など) にアクセスする場合は、それらのアクセス許可を Lambda 実行ロールに追加します。

例: LLM を審査員報酬関数として

この例では、Amazon Bedrock モデルを審査員として使用して、参照回答と比較することでモデル応答を評価します。この Lambda テンプレートは、お客様が Amazon Bedrock への呼び出しを実装して、判断評価を処理する推論リクエストを行うためのフレームワークを提供します。Lambda 関数は、他の報酬関数と同じ入出力契約を維持します。

実装

この Lambda 関数は、2 段階の評価プロセスを実装します。 はモデルレスポンスをlambda_handler抽出し、受信サンプルから回答を参照します。次に、lambda_graded関数は Amazon Bedrock を呼び出して、それらの間のセマンティック類似性をスコアリングします。この実装には、一時的な障害に対する自動再試行による堅牢なエラー処理が含まれており、柔軟なリファレンス回答形式 (文字列形式と構造化ディクショナリ形式の両方) をサポートしています。

実装の詳細:

  • 再試行ロジック: Bedrock API レート制限を処理するためのスロットリング例外のエクスポネンシャルバックオフ (1、2、4) を実装

  • エラー処理: 例外を発生させるのではなく、失敗した評価のスコア 0.0 を返します

  • 決定論的スコアリング: 温度 = 0.0 を使用して、評価間で一貫したスコアを確保します

  • 柔軟な参照形式: 文字列とディクショナリの両方の参照回答を自動的に処理します

  • スコアクランプ: すべてのスコアが有効な [0.0, 1.0] の範囲内にあることを確認します

  • Model Agnostic: JUDGE_MODEL_ID を変更して Amazon Bedrock モデル (Nova、Llama、Mistral など) を使用する

""" LLM Judge Lambda POC - Working implementation using Amazon Bedrock """ import json import time import boto3 bedrock_runtime = boto3.client('bedrock-runtime', region_name='us-east-1') JUDGE_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0" SYSTEM_PROMPT = "You must output ONLY a number between 0.0 and 1.0. No explanations, no text, just the number." JUDGE_PROMPT_TEMPLATE = """Compare the following two responses and rate how similar they are on a scale of 0.0 to 1.0, where: - 1.0 means the responses are semantically equivalent (same meaning, even if worded differently) - 0.5 means the responses are partially similar - 0.0 means the responses are completely different or contradictory Response A: {response_a} Response B: {response_b} Output ONLY a number between 0.0 and 1.0. No explanations.""" def lambda_graded(response_a: str, response_b: str, max_retries: int = 3) -> float: """Call Bedrock to compare responses and return similarity score.""" prompt = JUDGE_PROMPT_TEMPLATE.format(response_a=response_a, response_b=response_b) for attempt in range(max_retries): try: response = bedrock_runtime.converse( modelId=JUDGE_MODEL_ID, messages=[{"role": "user", "content": [{"text": prompt}]}], system=[{"text": SYSTEM_PROMPT}], inferenceConfig={"temperature": 0.0, "maxTokens": 10} ) print(f"Bedrock call successful: {response}") output = response['output']['message']['content'][0]['text'].strip() score = float(output) print(f"Score parsed: {score}") return max(0.0, min(1.0, score)) except Exception as e: if "ThrottlingException" in str(e) and attempt < max_retries - 1: time.sleep(2 ** attempt) else: print(f"Bedrock call failed: {e}") return None return None def lambda_handler(event, context): """AWS Lambda handler - processes samples from RFTEvalInvoker.""" try: samples = event if isinstance(event, list) else [event] results = [] for sample in samples: sample_id = sample.get("id", "unknown") messages = sample.get("messages", []) # Extract assistant response (response A) response_a = "" for msg in messages: if msg.get("role") in ["assistant", "nova_assistant"]: response_a = msg.get("content", "") break # Extract reference answer from root level (no longer in metadata) reference_answer = sample.get("reference_answer", "") # Handle both string and dict reference_answer formats if isinstance(reference_answer, dict): # If reference_answer is a dict, extract the explanation or compliant field response_b = reference_answer.get("explanation", reference_answer.get("compliant", "")) else: response_b = reference_answer if not response_a or not response_b: results.append({ "id": sample_id, "aggregate_reward_score": 0.0, "metrics_list": [{"name": "similarity_score", "value": 0.0, "type": "Metric"}] }) continue # Get similarity score score = lambda_graded(response_a, response_b) results.append({ "id": sample_id, "aggregate_reward_score": score, "metrics_list": [ { "name": "similarity_score", "value": score, "type": "Metric" } ] }) return {"statusCode": 200, "body": json.dumps(results)} except Exception as e: print(f"Error: {e}") return {"statusCode": 500, "body": json.dumps({"error": str(e)})}

入力形式

Lambda は、他の報酬関数と同じ入力形式を受け取ります。

{ "id": "sample-001", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don't have a dedicated security team..." } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." }, "my_custom_field": "custom_value" }

出力形式

{ "id": "sample-001", "aggregate_reward_score": 0.85, "metrics_list": [ { "name": "similarity_score", "value": 0.85, "type": "Metric" } ] }

デプロイに関する考慮事項

選択したモデルの機能と API 形式に基づいて、プロンプトテンプレートと推論パラメータを調整する必要がある場合もあります。

  • IAM アクセス許可: Lambda 実行ロールには、選択したモデルに対する アクセスbedrock:InvokeModel許可が必要です

  • タイムアウト: Bedrock API のレイテンシーと再試行に対応するために Lambda タイムアウトを少なくとも 60 秒に設定する

  • リージョン: 選択した Bedrock モデルが利用可能なリージョンにデプロイする

  • コスト: 各評価でサンプルごとに 1 つの API コールが行われるため、Bedrock API の使用状況をモニタリングする

  • スループット: 大規模な評価では、スロットリングを避けるために Bedrock クォータの引き上げをリクエストする

Bedrock スループットの向上

評価中にスロットリングが発生した場合は、Bedrock モデルのクォータを増やします。

  • AWS Service Quotas コンソールに移動します。

  • 「Bedrock」を検索してリージョンを選択する

  • 選択したモデルのクォータを検索します (たとえば、「Claude 3.5 Sonnet の 1 分あたりの呼び出し数」)。

  • 「クォータの引き上げをリクエスト」をクリックし、希望するスループットを指定します。

  • 引き上げの根拠を記載します (例:「RFT 評価ワークロード」)

Lambda の組み込み再試行ロジックは時折スロットリングを処理しますが、持続的な大量の評価には適切なクォータの増加が必要です。

必要な IAM ポリシー:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "bedrock:InvokeModel" ], "Resource": "arn:aws:bedrock:*::foundation-model/*" } ] }