The NumPy min and max trap

30.08.2016 20:21

There's an interesting trap that I managed to fall into a few times when doing calculations with Python. NumPy provides several functions with the same name as functions built into Python. These replacements typically provide a better integration with array types. Among them are the min() and max(). In vast majority of cases, NumPy versions are a drop-in replacement for built-ins. In a few, however, they can cause some very hard-to-spot bugs. Consider the following:

import numpy as np

print(max(-1, 0))
print(np.max(-1, 0))

This prints (at least in NumPy 1.11.1 and earlier):

0
-1

Where is the catch? The built-in max() can be used in two distinct ways: you can either pass it an iterable as the single argument (in which case the largest element of the iterable will be returned), or you can pass multiple arguments (in which case the largest argument will be returned). In NumPy, max() is an alias for amax() and that only supports the former convention. The second argument in the example above is interpreted as array axis along which to perform the maximum. It appears that NumPy thinks axis zero is a reasonable choice for a zero-dimensional input and doesn't complain.

Yes, recent versions of NumPy will complain if you have anything else than 0 or -1 in the axis argument. Having max(x, 0) in code is not that unusual though. I use it a lot as a shorthand when I need to clip negative values to 0. When moving code around between scripts that use NumPy, those that don't and IPython Notebooks (which do "from numpy import *" by default), its easy to mess things up.

I guess both sides are to blame here. I find that flexible functions that interpret arguments in multiple ways are usually bad practice and I try to leave them out of interfaces I design. Yes, they are convenient, but they also often lead to bugs. On the other hand, I would also expect NumPy to complain about the nonsensical axis argument. Axis -1 makes sense for a zero-dimensional input, axis 0 doesn't. The alias from max() to amax() is dangerous (and as far as I can see undocumented). A possible way to prevent such mistakes would be to support only the named version of the axis argument.

Posted by Tomaž | Categories: Code

Comments

Many things about the numpy API are counterintuitive, inconsistent, or unfriendly. I don't use it enough to have memorized all the idiosyncrasies, so I feel like I'm constantly having to debug this type of issue.

Another one that gets me is the odd overloadings of standard operators like *, which sometimes does element-wise multiplication and sometimes matrix multiplication.

Add a new comment


(No HTML tags allowed. Separate paragraphs with a blank line.)