Tuesday, January 13, 2015

argmax in Python

I started learning Python. Glanced through tutorials, downloaded & installed the interpreter, and decided to implement something I had wanted to craft for a while, a game theory stuff for a specific game (computing a gto strategy for a stochastic game if you wonder). Also, I wanted to craft a genetic algorithm and a simulation part where the evolving strategies compete against each other. At some point I needed an argmax function, taking a list and returning the (say, first) index of the list where a maximum value is present. (This was at the simulation part, to determine the index of the winning player.)

After googling a few minutes I realized that there is no such built-in argmax function in the standard libs of Python and several people on the web are actually arguing which approach is the best.

Probably the "pythonic" solution is that I should use a dictionary (i.e. a Map) instead, whenever the indices actually matter - I'll try also that. However, it seems that is's also a non-obvious question :P

We have this post on StackOverflow. From the accepted answer we get that if keys is an iterable and f is a function mapping keys to some values, then max(keys,f) is exactly an argmax function: gives back the element kkeys on which f takes its maximum.

Now if t is an array, then its domain is range(len(t)) : len(t) gives back the size of the array and range(k) produces sort of an iterable, the set 0,,k1 of indices, without explicitly constructing it (at least in Python 3.0. In previous editions I had to use xrange to this behaviour, while the old version of range constructed the set. But since 3.0, range is the new xrange.

So we have keys, what's with f? We need a function there which gets an index i and returns t[i]. At first sight I thought immediately, that's what those lambas are for! the expression lambdax:t[x] is a function, which gets x as an argument and returns t[x].

Hence we have our first argmax variant:

def argmax1( t ):
    return max( range(len(t)), key=lambda x:t[x] )

I was happy for a couple of minutes having figured out this one (recall that's my very first Python code :P ), I even tweeted that one out. Then I decided to do some benchmarking whether it's the most efficient option or not. Okay, Python is not for speed, I know that but I'm a courious guy. Also, there is a debate on SO whether those lambdas are usable at all, and scrolling through the library reference I found that arrays have a __getitem__ function (and of course t.__getitem__(i) returns t[i]). Hence the second variant became

def argmax2( t ):
    return max( range(len(t)), key=t.__getitem__ )

There are several other approaches, e.g. first calling max to find the maximum value, then calling index to find the index:

def argmax3( t ):
    return t.index( max( t ) )

Theoretically I agreed with the comment "This approach is theoretically suboptimal because it iterates over the list twice." However, I thought it did not hurt to implement it. Also, this guy here claims it beats another variant (without giving any specific detail on his benchmark) and I wanted to double-check. Also, from him using xrange we know that it's a pre-3.0 code so things could have changed.

Another comment suggested

def argmax4( t ):
    return max(enumerate(t), key=lambda t:t[1])[0]

that is, converting the array first to an iterable of keyvalue pairs, and crafting a lambda returning the value part. Maxing this we get a (key,value) with the maximal value, and getting its 0th entry returns the key we seek for. OK. Been there, done that.

I was also curious how a totally non-Pythonic method performs:
def argmax5( t ):
    ret = 0
    maxval = t[0]
    for i in range(len(t)):
      if maxval < t[i]:
        ret, maxval = i, t[i]
    return ret
that is, a totally imperative approach as I would do in C, say.

The previous post compared the indexmax combo with
def argmax6( t ):
    return max(zip(t, range(len(t))))[1]
so I included this one as well. Here, the zip method takes two lists (a1,a2,,ak) and (b1,b2,,bk) and produces ((a1,b1),(a2,b2),,(ak,bk)), then computes the (lexicographic) maximum which is exactly the (value,index) pair we seek, then with the projection to the index coordinate it gets rid of the value, neat. He uses izip and xrange since "these iterators are lazy" but since Python 3.0 zip and range are lazy as well (and zip and xrange are non-existent) so I used the latter ones.

Since I already had six different implementations for such an obvious task (and the Zen of Python states that "There should be one-- and preferably only one --obvious way to do it."... okay, it also states that "Although that way may not be obvious at first unless you're Dutch." and I'm not Dutch, so it's okay I guess), I decided to run a small benchmark, generating a random array of length, say, 100.000:
t = [random.random() for _ in range(100000)]
(easy enough), and timing the functions on this array:
guys = [argmax1, argmax2, argmax3, argmax4, argmax5, argmax6 ]
for f in guys:
    print(f.__name__, timeit.timeit(f, number=10))
timing each guy ten times.

My expectation was that argmax5 and argmax3 should be the slowest due to their non-Pythonic and obviously inefficient way; argmax4 and argmax6 seemed to be roughly equivalent (zip vs enumerate for essentially the same tupling thing), and argmax1 and argmax2 also seemed to be equivalent (but no tupling there), so I guessed these will score best, with argmax2 being slightly faster due to the abscence of lambdas. Complete code is here.

Aaaand the winner is:
argmax1 0.23996164304831566
argmax2 0.13004148940632398
argmax3 0.04990375064590147
argmax4 0.25151111932153347
argmax5 0.12684567172410843
argmax6 0.14380688729499602

WHAT. THE. HECK.

For those of you who cannot read, compare numbers, interpret colors or simply forgot which function was which:
1. The fastest method became the maxindex combo, which is theoretically the worst one since it traverses the array twice. With a runtime 0.04. The second one having a runtime of 0.12.
2. The second, third and fourth place go to variants 5, 2 and 6 with not too big differences, 0.12--0.14. That is, to the totally non-Pythonic usage of a for loop, to the __getitem__ variant of the max function and to the zipper method.
3. At the end there is argmax1 with 0.23 secs (that's the first one, max with the lambda key) and argmax4 with 0.25 secs (the enumerate-lambda combo). About six times slower than the best one and about two times slower than the other candidates.

Soo.. for me these results yield that
1. I will probably not use lambdas ever, they seem to ruin readability AND performance at the same time.
2. Calling range(len(t)) and iterating it probably has some heavy overhead. I really think that it's the bottleneck for those implementations ranging at 0.12--0.14 secs since it's the common part of them. 

Any thoughts on that? Python experts out there? :-)