Wenden Sie SageMaker Smart Sifting auf Ihr Hugging Face Transformers-Skript an - Amazon SageMaker KI

Die vorliegende Übersetzung wurde maschinell erstellt. Im Falle eines Konflikts oder eines Widerspruchs zwischen dieser übersetzten Fassung und der englischen Fassung (einschließlich infolge von Verzögerungen bei der Übersetzung) ist die englische Fassung maßgeblich.

Wenden Sie SageMaker Smart Sifting auf Ihr Hugging Face Transformers-Skript an

Es gibt zwei Möglichkeiten, das SageMaker Smart Sifting in die Trainer Transformers-Klasse zu implementieren.

Anmerkung

Wenn Sie eines der DLCs for PyTorch verwenden, während das SageMaker Smart-Sifting-Paket installiert ist, beachten Sie, dass Sie die transformers Bibliothek installieren müssen. Sie können zusätzliche Pakete installieren, indem Sie die Klasse for PyTorch (sagemaker.pytorch.PyTorch) für requirements.txt den Trainingsjob-Launcher erweitern DLCs oder an sie im SageMaker AI Python SDK übergeben.

Einfache Einrichtung

Der einfachste Weg, SageMaker Smart Sifting in die Trainer Transformers-Klasse zu implementieren, ist die Verwendung der enable_sifting Funktion. Diese Funktion akzeptiert ein vorhandenes Trainer Objekt und umschließt das vorhandene DataLoader Objekt mit. SiftingDataloader Sie können dasselbe Trainingsobjekt weiterhin verwenden. Sehen Sie sich das folgende Anwendungsbeispiel an.

from smart_sifting.integrations.trainer import enable_sifting from smart_sifting.loss.abstract_sift_loss_module import Loss from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) class SiftingImplementedLoss(Loss): def loss(self, model, transformed_batch, original_batch): loss_fct = MSELoss(reduction="none") # make sure to set reduction to "none" logits = model.bert(**original_batch) return loss_fct(logits, original_batch.get("labels")) sift_config = RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) ) trainer = Trainer(...) enable_sifting(trainer, sift_config, loss=SiftingImplementedLoss()) # updates the trainer with Sifting Loss and config trainer.train()

Die SiftingDataloader Klasse ist ein iterierbarer Datenlader. Die genaue Größe des resultierenden Datensatzes ist aufgrund der Zufallsstichproben während der Sichtung im Voraus nicht bekannt. Infolgedessen Trainer erwartet das Hugging Face das max_stepsTrainingsargument. Beachten Sie, dass dieses Argument den Konfigurationsparameter epoch außer Kraft setzt. num_train_epochs Wenn Ihr ursprünglicher Datenlader auch iterierbar war oder Ihr Training eine einzelne Epoche verwendetmax_steps, dann funktioniert der genauso wie der SiftingDataloader vorhandene Dataloader. Wenn der ursprüngliche Dataloader nicht iterierbar war oder nicht bereitgestellt max_steps wurde, gibt der Hugging Face Trainer möglicherweise eine Fehlermeldung ähnlich der folgenden aus.

args.max_steps must be set to a positive value if dataloader does not have a length, was -1

Um dieses Problem zu beheben, stellt die enable_sifting Funktion einen optionalen Parameter bereit. set_epochs Dies ermöglicht das Training mit Epochen, wobei die Anzahl der Epochen verwendet wird, die durch das Argument num_train_epochs der Trainer Klasse bereitgestellt wird, und es wird auf die maximale System-Ganzzahl gesetztmax_steps, sodass das Training fortgesetzt werden kann, bis die angegebenen Epochen abgeschlossen sind.

Benutzerdefiniertes Setup

Für eine benutzerdefinierte Integration des SageMaker Smart Sifting Dataloaders können Sie eine benutzerdefinierte Hugging Face Face-Klasse verwenden. Trainer Innerhalb jeder Unterklasse von kann die get_train_dataloader() Funktion überschrieben werdenTrainer, um stattdessen ein Objekt der Klasse zurückzugeben. SiftingDataloader In Fällen, in denen bereits benutzerdefinierte Trainer vorhanden sind, ist dieser Ansatz möglicherweise weniger aufdringlich, erfordert jedoch Codeänderungen als die einfache Einrichtungsoption. Im Folgenden finden Sie eine Beispielimplementierung von SageMaker Smart Sifting in eine benutzerdefinierte Hugging Face Face-Klasse. Trainer

from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) from smart_sifting.dataloader.sift_dataloader import SiftingDataloader from smart_sifting.loss.abstract_sift_loss_module import Loss from smart_sifting.data_model.data_model_interface import SiftingBatch, SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch class SiftingListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch class SiftingImplementedLoss(): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.celoss = torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # compute loss outputs = model(batch) return self.celoss(outputs.logits, batch[2]) class SiftingImplementedTrainer(Trainer): def get_train_dataloader(self): dl = super().get_train_dataloader() sift_config = RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) ) return SiftingDataloader( sift_config=sift_config, orig_dataloader=dl, batch_transforms=SiftingListBatchTransform(), loss_impl=SiftingImplementedLoss(), model=self.model )

Erstellen Sie mithilfe der umschlossenen Trainer Klasse wie folgt ein Objekt daraus.

trainer = SiftingImplementedTrainer( model=model, args=training_args, train_dataset=small_train_dataset, eval_dataset=small_eval_dataset ) trainer.train()