Terjemahan disediakan oleh mesin penerjemah. Jika konten terjemahan yang diberikan bertentangan dengan versi bahasa Inggris aslinya, utamakan versi bahasa Inggris.
Terapkan penyaringan SageMaker cerdas ke skrip Anda PyTorch
Instruksi ini menunjukkan cara mengaktifkan penyaringan SageMaker cerdas dengan skrip pelatihan Anda.
-
Konfigurasikan antarmuka penyaringan SageMaker cerdas.
Pustaka penyaringan SageMaker cerdas mengimplementasikan teknik pengambilan sampel berbasis kerugian ambang batas relatif yang membantu menyaring sampel dengan dampak yang lebih rendah dalam mengurangi nilai kerugian. Algoritma penyaringan SageMaker cerdas menghitung nilai kerugian dari setiap sampel data input menggunakan pass maju, dan menghitung persentil relatifnya terhadap nilai kehilangan data sebelumnya.
Dua parameter berikut adalah apa yang perlu Anda tentukan ke
RelativeProbabilisticSiftConfigkelas untuk membuat objek konfigurasi penyaringan.-
Tentukan proporsi data yang harus digunakan untuk pelatihan ke
beta_valueparameter. -
Tentukan jumlah sampel yang digunakan dalam perbandingan dengan
loss_history_lengthparameter.
Contoh kode berikut menunjukkan pengaturan sebuah objek dari
RelativeProbabilisticSiftConfigkelas.from smart_sifting.sift_config.sift_configs import ( RelativeProbabilisticSiftConfig LossConfig SiftingBaseConfig ) sift_config=RelativeProbabilisticSiftConfig( beta_value=0.5, loss_history_length=500, loss_based_sift_config=LossConfig( sift_config=SiftingBaseConfig(sift_delay=0) ) )Untuk informasi selengkapnya tentang
loss_based_sift_configparameter dan class terkait, lihat SageMaker modul konfigurasi penyaringan cerdas di bagian referensi SageMaker Smart Sifting Python SDK.sift_configObjek dalam contoh kode sebelumnya digunakan pada langkah 4 untuk menyiapkan kelas.SiftingDataloader -
-
(Opsional) Konfigurasikan kelas transformasi batch penyaringan SageMaker cerdas.
Kasus penggunaan pelatihan yang berbeda memerlukan format data pelatihan yang berbeda. Mengingat berbagai format data, algoritma penyaringan SageMaker cerdas perlu mengidentifikasi cara melakukan penyaringan pada batch tertentu. Untuk mengatasi hal ini, SageMaker smart sifting menyediakan modul transformasi batch yang membantu mengonversi batch menjadi format standar yang dapat disaring secara efisien.
-
SageMaker smart sifting menangani transformasi batch data pelatihan dalam format berikut: Daftar Python, kamus, tupel, dan tensor. Untuk format data ini, SageMaker smart sifting secara otomatis menangani konversi format data batch, dan Anda dapat melewati sisa langkah ini. Jika Anda melewati langkah ini, pada langkah 4 untuk mengkonfigurasi
SiftingDataloader, biarkanbatch_transformsparameterSiftingDataloaderke nilai defaultnya, yaituNone. -
Jika kumpulan data Anda tidak dalam format ini, Anda harus melanjutkan ke sisa langkah ini untuk membuat transformasi batch khusus menggunakan
SiftingBatchTransform.Dalam kasus di mana kumpulan data Anda tidak berada dalam salah satu format yang didukung oleh penyaringan SageMaker cerdas, Anda mungkin mengalami kesalahan. Kesalahan format data tersebut dapat diatasi dengan menambahkan
batch_transformsparameterbatch_format_indexor keSiftingDataloaderkelas, yang Anda atur di langkah 4. Berikut ini menunjukkan contoh kesalahan karena format data yang tidak kompatibel dan resolusi untuk mereka.Pesan Kesalahan Resolusi Batch tipe tidak
{type(batch)}didukung secara default.Kesalahan ini menunjukkan format batch tidak didukung secara default. Anda harus menerapkan kelas transformasi batch kustom, dan menggunakannya dengan menentukannya ke batch_transformsparameterSiftingDataloaderkelas.Tidak dapat mengindeks kumpulan jenis
{type(batch)}Kesalahan ini menunjukkan objek batch tidak dapat diindeks secara normal. Pengguna harus mengimplementasikan transformasi batch khusus dan meneruskan ini menggunakan batch_transformsparameter.Ukuran Batch
{batch_size}tidak sesuai dengan dimensi 0 atau dimensi 1 ukuranKesalahan ini terjadi ketika ukuran batch yang disediakan tidak sesuai dengan dimensi ke-0 atau ke-1 dari batch. Pengguna harus mengimplementasikan transformasi batch khusus dan meneruskan ini menggunakan batch_transformsparameter.Dimensi 0 dan dimensi 1 cocok dengan ukuran batch
Kesalahan ini menunjukkan bahwa karena beberapa dimensi cocok dengan ukuran batch yang disediakan, informasi lebih lanjut diperlukan untuk menyaring batch. Pengguna dapat memberikan batch_format_indexparameter untuk menunjukkan apakah batch dapat diindeks berdasarkan sampel atau fitur. Pengguna juga dapat menerapkan transformasi batch khusus, tetapi ini lebih banyak pekerjaan daripada yang diperlukan.Untuk mengatasi masalah yang disebutkan di atas, Anda perlu membuat kelas transformasi batch khusus menggunakan
SiftingBatchTransformmodul. Kelas transformasi batch harus terdiri dari sepasang fungsi transformasi dan reverse-transform. Pasangan fungsi mengonversi format data Anda ke format yang dapat diproses oleh algoritme penyaringan SageMaker cerdas. Setelah Anda membuat kelas transformasi batch, kelas mengembalikanSiftingBatchobjek yang akan Anda berikan keSiftingDataloaderkelas di langkah 4.Berikut ini adalah contoh kelas transformasi batch kustom
SiftingBatchTransformmodul.-
Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana potongan dataloader memiliki input, mask, dan label.
from typing import Any import torch from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.list_batch import ListBatch classListBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): inputs = batch[0].tolist() labels = batch[-1].tolist() # assume the last one is the list of labels return ListBatch(inputs, labels) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs), torch.tensor(list_batch.labels)] return a_batch -
Contoh implementasi transformasi batch daftar kustom dengan penyaringan SageMaker cerdas untuk kasus di mana tidak ada label yang diperlukan untuk transformasi terbalik.
classListBatchTransformNoLabels(SiftingBatchTransform): def transform(self, batch: Any): return ListBatch(batch[0].tolist()) def reverse_transform(self, list_batch: ListBatch): a_batch = [torch.tensor(list_batch.inputs)] return a_batch -
Contoh implementasi batch tensor khusus dengan penyaringan SageMaker cerdas untuk kasus di mana potongan pemuat data memiliki input, masker, dan label.
from typing import Any from smart_sifting.data_model.data_model_interface import SiftingBatchTransform from smart_sifting.data_model.tensor_batch import TensorBatch classTensorBatchTransform(SiftingBatchTransform): def transform(self, batch: Any): a_tensor_batch = TensorBatch( batch[0], batch[-1] ) # assume the last one is the list of labels return a_tensor_batch def reverse_transform(self, tensor_batch: TensorBatch): a_batch = [tensor_batch.inputs, tensor_batch.labels] return a_batch
Setelah Anda membuat
SiftingBatchTransformkelas transformasi batch yang diimplementasikan, Anda menggunakan kelas ini di langkah 4 untuk menyiapkan kelas.SiftingDataloaderSisa dari panduan ini mengasumsikan bahwaListBatchTransformkelas dibuat. Pada langkah 4, kelas ini diteruskan kebatch_transforms. -
-
-
Buat kelas untuk mengimplementasikan
Lossantarmuka penyaringan SageMaker cerdas. Tutorial ini mengasumsikan bahwa kelas diberi namaSiftingImplementedLoss. Saat menyiapkan kelas ini, kami sarankan Anda menggunakan fungsi kerugian yang sama dalam loop pelatihan model. Ikuti sublangkah berikut untuk membuat kelasLossimplementasi penyaringan SageMaker cerdas.-
SageMaker smart sifting menghitung nilai kerugian untuk setiap sampel data pelatihan, sebagai lawan menghitung nilai kerugian tunggal untuk batch. Untuk memastikan bahwa penyaringan SageMaker cerdas menggunakan logika perhitungan kerugian yang sama, buat fungsi smart-sifting-implemented kerugian menggunakan
Lossmodul penyaringan SageMaker pintar yang menggunakan fungsi kerugian Anda dan hitung kerugian per sampel pelatihan.Tip
SageMaker algoritma penyaringan cerdas berjalan pada setiap sampel data, bukan pada seluruh batch, jadi Anda harus menambahkan fungsi inisialisasi untuk mengatur fungsi PyTorch kerugian tanpa strategi pengurangan apa pun.
classSiftingImplementedLoss(Loss): def __init__(self): self.loss =torch.nn.CrossEntropyLoss(reduction='none')Ini juga ditunjukkan dalam contoh kode berikut.
-
Tentukan fungsi kerugian yang menerima
original_batch(atautransformed_batchjika Anda telah menyiapkan transformasi batch pada langkah 2) dan PyTorch model. Menggunakan fungsi kerugian yang ditentukan tanpa pengurangan, SageMaker smart sifting menjalankan forward pass untuk setiap sampel data untuk mengevaluasi nilai kerugiannya.
Kode berikut adalah contoh dari smart-sifting-implemented
Lossantarmuka bernamaSiftingImplementedLoss.from typing import Any import torch import torch.nn as nn from torch import Tensor from smart_sifting.data_model.data_model_interface import SiftingBatch from smart_sifting.loss.abstract_sift_loss_module import Loss model=... # a PyTorch model based on torch.nn.Module classSiftingImplementedLoss(Loss): # You should add the following initializaztion function # to calculate loss per sample, not per batch. def __init__(self): self.loss_no_reduction=torch.nn.CrossEntropyLoss(reduction='none') def loss( self, model: torch.nn.Module, transformed_batch: SiftingBatch, original_batch: Any = None, ) -> torch.Tensor: device = next(model.parameters()).device batch = [t.to(device) for t in original_batch] # use this if you use original batch and skipped step 2 # batch = [t.to(device) for t in transformed_batch] # use this if you transformed batches in step 2 # compute loss outputs = model(batch) return self.loss_no_reduction(outputs.logits, batch[2])Sebelum loop pelatihan mencapai pass maju yang sebenarnya, perhitungan kerugian penyaringan ini dilakukan selama fase pemuatan data pengambilan batch di setiap iterasi. Nilai kerugian individu kemudian dibandingkan dengan nilai kerugian sebelumnya, dan persentil relatifnya diperkirakan per objek yang telah
RelativeProbabilisticSiftConfigAnda atur pada langkah 1. -
-
Bungkus pemuat PyTroch data dengan
SiftingDataloaderkelas SageMaker AI.Terakhir, gunakan semua kelas implementasi penyaringan SageMaker cerdas yang Anda konfigurasikan pada langkah sebelumnya ke kelas
SiftingDataloderkonfigurasi SageMaker AI. Kelas ini adalah pembungkus untuk PyTorchDataLoader. Dengan membungkus PyTorch DataLoader, penyaringan SageMaker cerdas terdaftar untuk dijalankan sebagai bagian dari pemuatan data di setiap iterasi pekerjaan pelatihan. PyTorch Contoh kode berikut menunjukkan penerapan penyaringan data SageMaker AI ke file. PyTorchDataLoaderfrom smart_sifting.dataloader.sift_dataloader import SiftingDataloader from torch.utils.data import DataLoader train_dataloader = DataLoader(...) # PyTorch data loader # Wrap the PyTorch data loader by SiftingDataloder train_dataloader = SiftingDataloader( sift_config=sift_config, # config object of RelativeProbabilisticSiftConfig orig_dataloader=train_dataloader, batch_transforms=ListBatchTransform(), # Optional, this is the custom class from step 2 loss_impl=SiftingImplementedLoss(), # PyTorch loss function wrapped by the Sifting Loss interface model=model, log_batch_data=False)