Utilisation de TabTransformer dans SageMaker AI - Amazon SageMaker AI

Les traductions sont fournies par des outils de traduction automatique. En cas de conflit entre le contenu d'une traduction et celui de la version originale en anglais, la version anglaise prévaudra.

Utilisation de TabTransformer dans SageMaker AI

Vous pouvez utiliser TabTransformer comme un algorithme intégré dans Amazon SageMaker AI. La section suivante explique comment utiliser TabTransformer avec le kit SDK Python SageMaker. Pour en savoir plus sur l’utilisation de TabTransformer à partir de l’interface utilisateur d’Amazon SageMaker Studio Classic, consultez SageMaker JumpStart modèles préentraînés.

  • Utilisation de TabTransformer en tant qu'algorithme intégré

    Utilisez l'algorithme intégré TabTransformer pour créer un conteneur d'entraînement TabTransformer comme indiqué dans l'exemple de code suivant. Vous pouvez repérer automatiquement l’URI d’image de l’algorithme intégré TabTransformer à l’aide de l’API image_uris.retrieve SageMaker AI (ou de l’API get_image_uri si vous utilisez le kit Amazon SageMaker Python SDK version 2).

    Après avoir spécifié l’URI d’image TabTransformer, vous pouvez utiliser le conteneur TabTransformer pour construire un évaluateur à l’aide de l’API SageMaker AI Estimator et lancer une tâche d’entraînement. L'algorithme intégré TabTransformer s'exécute en mode script, mais le script d'entraînement vous est fourni et n'a pas besoin d'être remplacé. Si vous avez une vaste expérience de l'utilisation du mode script pour créer une tâche d'entraînement SageMaker, vous pouvez intégrer vos propres scripts d'entraînement TabTransformer.

    from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "pytorch-tabtransformerclassification-model", "*", "training" training_instance_type = "ml.p3.2xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_binary/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "n_epochs" ] = "50" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "training": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )

    Pour plus d'informations sur la configuration de TabTransformer en tant qu'algorithme intégré, consultez les exemples de bloc-notes suivants.