本文為英文版的機器翻譯版本,如內容有任何歧義或不一致之處,概以英文版為準。
SageMaker AI Spark for Scala 範例
Amazon SageMaker AI 提供 Apache Spark AI 程式庫 (SageMaker AI Spark
下載 Spark for Scala
您可以從 SageMaker AI Spark
如需安裝 SageMaker AI Spark 程式庫的詳細說明,請參閱 SageMaker AI Spark
SageMaker AI Spark SDK for Scala 可在 Maven 中央儲存庫中取得。在您的 pom.xml 檔案中新增以下相依性,將 Spark 程式庫新增至專案:
-
如果您的專案是使用 Maven 建置的,請將下列項目新增至 pom.xml 檔案:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.2.0-1.0</version> </dependency> -
如果您的專案依賴 Spark 2.1,請將下列項目新增至 pom.xml 檔案:
<dependency> <groupId>com.amazonaws</groupId> <artifactId>sagemaker-spark_2.11</artifactId> <version>spark_2.1.1-1.0</version> </dependency>
Spark for Scala 範例
本節會提供範例程式碼,其可使用 SageMaker AI 提供的 Apache Spark Scala 程式庫,在您的 Spark 叢集中使用 DataFrame 來訓練 SageMaker AI 中的模型。接下來是如何 使用自訂演算法搭配 Apache Spark,在 Amazon SageMaker AI 進行模型訓練和託管 和 在 Spark 管道中使用 SageMakerEstimator 的範例。
下列範例會使用 SageMaker AI 託管服務,託管產生的模型成品。如需此範例的更多詳細資訊,請參閱入門:使用 SageMaker AI Spark SDK 在 SageMaker AI 上進行 K-Means 叢集化
-
使用
KMeansSageMakerEstimator,擬合 (或訓練) 資料上的模型因為此範例使用 SageMaker AI 提供的 k-means 演算法訓練模型,所以您會使用
KMeansSageMakerEstimator。您可以善用來自 MNIST 資料集的手寫個位數字影像,加以訓練模型。請將該影像提供為輸入DataFrame。為方便起見,SageMaker AI 會在 Amazon S3 儲存貯體中提供此資料集。估算器會在回應中傳回
SageMakerModel物件。 -
使用訓練過的
SageMakerModel獲取推論若要從 SageMaker AI 託管的模型取得推論,請呼叫
SageMakerModel.transform方法。您可以將DataFrame傳遞為輸入。該方法會將輸入DataFrame轉換為另一個DataFrame,其將包含從模型取得的推論。針對指定的手寫個位數字輸入影像,推論功能會識別該影像所屬的叢集。如需詳細資訊,請參閱K 平均數演算法。
import org.apache.spark.sql.SparkSession import com.amazonaws.services.sagemaker.sparksdk.IAMRole import com.amazonaws.services.sagemaker.sparksdk.algorithms import com.amazonaws.services.sagemaker.sparksdk.algorithms.KMeansSageMakerEstimator val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784) // train val model = estimator.fit(trainingData) val transformedData = model.transform(testData) transformedData.show
此範例程式碼可做到以下操作:
-
將 MNIST 資料集從 SageMaker AI (
awsai-sparksdk-dataset) 提供的 S3 儲存貯體載入至 SparkDataFrame(mnistTrainingDataFrame):// Get a Spark session. val spark = SparkSession.builder.getOrCreate // load mnist data as a dataframe from libsvm val region = "us-east-1" val trainingData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/train/") val testData = spark.read.format("libsvm") .option("numFeatures", "784") .load(s"s3://sagemaker-sample-data-$region/spark/mnist/test/") val roleArn = "arn:aws:iam::account-id:role/rolename" trainingData.show()show方法會在資料框架中顯示前 20 個資料列:+-----+--------------------+ |label| features| +-----+--------------------+ | 5.0|(784,[152,153,154...| | 0.0|(784,[127,128,129...| | 4.0|(784,[160,161,162...| | 1.0|(784,[158,159,160...| | 9.0|(784,[208,209,210...| | 2.0|(784,[155,156,157...| | 1.0|(784,[124,125,126...| | 3.0|(784,[151,152,153...| | 1.0|(784,[152,153,154...| | 4.0|(784,[134,135,161...| | 3.0|(784,[123,124,125...| | 5.0|(784,[216,217,218...| | 3.0|(784,[143,144,145...| | 6.0|(784,[72,73,74,99...| | 1.0|(784,[151,152,153...| | 7.0|(784,[211,212,213...| | 2.0|(784,[151,152,153...| | 8.0|(784,[159,160,161...| | 6.0|(784,[100,101,102...| | 9.0|(784,[209,210,211...| +-----+--------------------+ only showing top 20 rows在每個資料列中:
-
label欄位會識別影像的標籤。例如,如果手寫數字的影像為數字 5,標籤值即為 5。 -
features欄位會存放org.apache.spark.ml.linalg.Vector值的向量 (Double)。這些值即為手寫數字的 784 特徵。(每個手寫數字的影像均為 28 x 28 像素,因此稱為 784 特徵。)
-
-
建立 SageMaker AI 估算器 (
KMeansSageMakerEstimator)此估算器的
fit方法會使用 SageMaker AI 提供的 k-means 演算法,以使用輸入DataFrame來訓練模型。該方法會在回應中傳回SageMakerModel物件,讓您可以獲取推論。注意
KMeansSageMakerEstimator會擴充 SageMaker AISageMakerEstimator,進而擴充 Apache SparkEstimator。val estimator = new KMeansSageMakerEstimator( sagemakerRole = IAMRole(roleArn), trainingInstanceType = "ml.p2.xlarge", trainingInstanceCount = 1, endpointInstanceType = "ml.c4.xlarge", endpointInitialInstanceCount = 1) .setK(10).setFeatureDim(784)建構函式參數提供用於訓練模型並在 SageMaker AI 上部署該模型的資訊:
-
trainingInstanceType與trainingInstanceCount- 可識別用來訓練模型的機器學習 (ML) 運算執行個體類型和數量。 -
endpointInstanceType- 識別在 SageMaker AI 中託管模型時要使用的 ML 運算執行個體類型。而根據預設,系統會採用一個機器學習 (ML) 運算執行個體。 -
endpointInitialInstanceCount- 識別最初支援端點在 SageMaker AI 託管模型的 ML 運算執行個體數量。 -
sagemakerRole- SageMaker AI 會擔任 IAM 角色,以代表您執行任務。以模型訓練任務為例,該參數會自 S3 讀取資料並將訓練結果 (模型成品) 寫入至 S3。注意
此範例會隱含地建立 SageMaker AI 用戶端。而您必須提供登入資料,才能建立此用戶端。API 會使用這些憑證,驗證對 SageMaker AI 的請求。例如,它會使用這些憑證驗證建立訓練任務和 API 呼叫的請求,以使用 SageMaker AI 託管服務來部署模型。
-
KMeansSageMakerEstimator物件建立完成後,您即可設定下列參數,以便進行模型訓練:-
訓練模型期間,K 平均數演算法應該建立的叢集數量。您可以指定 10 個叢集,並以數字 0 至 9 編號各叢集。
-
識別每個輸入影像是否皆具備 784 特徵 (每個手寫數字的影像均為 28 x 28 像素,因此稱為 784 特徵)。
-
-
-
呼叫估算器
fit方法// train val model = estimator.fit(trainingData)您可以將輸入
DataFrame傳遞為參數。該模型會執行訓練模型將其部署至 SageMaker AI 的所有工作。如需詳細資訊,請參閱整合 Apache Spark 應用程式與 SageMaker AI。在回應中,您會取得SageMakerModel物件,您可以將其用來從您在 SageMaker AI 中部署的模型取得推論。您僅需提供輸入
DataFrame。不需要為用來訓練模型的 K 平均數演算法指定登錄檔路徑,因為KMeansSageMakerEstimator已掌握該路徑。 -
呼叫
SageMakerModel.transform方法,從 SageMaker AI 中部署的模型取得推論。transform方法會採用DataFrame做為輸入並進行轉換,接著傳回另一個DataFrame,其將包含從模型取得的推論。val transformedData = model.transform(testData) transformedData.show為簡化程序,做為輸入的
DataFrame會與此範例中用來訓練模型的transform方法相同。transform方法會執行下列作業:-
將輸入
DataFrame中的features欄序列化為 protobuf,並將其傳送至 SageMaker AI 端點以進行推論。 -
將 protobuf 回應還原序列化為兩個額外欄位 (
distance_to_cluster與closest_cluster),而這兩個欄位會位於轉換後的DataFrame。
show方法會取得輸入DataFrame前 20 個資料列中的推論:+-----+--------------------+-------------------+---------------+ |label| features|distance_to_cluster|closest_cluster| +-----+--------------------+-------------------+---------------+ | 5.0|(784,[152,153,154...| 1767.897705078125| 4.0| | 0.0|(784,[127,128,129...| 1392.157470703125| 5.0| | 4.0|(784,[160,161,162...| 1671.5711669921875| 9.0| | 1.0|(784,[158,159,160...| 1182.6082763671875| 6.0| | 9.0|(784,[208,209,210...| 1390.4002685546875| 0.0| | 2.0|(784,[155,156,157...| 1713.988037109375| 1.0| | 1.0|(784,[124,125,126...| 1246.3016357421875| 2.0| | 3.0|(784,[151,152,153...| 1753.229248046875| 4.0| | 1.0|(784,[152,153,154...| 978.8394165039062| 2.0| | 4.0|(784,[134,135,161...| 1623.176513671875| 3.0| | 3.0|(784,[123,124,125...| 1533.863525390625| 4.0| | 5.0|(784,[216,217,218...| 1469.357177734375| 6.0| | 3.0|(784,[143,144,145...| 1736.765869140625| 4.0| | 6.0|(784,[72,73,74,99...| 1473.69384765625| 8.0| | 1.0|(784,[151,152,153...| 944.88720703125| 2.0| | 7.0|(784,[211,212,213...| 1285.9071044921875| 3.0| | 2.0|(784,[151,152,153...| 1635.0125732421875| 1.0| | 8.0|(784,[159,160,161...| 1436.3162841796875| 6.0| | 6.0|(784,[100,101,102...| 1499.7366943359375| 7.0| | 9.0|(784,[209,210,211...| 1364.6319580078125| 6.0| +-----+--------------------+-------------------+---------------+您即可解譯資料,如下所示:
-
label為 5 的手寫數字屬於叢集 4 (closest_cluster)。 -
label為 0 的手寫數字屬於叢集 5。 -
label為 4 的手寫數字屬於叢集 9。 -
label為 1 的手寫數字屬於叢集 6。
-