Customize train 🐳

In an earlier post we went through how to run a training script using sklearnPyTorch or transformers with SageMaker by leveraging their preconfigured framework containers. The training scripts we used were self contained, meaning they only used the respective framework and python standard library. This meant we only had to worry about uploading our data and fetching our model from s3, and deciding the instance type we wanted to use.

This is excellent for quick prototyping but in practice our code lives in separate modules to reduce duplication, and often requires additional libraries. Case in point: all our training scripts from the previous post were sharing the same load_data function and our PyTorch script had to implement its own tokenizer in the absence of additional libraries.

In this post, we will look into how we can customize SageMaker training to address those issues as well as create our own containers to run training. As a reminder, this is our minimal project setup

The code for our projects lives in src and the libraries required are defined in requirements.txt. The data comes from a sentiment classification task in Kaggle. The data comes in CSV format and contains two columns, text and label. The label is either 0 or 1. We use a scripts folder to add all necessary functionality for interacting with SageMaker. The intention is that you can simply take some of those into your own project and get started. Here is an sklearn training script that uses tf-idf naive-bayes which imports load_data from a utils module.

And here is the utils module

Let’s run this outside of SageMaker first.

In order for SageMaker to run our job, we need to include any additional files that are required. We can do that via the source_dir parameter that all SageMaker Estimators (e.g. SkLearn, PyTorch, HuggingFace, etc.) expose. As the name hints, this allows us to specify a directory with our source code which should be included. This is our src folder.

Here is the code for running the job, which is very similar to one from the previous post

Notice that the entry_point no longer needs to include src as the path is now relative. This is how you run it

This should take care of any additional modules our code might depend on. Remember we need to prepend data and model paths with file:// or s3:// for local or s3 paths respectively.