

本文為英文版的機器翻譯版本，如內容有任何歧義或不一致之處，概以英文版為準。

# 使用自訂演算法搭配 Apache Spark，在 Amazon SageMaker AI 進行模型訓練和託管
<a name="apache-spark-example1-cust-algo"></a>

在 [SageMaker AI Spark for Scala 範例](apache-spark-example1.md)，您會使用 `kMeansSageMakerEstimator`，因為該範例會使用 Amazon SageMaker AI 提供的 k-means 演算法進行模型訓練。不過，您也可以選擇使用專屬的自訂演算法來訓練模型。假設您已建立 Docker 影像，就可以建立您專屬的 `SageMakerEstimator`，並指定自訂影像的 Amazon Elastic Container Registry 路徑。

以下範例會說明從 `SageMakerEstimator` 建立 `KMeansSageMakerEstimator` 的方式。請在新的估算器中明確地指定 Docker 登錄檔路徑，以便訓練和推論程式碼影像。

```
import com.amazonaws.services.sagemaker.sparksdk.IAMRole
import com.amazonaws.services.sagemaker.sparksdk.SageMakerEstimator
import com.amazonaws.services.sagemaker.sparksdk.transformation.serializers.ProtobufRequestRowSerializer
import com.amazonaws.services.sagemaker.sparksdk.transformation.deserializers.KMeansProtobufResponseRowDeserializer

val estimator = new SageMakerEstimator(
  trainingImage =
    "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
  modelImage =
    "811284229777.dkr.ecr.us-east-1.amazonaws.com/kmeans:1",
  requestRowSerializer = new ProtobufRequestRowSerializer(),
  responseRowDeserializer = new KMeansProtobufResponseRowDeserializer(),
  hyperParameters = Map("k" -> "10", "feature_dim" -> "784"),
  sagemakerRole = IAMRole(roleArn),
  trainingInstanceType = "ml.p2.xlarge",
  trainingInstanceCount = 1,
  endpointInstanceType = "ml.c4.xlarge",
  endpointInitialInstanceCount = 1,
  trainingSparkDataFormat = "sagemaker")
```

`SageMakerEstimator` 建構函式中的參數會包含以下程式碼：
+ `trainingImage` - 可識別訓練影像的 Docker 登錄檔路徑，該訓練影像包含自訂程式碼。
+ `modelImage` - 可識別影像的 Docker 登錄檔路徑，該影像包含推論程式碼。
+ `requestRowSerializer` - 實作 `com.amazonaws.services.sagemaker.sparksdk.transformation.RequestRowSerializer`。

  此參數會序列化輸入 `DataFrame` 中的資料列，以將它們傳送至 SageMaker AI 中託管的模型進行推論。
+ `responseRowDeserializer` - 實作 

  `com.amazonaws.services.sagemaker.sparksdk.transformation.ResponseRowDeserializer`.

  此參數會還原序列化來自模型 (託管於 SageMaker AI) 的回應，回到 `DataFrame`。
+ `trainingSparkDataFormat` - 可指定 `DataFrame` 訓練資料上傳至 S3 期間，Spark 會使用的資料格式。例如，`"sagemaker"` 適用於 protobuf 格式、`"csv"` 適用於逗號分隔值，而 `"libsvm"` 適用於 LibSVM 格式。

您可以實作專屬的 `RequestRowSerializer` 和 `ResponseRowDeserializer`，將使用您推論程式碼支援之資料格式 (如 libsvm 或 .csv) 的資料列序列化及還原序列化。