Over the past The code for this post can be found here. few months, I’ve seen multiple people ask how to correctly use fast.ai to predict labels for new data. The fast.ai inference tutorial 404sAt the time of publishing., so I decided to write a tutorial detailing how to use fast.ai for inference, how to save and load fast.ai models, and how to avoid the few pitfalls along the way.
# Train a Model
First, let’s quickly train a model on Imagenette using near SOTA settings. We’ll use the DataBlock API to define the training and validation set, define some minimal augmentation, directly create a
Learner objectInstead of using
vision_learner., and then train the model for five epochs.
from fastai.vision.all import * source = untar_data(URLs.IMAGENETTE_320) workers = min(8, num_cpus()) dblock = DataBlock(blocks=(ImageBlock, CategoryBlock), splitter=GrandparentSplitter(valid_name='val'), get_items=get_image_files, get_y=parent_label, item_tfms=[RandomResizedCrop(192, min_scale=0.35), FlipItem(0.5)], batch_tfms=Normalize.from_stats(*imagenet_stats)) dls = dblock.dataloaders(source, bs=64, num_workers=workers) learn = Learner(dls, xresnet50(n_out=dls.c), opt_func=ranger, loss_func=LabelSmoothingCrossEntropyFlat(), metrics=accuracy) learn.fit_flat_cos(5, 8e-3)
Using a Google Colab or Kaggle Tesla P100 instance, this model will take roughly six minutes to train and reach an accuracy of near 80 percent.
# Saving and Loading
Fast.ai offers two out of the box methods to save a modelFast.ai is a PyTorch framework. All fast.ai models are PyTorch models, thus it’s possible to follow the PyTorch saving and loading tutorial directly from
Learner.model. for later use:
Learner.save under the hood.. Each is paired with its own loading method:
Learner.load for export and save, respectively.
# Export and load_learner
Pass an export file name and
Learner.export will pickle and save the learner object to the
Learner.path, which by default is the current working directory unless a path location was set in the
Export pickles the entire learner object. Including the dataloaders, loss function, optimizer, augmentations or transforms, and all callbacks. All of these are loaded and restored when calling
load_learnerA notable exception is the MixedPrecision callback, which is currently removed on load. The model is set to full precision.. While
load_learner restores the dataloader and all its settings, the loaded dataloader will not point to any dataThe assumption is the dataloader is being loaded in production so the training set doesn’t exist.. We’ll add data to the dataloader in the inference section of this tutorial.
For example, if you use Weights & Biases to log training metrics via the
WandbCallback, but don’t install and import the wandb package in your inference environment,
load_learner will error out. To remove a callback, use either
Learner.remove_cbs. Likewise, you can set the loss function (
loss_func) and/or optimizer (
opt_func) to a fast.ai version to limit duplicated code in the inference environment.
To load an exported fast.ai learner, call
load_learner with the path to the exported learner.
learn = load_learner('xresnet50_export.pkl', cpu=False)
load_learner will load to CPU. If you are going to predict in batches, then you probably want to set
cpu=False to load to the original device, or load to CPU then manually transfer to the new device.
# Save and Load
Learner.save only pickels the model weights and optionally the optimizer state.
save writes to
Learner.model_dir, which by default is set to the
models folder in the current working directory.
Since we won’t be training the model further, we can ignore the optimizer and only save the model weights.
On load we’ll need to recreate the
Learner and dataloaders so the model inputs during prediction are the same as they would be during the validation step when we were training the model.
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock), splitter=GrandparentSplitter(valid_name='val'), get_items=get_image_files, get_y=parent_label, item_tfms=RandomResizedCrop(192), batch_tfms=Normalize.from_stats(*imagenet_stats)) dls = dblock.dataloaders(source, bs=64, num_workers=workers) learn = Learner(dls, xresnet50(n_out=dls.c), loss_func=LabelSmoothingCrossEntropyFlat(), metrics=accuracy)
A discerning eye will note that I changed some of the
item_tfms from the original training datablockIf we were going to use test time augmentation, we could alter the transforms applied here, or pass them directly into
Learner.tta. We don’t need
FlipItem, as it’s not applied on the validation set. Neither do we need to set a minimum scale, as the image isn’t randomly resized during the validation, but rather resized and center cropped. There’s also no need to specify an optimizer in the
Learner as we won’t be training the model further.
If you don’t know which augmentations are ignored during validationTransforms with an explicit
split_idx=0 are ignored during validation., there’s no harm in leaving them in the recreated dataloader and letting fast.ai take care of things under the hood.
Learner is recreated, including it’s dataloaders, we can load the model weights using
Learner.load, optionally changing the device where the model will load too.
# Loss Activation and Decodes
Before diving into inference, we need to take a quick detour to discuss the custom fast.ai loss functions. If we look at fast.ai’s version of cross entropy loss:
class CrossEntropyLossFlat(BaseLoss): "Same as `nn.CrossEntropyLoss`, but flattens input and target." y_int = True @use_kwargs_dict(keep=True, weight=None, ignore_index=-100, reduction='mean') def __init__(self, *args, axis=-1, **kwargs): super().__init__(nn.CrossEntropyLoss, *args, axis=axis, **kwargs) def decodes(self, x): return x.argmax(dim=self.axis) def activation(self, x): return F.softmax(x, dim=self.axis)
we see fast.ai adds two convenience methods which can be used during inference:
activation method replicates the loss function’s fused activation, while the
decodes method is used to transform the activation’s output to final predictions.
Inference using a loss function without these methods results in the raw output of the model. See the appendix for an example from our classification problem.
Now that our model is loaded, the dataloaders were recreated or restored, and we know about
activation, we are ready to begin inference.
# Predicting One Item
For inference on one item at a time, fast.ai has
Learner.predict. This method is often used for inference using a CPU.
predict expects the input to be the same format as the dataloader. In our case that’s a filename.
If we had defined a dataloader to expect a pandas
dblock = DataBlock(blocks=(ImageBlock, CategoryBlock), splitter=ColSplitter('valid'), get_x=ColReader('image'), get_y=ColReader('label'), item_tfms=RandomResizedCrop(192), batch_tfms=Normalize.from_stats(*imagenet_stats)) dls = block.dataloaders(df, bs=64, num_workers=workers)
then we would need to pass a single
DataFrame row with an ‘image’ column.
predict returns a tuple with three items: a fully decoded prediction including reversing transforms from the dataloader, a decoded prediction using
decodes, and the prediction from the model passed through the loss function’s
Since this is a classification problem,
predict using our model returns the predicted class label, predicted numeric class, and the probabilities for each class.
('n01440764', tensor(0), tensor([0.8351, 0.0355, 0.0105, 0.0629, 0.0031, 0.0168, 0.0131, 0.0059, 0.0101, 0.0072]))
# Batch Prediction
For inference on the entirety of a datasetIt’s also possible to batch multiple simultaneous requests from multiple users for increased inference speed. But that is beyond the scope of this tutorial., predicting one item at a time is an inefficient use of the parallel computational power of a GPU. We’ll want to use batch prediction using
First, we need to create a test dataloader. The easiest way is to use the
test_dl will use the validation augmentations from our original or recreated dataloaders.
Since Imagenette doesn’t have a test set, we’ll reuse the validation set for this tutorial. Since our original dataloader expects file names as the input, let’s use list comprehension to grab a list of file names
os.listdir which provides files in an arbitrary order, hence the need for
sorted. and then create the test dataloader.
test_files = [fn for fn in sorted((source/'val').glob('**/*')) if fn.is_file()] test_dl = learn.dls.test_dl(test_files)
As mentioned earlier, if our dataloader read pandas dataframes we’d pass it a
DataFrame with the list of inputs.
Now that we have a test dataloader, we can pass it to
get_preds for batch prediction. By default,
get_preds returns a tuple of predictions and labels. Since this is a test set and we don’t have any labelsMore accurately, we’re pretending not to have any labels., we can ignore the second half of the tuple.
preds, _ = learn.get_preds(dl=test_dl)
If we look at the first predicted item, it is the same probabilities for each class as the third result from
tensor([0.8351, 0.0355, 0.0105, 0.0629, 0.0031, 0.0168, 0.0131, 0.0059, 0.0101, 0.0072])
To return decoded results, set the
with_decoded variable to
preds, _, decoded = learn.get_preds(dl=test_dl, with_decoded=True)
The first result of
decoded will match the second result from
predict, which in our case is the predicted numeric class.
I will finish this tutorial with a list reiterating the four main things to remember when using fast.ai for inference.
load_learnerrequires all custom code and packages to be replicated in the inference environment.
Learner.loadrequires you to both recreate the original model and duplicate the original dataloader’s validation transforms.
Dataloaders.test_dlto create a test dataloader for batch prediction when using
- Fast.ai loss functions add
decodesconvenience methods, which you might want to reproduce if using a non-fast.ai loss function.
# Appendix: Results Without Activation and Decodes
If we replace the fast.ai version of cross entropy loss with a PyTorch version which doesn’t have the
activation convenience methods and predict a single image:
learn = Learner(dls, xresnet50(n_out=dls.c), loss_func=nn.CrossEntropyLoss())
we get the raw results from the model along with all labels from the dataloader (if applicable).
("['n03000684', 'n01440764', 'n03888257', 'n01440764', 'n03445777', 'n01440764', 'n01440764', 'n03888257', 'n03888257', 'n03888257']", tensor([ 3.2740, 0.1149, -1.1062, 0.6877, -2.3351, -0.6300, -0.8832, -1.6813, -1.1409, -1.4827]), tensor([ 3.2740, 0.1149, -1.1062, 0.6877, -2.3351, -0.6300, -0.8832, -1.6813, -1.1409, -1.4827]))
get_preds using a loss function without the the
activation convenience methods results with the same raw model output.
preds, _ = learn.get_preds(dl=test_dl)
tensor([ 3.2740, 0.1149, -1.1062, 0.6877, -2.3351, -0.6300, -0.8832, -1.6813, -1.1409, -1.4827])