Tuesday, January 13, 2015

argmax in Python

I started learning Python. Glanced through tutorials, downloaded & installed the interpreter, and decided to implement something I had wanted to craft for a while, a game theory stuff for a specific game (computing a gto strategy for a stochastic game if you wonder). Also, I wanted to craft a genetic algorithm and a simulation part where the evolving strategies compete against each other. At some point I needed an $argmax$ function, taking a list and returning the (say, first) index of the list where a maximum value is present. (This was at the simulation part, to determine the index of the winning player.)

After googling a few minutes I realized that there is no such built-in $argmax$ function in the standard libs of Python and several people on the web are actually arguing which approach is the best.

Probably the "pythonic" solution is that I should use a dictionary (i.e. a Map) instead, whenever the indices actually matter - I'll try also that. However, it seems that is's also a non-obvious question :P

We have this post on StackOverflow. From the accepted answer we get that if $keys$ is an iterable and $f$ is a function mapping keys to some values, then $max( keys, f )$ is exactly an $argmax$ function: gives back the element $k\in keys$ on which $f$ takes its maximum.

Now if $t$ is an array, then its domain is $range(len(t))$ : $len(t)$ gives back the size of the array and $range( k )$ produces sort of an iterable, the set ${0,\ldots, k-1}$ of indices, without explicitly constructing it (at least in Python 3.0. In previous editions I had to use $xrange$ to this behaviour, while the old version of $range$ constructed the set. But since 3.0, $range$ is the new $xrange$.

So we have $keys$, what's with $f$? We need a function there which gets an index $i$ and returns $t[i]$. At first sight I thought immediately, that's what those $lamba$s are for! the expression $lambda x:t[x]$ is a function, which gets $x$ as an argument and returns $t[x]$.

Hence we have our first $argmax$ variant:

def argmax1( t ):
    return max( range(len(t)), key=lambda x:t[x] )

I was happy for a couple of minutes having figured out this one (recall that's my very first Python code :P ), I even tweeted that one out. Then I decided to do some benchmarking whether it's the most efficient option or not. Okay, Python is not for speed, I know that but I'm a courious guy. Also, there is a debate on SO whether those $lambda$s are usable at all, and scrolling through the library reference I found that arrays have a $\_\_getitem\_\_$ function (and of course $t.\_\_getitem\_\_(i) $ returns $t[i]$). Hence the second variant became

def argmax2( t ):
    return max( range(len(t)), key=t.__getitem__ )

There are several other approaches, e.g. first calling $max$ to find the maximum value, then calling $index$ to find the index:

def argmax3( t ):
    return t.index( max( t ) )

Theoretically I agreed with the comment "This approach is theoretically suboptimal because it iterates over the list twice." However, I thought it did not hurt to implement it. Also, this guy here claims it beats another variant (without giving any specific detail on his benchmark) and I wanted to double-check. Also, from him using $xrange$ we know that it's a pre-3.0 code so things could have changed.

Another comment suggested

def argmax4( t ):
    return max(enumerate(t), key=lambda t:t[1])[0]

that is, converting the array first to an iterable of $key-value$ pairs, and crafting a $lambda$ returning the $value$ part. $Max$ing this we get a $(key,value)$ with the maximal value, and getting its $0$th entry returns the $key$ we seek for. OK. Been there, done that.

I was also curious how a totally non-Pythonic method performs:
def argmax5( t ):
    ret = 0
    maxval = t[0]
    for i in range(len(t)):
      if maxval < t[i]:
        ret, maxval = i, t[i]
    return ret
that is, a totally imperative approach as I would do in C, say.

The previous post compared the $index-max$ combo with
def argmax6( t ):
    return max(zip(t, range(len(t))))[1]
so I included this one as well. Here, the zip method takes two lists $(a_1, a_2, \ldots, a_k)$ and $(b_1, b_2, \ldots, b_k)$ and produces $((a_1,b_1), (a_2,b_2), \ldots, (a_k,b_k))$, then computes the (lexicographic) maximum which is exactly the $(value,index)$ pair we seek, then with the projection to the $index$ coordinate it gets rid of the value, neat. He uses $izip$ and $xrange$ since "these iterators are lazy" but since Python 3.0 $zip$ and $range$ are lazy as well (and $zip$ and $xrange$ are non-existent) so I used the latter ones.

Since I already had six different implementations for such an obvious task (and the Zen of Python states that "There should be one-- and preferably only one --obvious way to do it."... okay, it also states that "Although that way may not be obvious at first unless you're Dutch." and I'm not Dutch, so it's okay I guess), I decided to run a small benchmark, generating a random array of length, say, 100.000:
t = [random.random() for _ in range(100000)]
(easy enough), and timing the functions on this array:
guys = [argmax1, argmax2, argmax3, argmax4, argmax5, argmax6 ]
for f in guys:
    print(f.__name__, timeit.timeit(f, number=10))
timing each guy ten times.

My expectation was that $argmax5$ and $argmax3$ should be the slowest due to their non-Pythonic and obviously inefficient way; $argmax4$ and $argmax6$ seemed to be roughly equivalent ($zip$ vs $enumerate$ for essentially the same tupling thing), and $argmax1$ and $argmax2$ also seemed to be equivalent (but no tupling there), so I guessed these will score best, with $argmax2$ being slightly faster due to the abscence of lambdas. Complete code is here.

Aaaand the winner is:
argmax1 0.23996164304831566
argmax2 0.13004148940632398
argmax3 0.04990375064590147
argmax4 0.25151111932153347
argmax5 0.12684567172410843
argmax6 0.14380688729499602

WHAT. THE. HECK.

For those of you who cannot read, compare numbers, interpret colors or simply forgot which function was which:
1. The fastest method became the $max-index$ combo, which is theoretically the worst one since it traverses the array twice. With a runtime 0.04. The second one having a runtime of 0.12.
2. The second, third and fourth place go to variants 5, 2 and 6 with not too big differences, 0.12--0.14. That is, to the totally non-Pythonic usage of a $for$ loop, to the $\_\_getitem\_\_$ variant of the $max$ function and to the $zip$per method.
3. At the end there is $argmax1$ with 0.23 secs (that's the first one, $max$ with the $lambda$ key) and $argmax4$ with 0.25 secs (the $enumerate$-$lambda$ combo). About six times slower than the best one and about two times slower than the other candidates.

Soo.. for me these results yield that
1. I will probably not use $lambda$s ever, they seem to ruin readability AND performance at the same time.
2. Calling $range(len(t))$ and iterating it probably has some heavy overhead. I really think that it's the bottleneck for those implementations ranging at 0.12--0.14 secs since it's the common part of them. 

Any thoughts on that? Python experts out there? :-)