Inference With fastai

Model Saving, Loading, and Prediction

Over the past The code for this post can be found here. few months, I’ve seen multiple people ask how to correctly use fast.ai to predict labels for new data. The fast.ai inference tutorial 404sAt the time of publishing., so I decided to write a tutorial detailing how to use fast.ai for inference, how to save and load fast.ai models, and how to avoid the few pitfalls along the way.

# Train a Model

First, let’s quickly train a model on Imagenette using near SOTA settings. We’ll use the DataBlock API to define the training and validation set, define some minimal augmentation, directly create a Learner objectInstead of using vision_learner., and then train the model for five epochs.

from fastai.vision.all import *

source = untar_data(URLs.IMAGENETTE_320)
workers = min(8, num_cpus())

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, 
                   get_y=parent_label,
                   item_tfms=[RandomResizedCrop(192, min_scale=0.35), 
                              FlipItem(0.5)],
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

dls = dblock.dataloaders(source, bs=64, num_workers=workers)

learn = Learner(dls, xresnet50(n_out=dls.c), opt_func=ranger, 
                loss_func=LabelSmoothingCrossEntropyFlat(),
                metrics=accuracy)

learn.fit_flat_cos(5, 8e-3)

Using a Google Colab or Kaggle Tesla P100 instance, this model will take roughly six minutes to train and reach an accuracy of near 80 percent.

# Saving and Loading

Fast.ai offers two out of the box methods to save a modelFast.ai is a PyTorch framework. All fast.ai models are PyTorch models, thus it’s possible to follow the PyTorch saving and loading tutorial directly from Learner.model. for later use: Learner.export and Learner.saveCallbacks like SaveModelCallback use Learner.save under the hood.. Each is paired with its own loading method: load_learner and Learner.load for export and save, respectively.

# Export and load_learner

Pass an export file name and Learner.export will pickle and save the learner object to the Learner.path, which by default is the current working directory unless a path location was set in the DataBlock.dataloaders or Learner.

learn.export('xresnet50_export.pkl')

Export pickles the entire learner object. Including the dataloaders, loss function, optimizer, augmentations or transforms, and all callbacks. All of these are loaded and restored when calling load_learnerA notable exception is the MixedPrecision callback, which is currently removed on load. The model is set to full precision.. While load_learner restores the dataloader and all its settings, the loaded dataloader will not point to any dataThe assumption is the dataloader is being loaded in production so the training set doesn’t exist.. We’ll add data to the dataloader in the inference section of this tutorial.

For example, if you use Weights & Biases to log training metrics via the WandbCallback, but don’t install and import the wandb package in your inference environment, load_learner will error out. To remove a callback, use either Learner.remove_cb or Learner.remove_cbs. Likewise, you can set the loss function (loss_func) and/or optimizer (opt_func) to a fast.ai version to limit duplicated code in the inference environment.

To load an exported fast.ai learner, call load_learner with the path to the exported learner.

learn = load_learner('xresnet50_export.pkl', cpu=False)

By default, load_learner will load to CPU. If you are going to predict in batches, then you probably want to set cpu=False to load to the original device, or load to CPU then manually transfer to the new device.

learn.dls.to(device='cuda')
learn.model.to(device='cuda')

# Save and Load

Learner.save only pickels the model weights and optionally the optimizer state. save writes to Learner.path/Learner.model_dir, which by default is set to the models folder in the current working directory.

Since we won’t be training the model further, we can ignore the optimizer and only save the model weights.

learn.save('xresnet50_save', with_opt=False)

On load we’ll need to recreate the Learner and dataloaders so the model inputs during prediction are the same as they would be during the validation step when we were training the model.

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=GrandparentSplitter(valid_name='val'),
                   get_items=get_image_files, 
                   get_y=parent_label,
                   item_tfms=RandomResizedCrop(192),
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

dls = dblock.dataloaders(source, bs=64, num_workers=workers)

learn = Learner(dls, xresnet50(n_out=dls.c),
                loss_func=LabelSmoothingCrossEntropyFlat(),
                metrics=accuracy)

A discerning eye will note that I changed some of the item_tfms from the original training datablockIf we were going to use test time augmentation, we could alter the transforms applied here, or pass them directly into Learner.tta. We don’t need FlipItem, as it’s not applied on the validation set. Neither do we need to set a minimum scale, as the image isn’t randomly resized during the validation, but rather resized and center cropped. There’s also no need to specify an optimizer in the Learner as we won’t be training the model further.

If you don’t know which augmentations are ignored during validationTransforms with an explicit split_idx=0 are ignored during validation., there’s no harm in leaving them in the recreated dataloader and letting fast.ai take care of things under the hood.

Once the Learner is recreated, including it’s dataloaders, we can load the model weights using Learner.load, optionally changing the device where the model will load too.

learn.load('xresnet50_save')

# Loss Activation and Decodes

Before diving into inference, we need to take a quick detour to discuss the custom fast.ai loss functions. If we look at fast.ai’s version of cross entropy loss:

class CrossEntropyLossFlat(BaseLoss):
    "Same as `nn.CrossEntropyLoss`, but flattens input and target."
    y_int = True
    @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean')
    def __init__(self, *args, axis=-1, **kwargs): 
        super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs)
        
    def decodes(self, x):
        return x.argmax(dim=self.axis)
    
    def activation(self, x):
        return F.softmax(x, dim=self.axis)

we see fast.ai adds two convenience methods which can be used during inference: decodes and activation. The activation method replicates the loss function’s fused activation, while the decodes method is used to transform the activation’s output to final predictions.

Inference using a loss function without these methods results in the raw output of the model. See the appendix for an example from our classification problem.

# Inference

Now that our model is loaded, the dataloaders were recreated or restored, and we know about decodes and activation, we are ready to begin inference.

# Predicting One Item

For inference on one item at a time, fast.ai has Learner.predict. This method is often used for inference using a CPU. predict expects the input to be the same format as the dataloader. In our case that’s a filename.

learn.predict(source/'val/n01440764/ILSVRC2012_val_00009111.JPEG')

If we had defined a dataloader to expect a pandas DataFrame:

dblock = DataBlock(blocks=(ImageBlock, CategoryBlock),
                   splitter=ColSplitter('valid'),
                   get_x=ColReader('image'), 
                   get_y=ColReader('label'),
                   item_tfms=RandomResizedCrop(192),
                   batch_tfms=Normalize.from_stats(*imagenet_stats))

dls = block.dataloaders(df, bs=64, num_workers=workers)

then we would need to pass a single DataFrame row with an ‘image’ column.

By default, predict returns a tuple with three items: a fully decoded prediction including reversing transforms from the dataloader, a decoded prediction using decodes, and the prediction from the model passed through the loss function’s activation.

Since this is a classification problem, predict using our model returns the predicted class label, predicted numeric class, and the probabilities for each class.

('n01440764',
 tensor(0),
 tensor([0.8351, 0.0355, 0.0105, 0.0629, 0.0031, 0.0168, 0.0131, 0.0059, 0.0101,
         0.0072]))

# Batch Prediction

For inference on the entirety of a datasetIt’s also possible to batch multiple simultaneous requests from multiple users for increased inference speed. But that is beyond the scope of this tutorial., predicting one item at a time is an inefficient use of the parallel computational power of a GPU. We’ll want to use batch prediction using Learner.get_preds.

First, we need to create a test dataloader. The easiest way is to use the DataLoaders.test_dl method. test_dl will use the validation augmentations from our original or recreated dataloaders.

Since Imagenette doesn’t have a test set, we’ll reuse the validation set for this tutorial. Since our original dataloader expects file names as the input, let’s use list comprehension to grab a list of file namesglob uses os.listdir which provides files in an arbitrary order, hence the need for sorted. and then create the test dataloader.

test_files = [fn for fn in sorted((source/'val').glob('**/*')) if fn.is_file()]

test_dl = learn.dls.test_dl(test_files)

As mentioned earlier, if our dataloader read pandas dataframes we’d pass it a DataFrame with the list of inputs.

Now that we have a test dataloader, we can pass it to get_preds for batch prediction. By default, get_preds returns a tuple of predictions and labels. Since this is a test set and we don’t have any labelsMore accurately, we’re pretending not to have any labels., we can ignore the second half of the tuple.

preds, _ = learn.get_preds(dl=test_dl)

If we look at the first predicted item, it is the same probabilities for each class as the third result from predict.

tensor([0.8351, 0.0355, 0.0105, 0.0629, 0.0031, 0.0168, 0.0131, 0.0059, 0.0101,
        0.0072])

To return decoded results, set the with_decoded variable to True.

preds, _, decoded = learn.get_preds(dl=test_dl, with_decoded=True)

The first result of decoded will match the second result from predict, which in our case is the predicted numeric class.

tensor(0)

# Conclusion

I will finish this tutorial with a list reiterating the four main things to remember when using fast.ai for inference.

  1. load_learner requires all custom code and packages to be replicated in the inference environment.
  2. Learner.load requires you to both recreate the original model and duplicate the original dataloader’s validation transforms.
  3. Use Dataloaders.test_dl to create a test dataloader for batch prediction when using Learner.get_preds.
  4. Fast.ai loss functions add activation and decodes convenience methods, which you might want to reproduce if using a non-fast.ai loss function.

# Appendix: Results Without Activation and Decodes

If we replace the fast.ai version of cross entropy loss with a PyTorch version which doesn’t have the decodes or activation convenience methods and predict a single image:

learn = Learner(dls, xresnet50(n_out=dls.c),
                loss_func=nn.CrossEntropyLoss())

we get the raw results from the model along with all labels from the dataloader (if applicable).

learn.predict(source/'val/n01440764/ILSVRC2012_val_00009111.JPEG')
("['n03000684', 'n01440764', 'n03888257', 'n01440764', 'n03445777', 'n01440764', 'n01440764', 'n03888257', 'n03888257', 'n03888257']",
 tensor([ 3.2740,  0.1149, -1.1062,  0.6877, -2.3351, -0.6300, -0.8832, -1.6813,
         -1.1409, -1.4827]),
 tensor([ 3.2740,  0.1149, -1.1062,  0.6877, -2.3351, -0.6300, -0.8832, -1.6813,
         -1.1409, -1.4827]))

Likewise, using get_preds using a loss function without the the decodes or activation convenience methods results with the same raw model output.

preds, _ = learn.get_preds(dl=test_dl)
tensor([ 3.2740,  0.1149, -1.1062,  0.6877, -2.3351, -0.6300, -0.8832, -1.6813,
        -1.1409, -1.4827])
Previous

PyTorch 1.9 added a native implementation of Mish, my go to activation function for computer vision tasks. In this post I benchmark the computational performance...

Next

SageMaker is a strong contender for those starting out in deep learning and almost a straight upgrade from the free...