Introducing Pytorch for fast.ai

The next fast.ai courses will be based nearly entirely on a new framework we have developed, built on Pytorch. Pytorch is a different kind of deep learning library (dynamic, rather than static), which has been adopted by many (if not most) of the researchers that we most respect, and in a recent Kaggle competition was used by nearly all of the top 10 finishers.

We have spent around a thousand hours this year working with Pytorch to get to this point, and we are very excited about what it is allowing us to do. We will be writing a number of articles in the coming weeks talking about each aspect of this. First, we will start with a quick summary of the background to, and implications of, this decision. Perhaps the best summary, however, is this snippet from the start of our first lesson:

Excerpt from Lesson 1, showing 99.32% accuracy
Excerpt from Lesson 1, showing 99.32% accuracy

fast.ai’s teaching goal

Our goal at fast.ai is for there to be nothing to teach. We believe that the fact that we currently require high school math, one year of coding experience, and seven weeks of study to become a world-class deep learning practitioner, is not an acceptable state of affairs (even although this is less prerequisites for any other course of a similar level). Everybody should be able to use deep learning to solve their problems with no more education than it takes to use a smart phone. Therefore, each year our main research goal is to be able to teach a wider range of deep learning applications, that run faster, and are more accurate, to people with less prerequisites.

We want our students to be able to solve their most challenging and important problems, to transform their industries and organisations, which we believe is the potential of deep learning. We are not just trying to teach people how to get existing jobs in the field — but to go far beyond that.

Therefore, since we first ran our deep learning course, we have been constantly curating best practices, and benchmarking and developing many techniques, trialling them against Kaggle leaderboards and academic state-of-the-art results.

Why we tried Pytorch

As we developed our second course, Cutting-Edge Deep Learning for Coders, we started to hit the limits of the libraries we had chosen: Keras and Tensorflow. For example, perhaps the most important technique in natural language processing today is the use of attentional models. We discovered that there was no effective implementation of attentional models for Keras at the time, and the Tensorflow implementations were not documented, rapidly changing, and unnecessarily complex. We ended up writing our own in Keras, which turned out to take a long time, and be very hard to debug. We then turned our attention to implementing dynamic teacher forcing, for which we could find no implementation in either Keras or Tensorflow, but is a critical technique for accurate neural translation models. Again, we tried to write our own, but this time we just weren’t able to make anything work.

At that point the first pre-release of Pytorch had just been released. The promise of Pytorch was that it was built as a dynamic, rather than static computation graph, framework (more on this in a later post). Dynamic frameworks, it was claimed, would allow us to write regular Python code, and use regular python debugging, to develop our neural network logic. The claims, it turned out, were totally accurate. We had implemented attentional models and dynamic teacher forcing from scratch in Pytorch within a few hours of first using it.

Some pytorch benefits for us and our students

The focus of our second course is to allow students to be able to read and implement recent research papers. This is important because the range of deep learning applications studied so far has been extremely limited, in a few areas that the academic community happens to be interested in. Therefore, solving many real-world problems with deep learning requires an understanding of the underlying techniques in depth, and the ability to implement customised versions of them appropriate for your particular problem, and data. Because Pytorch allowed us, and our students, to use all of the flexibility and capability of regular python code to build and train neural networks, we were able to tackle a much wider range of problems.

An additional benefit of Pytorch is that it allowed us to give our students a much more in-depth understanding of what was going on in each algorithm that we covered. With a static computation graph library like Tensorflow, once you have declaratively expressed your computation, you send it off to the GPU where it gets handled like a black box. But with a dynamic approach, you can fully dive into every level of the computation, and see exactly what is going on. We believe that the best way to learn deep learning is through coding and experiments, so the dynamic approach is exactly what we need for our students.

Much to our surprise, we also found that many models trained quite a lot faster on pytorch than they had on Tensorflow. This was quite against the prevailing wisdom, that said that static computation graphs should allow for more optimization to be done, which should have resulted in higher performance in Tensorflow. In practice, we’re seeing some models are a bit faster, some a bit slower, and things change in this respect every month. The key issues seem to be that:

Why we built a new framework on top of Pytorch

Unfortunately, Pytorch was a long way from being a good option for part one of the course, which is designed to be accessible to people with no machine learning background. It did not have anything like the clear simple API of Keras for training models. Every project required dozens of lines of code just to implement the basics of training a neural network. Unlike Keras, where the defaults are thoughtfully chosen to be as useful as possible, Pytorch required everything to be specified in detail. However, we also realised that Keras could be even better. We noticed that we kept on making the same mistakes in Keras, such as failing to shuffle our data when we needed to, or vice versa. Also, many recent best practices were not being incorporated into Keras, particularly in the rapidly developing field of natural language processing. We wondered if we could build something that could be even better than Keras for rapidly training world-class deep learning models.

After a lot of research and development it turned out that the answer was yes, we could (in our biased opinion). We built models that are faster, more accurate, and more complex than those using Keras, yet were written with much less code. We’ve implemented recent papers that allow much more reliable training of more accurate models, across a number of fields.

The key was to create an OO class which encapsulated all of the important data choices (such as preprocessing, augmentation, test, training, and validation sets, multiclass versus single class classification versus regression, et cetera) along with the choice of model architecture. Once we did that, we were able to largely automatically figure out the best architecture, preprocessing, and training parameters for that model, for that data. Suddenly, we were dramatically more productive, and made far less errors, because everything that could be automated, was automated. But we also provided the ability to customise every stage, so we could easily experiment with different approaches.

With the increased productivity this enabled, we were able to try far more techniques, and in the process we discovered a number of current standard practices that are actually extremely poor approaches. For example, we found that the combination of batch normalisation (which nearly all modern CNN architectures use) and model pretraining and fine-tuning (which you should use in every project if possible) can result in a 500% decrease in accuracy using standard training approaches. (We will be discussing this issue in-depth in a future post.) The results of this research are being incorporated directly into our framework.

There will be a limited release for our in person students at USF first, at the end of October, and a public release towards the end of the year. (By which time we’ll need to pick a name! Suggestions welcome…) (If you want to join the in-person course, there’s still room in the International Fellowship program.)

What should you be learning?

If it feels like new deep learning libraries are appearing at a rapid pace nowadays, then you need to be prepared for a much faster rate of change in the coming months and years. As more people enter the field, they will bring more skills and ideas, and try more things. You should assume that whatever specific libraries and software you learn today will be obsolete in a year or two. Just think about the number of changes of libraries and technology stacks that occur all the time in the world of web programming — and yet this is a much more mature and slow-growing area than deep learning. So we strongly believe that the focus in learning needs to be on understanding the underlying techniques and how to apply them in practice, and how to quickly build expertise in new tools and techniques as they are released.

By the end of the course, you’ll understand nearly all of the code that’s inside the framework, because each lesson we’ll be digging a level deeper to understand exactly what’s going on as we build and train our models. This means that you’ll have learnt the most important best practices used in modern deep learning—not just how to use them, but how they really work and are implemented. If you want to use those approaches in another framework, you’ll have the knowledge you need to develop it if needed.

To help students learn new frameworks as they need them, we will be spending one lesson learning to use Tensorflow, MXNet, CNTK, and Keras. We will work with our students to port our framework to these other libraries, which will make for some great class projects.

We will also spend some time looking at how to productionize deep learning models. Unless you are working at Google-scale, your best approach will probably be to create a simple REST interface on top of your Pytorch model, running inference on the CPU. If you need to scale up to very high volume, you can export your model (as long as it does not use certain kinds of customisations) to Caffe2 or CNTK. If you need computation to be done on a mobile device, you can either export as mentioned above, or use an on-device library.

How we feel about Keras

We still really like Keras. It’s a great library and is far better for fairly simple models than anything that came before. It’s very easy to move between Keras and our new framework, at least for the subset of tasks and architectures that Keras supports. Keras supports lots of backend libraries which means you can run Keras code in many places.

It has a unique (to our knowledge) approach to defining architectures where authors of custom layers are required to create a build() method which tells Keras what shape output it creates for a given input. This allows users to more easily create simple architectures because they almost never have to specify the number of input channels for a layer. For architectures like Densenet which concatenate layers it can make the code quite a bit simpler.

On the other hand, it tends to make it harder to customize models, especially during training. More importantly, the static computation graph on the backend, along with Keras’ need for an extra compile() phase, means that it’s hard to customize a model’s behaviour once it’s built.

What’s next for fast.ai and Pytorch

We expect to see our framework and how we teach Pytorch develop a lot as we teach the course and get feedback and ideas from our students. In past courses students have developed a lot of interesting projects, many of which have helped other students—we expect that to continue. Given the accelerating progress in deep learning, it’s quite possible that by this time next year, there will be very different hardware or software options that will make todays’ technology quite obsolete. Although based on the quick adoption of new technologies we’ve seen from the Pytorch developers, we suspect that they might stay ahead of the curve for a while at least…

In our next post, we’ll be talking more about some of the standout features of Pytorch, and dynamic libraries more generally.


To discuss this post at Hacker News, click here