So verwenden Sie den Algorithmus Textklassifizierung – TensorFlow von SageMaker AI - Amazon SageMaker AI

So verwenden Sie den Algorithmus Textklassifizierung – TensorFlow von SageMaker AI

Sie können Textklassifizierung – TensorFlow als integrierten Algorithmus von Amazon SageMaker AI verwenden. Im folgenden Abschnitt wird beschrieben, wie Sie Textklassifizierung – TensorFlow mit dem SageMaker AI Python SDK verwenden. Informationen zur Verwendung von Textklassifizierung – TensorFlow über die Benutzeroberfläche von Amazon SageMaker Studio Classic finden Sie unter Vortrainierte SageMaker-JumpStart-Modelle.

Der Textklassifizierung – TensorFlow-Algorithmus unterstützt Transfer Learning unter Verwendung eines der kompatiblen vortrainierten TensorFlow-Modelle. Eine Liste aller verfügbaren vortrainierten Modelle finden Sie unter TensorFlow-Hub-Modelle. Jedes vortrainierte Modell hat ein Unikat model_id. Im folgenden Beispiel wird BERT Base Uncased (model_id:tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2) zur Feinabstimmung eines benutzerdefinierten Datensatzes verwendet. Die vortrainierten Modelle werden alle vorab vom TensorFlow Hub heruntergeladen und in Amazon S3-Buckets gespeichert, sodass Trainingsauftrages netzwerkisoliert ausgeführt werden können. Verwenden Sie diese vorgenerierten Modelltrainingsartefakte, um einen SageMaker AI Estimator zu erstellen.

Rufen Sie zunächst den Docker-Image-URI, den Trainingsskript-URI und den vortrainierten Modell-URI ab. Ändern Sie dann die Hyperparameter nach Bedarf. Sie können ein Python-Wörterbuch mit allen verfügbaren Hyperparametern und ihren Standardwerten mit hyperparameters.retrieve_default sehen. Weitere Informationen finden Sie unter Textklassifizierungs- TensorFlow Hyperparameter. Verwenden Sie diese Werte, um einen SageMaker AI Estimator zu erstellen.

Anmerkung

Die Standard-Hyperparameterwerte sind für verschiedene Modelle unterschiedlich. Bei größeren Modellen ist die Standardstapelgröße beispielsweise kleiner.

In diesem Beispiel wird der SST2Datensatz verwendet, der positive und negative Filmkritiken enthält. Wir haben den Datensatz vorab heruntergeladen und mit Amazon S3 verfügbar gemacht. Rufen Sie zur Feinabstimmung Ihres Modells an, .fit indem Sie den Amazon S3-Speicherort Ihres Trainingsdatensatzes verwenden. Jeder S3-Bucket, der in einem Notebook verwendet wird, muss sich in derselben AWS-Region befinden wie die Notebook-Instance, die darauf zugreift.

from sagemaker import image_uris, model_uris, script_uris, hyperparameters from sagemaker.estimator import Estimator model_id, model_version = "tensorflow-tc-bert-en-uncased-L-12-H-768-A-12-2", "*" training_instance_type = "ml.p3.2xlarge" # Retrieve the Docker image train_image_uri = image_uris.retrieve(model_id=model_id,model_version=model_version,image_scope="training",instance_type=training_instance_type,region=None,framework=None) # Retrieve the training script train_source_uri = script_uris.retrieve(model_id=model_id, model_version=model_version, script_scope="training") # Retrieve the pretrained model tarball for transfer learning train_model_uri = model_uris.retrieve(model_id=model_id, model_version=model_version, model_scope="training") # Retrieve the default hyperparameters for fine-tuning the model hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version) # [Optional] Override default hyperparameters with custom values hyperparameters["epochs"] = "5" # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/SST2/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tc-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" # Create an Estimator instance tf_tc_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location, ) # Launch a training job tf_tc_estimator.fit({"training": training_dataset_s3_path}, logs=True)

Weitere Informationen zur Verwendung des SageMaker Textklassifizierung – TensorFlow-Algorithmus für Transfer-Leraning in einem benutzerdefinierten Datensatz finden Sie im Notebook Introduction to JumpStart – Textklassifizierung.