實作獎勵函數 - Amazon SageMaker AI

本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。

實作獎勵函數

概觀

獎勵函數 (也稱為計分器或分級器) 是評估模型回應並提供意見回饋訊號以供訓練的核心元件。它必須實作為接受模型回應並傳回獎勵分數的 Lambda 函數。

介面格式

您的獎勵函數必須接受並傳回下列格式的資料:

訓練的輸入範例

{ "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

獎勵 lambda 的範例承載

容器會自動轉換您的資料,再透過下列方式將其傳送至 Lambda 函數:

  1. 為每個提示產生模型回應

  2. 將助理轉彎 (產生的回應) 附加到訊息陣列

  3. 新增唯一id欄位以進行追蹤

您的 Lambda 函數將以此轉換格式接收資料:

{ "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don not have a dedicated security team..." } ], # Following section will be same as your training dataset sample "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } }

獎勵 Lambda 合約

def lambda_handler(event, context): return lambda_grader(event) def lambda_grader(samples: list[dict]) -> list[dict]: """ Args: samples: List of dictionaries in OpenAI format Example input: { "id": "123", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Company, I don nott have a dedicated security team..." } ], # This section will be same as your training dataset "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." } } Returns: List of dictionaries with reward scores: { "id": str, # Same id as input sample "aggregate_reward_score": float, # Overall score for the sample "metrics_list": [ # OPTIONAL: Component scores { "name": str, # Name of the component score "value": float, # Value of the component score "type": str # "Reward" or "Metric" } ] } """

輸入和輸出欄位

輸入欄位

欄位 Description 其他備註
id 範例的唯一識別符 在輸出中回呼。字串格式
messages 以 OpenAI 格式排序的聊天歷史記錄 訊息物件陣列
messages【】.role 訊息的發言者 常見值:"user"、"assistant"、"system"
messages【】.content 訊息的文字內容 純文字的字串
**中繼資料 協助分級的自由格式資訊 物件;從訓練資料傳遞的選用欄位

輸出欄位

欄位 Description 其他備註
id 與輸入範例相同的識別符 必須符合輸入
aggregate_reward_score 範例的整體分數 浮動 (例如 0.0–1.0 或任務定義範圍)
metrics_list 組成彙總的元件分數 指標物件陣列

技術限制條件

  • 逾時限制 – 每個 Lambda 調用最多 15 分鐘執行時間

  • 並行 – 必須處理rollout_worker_replicas * 64並行請求

  • 可靠性 – 必須實作適當的錯誤處理,並一致地傳回有效的分數

  • 效能 – 最佳化快速執行 (秒,而非分鐘),以啟用高效訓練

最佳實務

  • 將外部 API 呼叫降至最低

  • 使用有效率的演算法和資料結構

  • 實作暫時性故障的重試邏輯

  • 快取可重複使用的運算

  • 在訓練之前徹底測試,以確保無錯誤執行

使用自訂獎勵函數

當您有任務特定的評估條件時,請實作自訂獎勵函數:

  • 定義評估條件 – 判斷哪些項目對您的任務做出良好的回應

  • 實作 Lambda 函數 – 依照介面格式建立 Lambda 函數

  • 本機測試 – 驗證您的函數傳回正確的範例輸入分數

  • 部署至 AWS - 部署您的 Lambda 並記下 ARN

  • 設定配方 – 將 Lambda ARN 新增至配方的 reward_lambda_arn 欄位

  • 使用小型資料集進行測試 – 以最少的資料執行 RFT 以驗證整合

IAM 許可

所需的許可

您的 SageMaker 執行角色必須具有叫用 Lambda 函數的許可。將此政策新增至 SageMaker 執行角色:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "lambda:InvokeFunction" ], "Resource": "arn:aws:lambda:region:account-id:function:function-name" } ] }

Lambda 執行角色

您的 Lambda 函數的執行角色需要基本的 Lambda 執行許可:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "logs:CreateLogGroup", "logs:CreateLogStream", "logs:PutLogEvents" ], "Resource": "arn:aws:logs:*:*:*" } ] }

其他許可:如果您的 Lambda 函數存取其他 AWS 服務 (例如,用於參考資料的 S3、用於記錄的 DynamoDB),請將這些許可新增至 Lambda 執行角色。

範例:LLM 做為判斷獎勵函數

此範例示範如何使用 Amazon Bedrock 模型做為判斷,透過將模型回應與參考答案進行比較來評估模型回應。此 Lambda 範本提供架構,讓客戶實作對 Amazon Bedrock 的呼叫,以進行推論請求,以處理判斷評估。Lambda 函數會維護與其他獎勵函數相同的輸入/輸出合約。

實作

此 Lambda 函數會實作兩階段評估程序: 會從傳入的範例中lambda_handler擷取模型回應和參考答案,然後lambda_graded函數會呼叫 Amazon Bedrock 來對它們之間的語意相似性進行評分。實作包括強大的錯誤處理功能,可自動重試暫時性失敗,並支援靈活的參考答案格式 (字串和結構化字典格式)。

實作詳細資訊:

  • 重試邏輯:針對調節例外狀況實作指數退避 (1、2、4),以處理 Bedrock API 速率限制

  • 錯誤處理:針對失敗的評估傳回 0.0 的分數,而不是引發例外狀況

  • 確定性評分:使用溫度 = 0.0 來確保評估之間的分數一致

  • 彈性參考格式:自動處理字串和字典參考答案

  • 分數限制:確保所有分數都落在有效的 【0.0, 1.0】 範圍內

  • 模型無關:將 JUDGE_MODEL_ID 變更為使用任何 Amazon Bedrock 模型 (Nova、Llama、Mistral 等)

""" LLM Judge Lambda POC - Working implementation using Amazon Bedrock """ import json import time import boto3 bedrock_runtime = boto3.client('bedrock-runtime', region_name='us-east-1') JUDGE_MODEL_ID = "anthropic.claude-3-5-sonnet-20240620-v1:0" SYSTEM_PROMPT = "You must output ONLY a number between 0.0 and 1.0. No explanations, no text, just the number." JUDGE_PROMPT_TEMPLATE = """Compare the following two responses and rate how similar they are on a scale of 0.0 to 1.0, where: - 1.0 means the responses are semantically equivalent (same meaning, even if worded differently) - 0.5 means the responses are partially similar - 0.0 means the responses are completely different or contradictory Response A: {response_a} Response B: {response_b} Output ONLY a number between 0.0 and 1.0. No explanations.""" def lambda_graded(response_a: str, response_b: str, max_retries: int = 3) -> float: """Call Bedrock to compare responses and return similarity score.""" prompt = JUDGE_PROMPT_TEMPLATE.format(response_a=response_a, response_b=response_b) for attempt in range(max_retries): try: response = bedrock_runtime.converse( modelId=JUDGE_MODEL_ID, messages=[{"role": "user", "content": [{"text": prompt}]}], system=[{"text": SYSTEM_PROMPT}], inferenceConfig={"temperature": 0.0, "maxTokens": 10} ) print(f"Bedrock call successful: {response}") output = response['output']['message']['content'][0]['text'].strip() score = float(output) print(f"Score parsed: {score}") return max(0.0, min(1.0, score)) except Exception as e: if "ThrottlingException" in str(e) and attempt < max_retries - 1: time.sleep(2 ** attempt) else: print(f"Bedrock call failed: {e}") return None return None def lambda_handler(event, context): """AWS Lambda handler - processes samples from RFTEvalInvoker.""" try: samples = event if isinstance(event, list) else [event] results = [] for sample in samples: sample_id = sample.get("id", "unknown") messages = sample.get("messages", []) # Extract assistant response (response A) response_a = "" for msg in messages: if msg.get("role") in ["assistant", "nova_assistant"]: response_a = msg.get("content", "") break # Extract reference answer from root level (no longer in metadata) reference_answer = sample.get("reference_answer", "") # Handle both string and dict reference_answer formats if isinstance(reference_answer, dict): # If reference_answer is a dict, extract the explanation or compliant field response_b = reference_answer.get("explanation", reference_answer.get("compliant", "")) else: response_b = reference_answer if not response_a or not response_b: results.append({ "id": sample_id, "aggregate_reward_score": 0.0, "metrics_list": [{"name": "similarity_score", "value": 0.0, "type": "Metric"}] }) continue # Get similarity score score = lambda_graded(response_a, response_b) results.append({ "id": sample_id, "aggregate_reward_score": score, "metrics_list": [ { "name": "similarity_score", "value": score, "type": "Metric" } ] }) return {"statusCode": 200, "body": json.dumps(results)} except Exception as e: print(f"Error: {e}") return {"statusCode": 500, "body": json.dumps({"error": str(e)})}

輸入格式

Lambda 會收到與其他獎勵函數相同的輸入格式:

{ "id": "sample-001", "messages": [ { "role": "user", "content": "Do you have a dedicated security team?" }, { "role": "assistant", "content": "As an AI developed by Amazon, I don't have a dedicated security team..." } ], "reference_answer": { "compliant": "No", "explanation": "As an AI developed by Company, I do not have a traditional security team..." }, "my_custom_field": "custom_value" }

輸出格式

{ "id": "sample-001", "aggregate_reward_score": 0.85, "metrics_list": [ { "name": "similarity_score", "value": 0.85, "type": "Metric" } ] }

部署考量

您可能還需要根據所選模型的功能和 API 格式來調整提示範本和推論參數。

  • IAM 許可:Lambda 執行角色必須具有所選模型的bedrock:InvokeModel許可

  • 逾時:將 Lambda 逾時設定為至少 60 秒,以適應 Bedrock API 延遲和重試

  • 區域:部署在可使用您所選 Bedrock 模型的區域

  • 成本:監控 Bedrock API 用量,因為每次評估都會對每個範例發出一次 API 呼叫

  • 輸送量:對於大規模評估,請求提高 Bedrock 配額以避免限流

增加 Bedrock 輸送量

如果您在評估期間遇到限流,請增加 Bedrock 模型配額:

  • 導覽至 AWS Service Quotas 主控台

  • 搜尋 "Bedrock" 並選擇您的區域

  • 尋找所選模型的配額 (例如「Claude 3.5 Sonnet 每分鐘叫用次數」)

  • 按一下「請求增加配額」並指定所需的輸送量

  • 提供增加的理由 (例如,「RFT 評估工作負載」)

Lambda 的內建重試邏輯會偶爾處理限流,但持續的大量評估需要增加適當的配額。

必要的 IAM 政策:

{ "Version": "2012-10-17", "Statement": [ { "Effect": "Allow", "Action": [ "bedrock:InvokeModel" ], "Resource": "arn:aws:bedrock:*::foundation-model/*" } ] }