Las traducciones son generadas a través de traducción automática. En caso de conflicto entre la traducción y la version original de inglés, prevalecerá la version en inglés.
Cómo usar SageMaker AI TabTransformer
Puede utilizar TabTransformer como un algoritmo integrado de Amazon SageMaker AI. En la siguiente sección, se describe cómo utilizar TabTransformer con el SageMaker Python SDK. Para obtener información sobre cómo utilizar TabTransformer desde la interfaz de usuario de Amazon SageMaker Studio Classic, consulte SageMaker JumpStart modelos preentrenados.
-
Uso de TabTransformer como algoritmo integrado
Utilice el algoritmo integrado TabTransformer para crear un contenedor de entrenamiento de TabTransformer, como se ve en el siguiente ejemplo de código. Puede detectar automáticamente el URI de la imagen del algoritmo integrado TabTransformer mediante la API
image_uris.retrievede SageMaker AI (o la APIget_image_urisi utiliza la versión 2 de Amazon SageMaker Python SDK). Después de especificar el URI de imagen de TabTransformer, puede utilizar el contenedor de TabTransformer para construir un estimador con la API Estimator de SageMaker AI e iniciar un trabajo de entrenamiento. El algoritmo integrado TabTransformer se ejecuta en modo script, pero el script de entrenamiento se proporciona automáticamente y no es necesario reemplazarlo. Si tiene mucha experiencia en el uso del modo script para crear un trabajo de entrenamiento de SageMaker, puede incorporar sus propios scripts de entrenamiento de TabTransformer.
from sagemaker import image_uris, model_uris, script_uris train_model_id, train_model_version, train_scope = "pytorch-tabtransformerclassification-model", "*", "training" training_instance_type = "ml.p3.2xlarge" # Retrieve the docker image train_image_uri = image_uris.retrieve( region=None, framework=None, model_id=train_model_id, model_version=train_model_version, image_scope=train_scope, instance_type=training_instance_type ) # Retrieve the training script train_source_uri = script_uris.retrieve( model_id=train_model_id, model_version=train_model_version, script_scope=train_scope ) train_model_uri = model_uris.retrieve( model_id=train_model_id, model_version=train_model_version, model_scope=train_scope ) # Sample training data is available in this bucket training_data_bucket = f"jumpstart-cache-prod-{aws_region}" training_data_prefix = "training-datasets/tabular_binary/" training_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/train" validation_dataset_s3_path = f"s3://{training_data_bucket}/{training_data_prefix}/validation" output_bucket = sess.default_bucket() output_prefix = "jumpstart-example-tabular-training" s3_output_location = f"s3://{output_bucket}/{output_prefix}/output" from sagemaker import hyperparameters # Retrieve the default hyperparameters for training the model hyperparameters = hyperparameters.retrieve_default( model_id=train_model_id, model_version=train_model_version ) # [Optional] Override default hyperparameters with custom values hyperparameters[ "n_epochs" ] = "50" print(hyperparameters) from sagemaker.estimator import Estimator from sagemaker.utils import name_from_base training_job_name = name_from_base(f"built-in-algo-{train_model_id}-training") # Create SageMaker Estimator instance tabular_estimator = Estimator( role=aws_role, image_uri=train_image_uri, source_dir=train_source_uri, model_uri=train_model_uri, entry_point="transfer_learning.py", instance_count=1, instance_type=training_instance_type, max_run=360000, hyperparameters=hyperparameters, output_path=s3_output_location ) # Launch a SageMaker Training job by passing the S3 path of the training data tabular_estimator.fit( { "training": training_dataset_s3_path, "validation": validation_dataset_s3_path, }, logs=True, job_name=training_job_name )Para obtener más información sobre cómo configurar TabTransformer como un algoritmo integrado, consulte los siguientes ejemplos de cuadernos.