Tuesday, January 13, 2015

argmax in Python

I started learning Python. Glanced through tutorials, downloaded & installed the interpreter, and decided to implement something I had wanted to craft for a while, a game theory stuff for a specific game (computing a gto strategy for a stochastic game if you wonder). Also, I wanted to craft a genetic algorithm and a simulation part where the evolving strategies compete against each other. At some point I needed an $argmax$ function, taking a list and returning the (say, first) index of the list where a maximum value is present. (This was at the simulation part, to determine the index of the winning player.)

After googling a few minutes I realized that there is no such built-in $argmax$ function in the standard libs of Python and several people on the web are actually arguing which approach is the best.

Probably the "pythonic" solution is that I should use a dictionary (i.e. a Map) instead, whenever the indices actually matter - I'll try also that. However, it seems that is's also a non-obvious question :P

We have this post on StackOverflow. From the accepted answer we get that if $keys$ is an iterable and $f$ is a function mapping keys to some values, then $max( keys, f )$ is exactly an $argmax$ function: gives back the element $k\in keys$ on which $f$ takes its maximum.

Now if $t$ is an array, then its domain is $range(len(t))$ : $len(t)$ gives back the size of the array and $range( k )$ produces sort of an iterable, the set ${0,\ldots, k-1}$ of indices, without explicitly constructing it (at least in Python 3.0. In previous editions I had to use $xrange$ to this behaviour, while the old version of $range$ constructed the set. But since 3.0, $range$ is the new $xrange$.

So we have $keys$, what's with $f$? We need a function there which gets an index $i$ and returns $t[i]$. At first sight I thought immediately, that's what those $lamba$s are for! the expression $lambda x:t[x]$ is a function, which gets $x$ as an argument and returns $t[x]$.

Hence we have our first $argmax$ variant:

def argmax1( t ):
    return max( range(len(t)), key=lambda x:t[x] )

I was happy for a couple of minutes having figured out this one (recall that's my very first Python code :P ), I even tweeted that one out. Then I decided to do some benchmarking whether it's the most efficient option or not. Okay, Python is not for speed, I know that but I'm a courious guy. Also, there is a debate on SO whether those $lambda$s are usable at all, and scrolling through the library reference I found that arrays have a $\_\_getitem\_\_$ function (and of course $t.\_\_getitem\_\_(i) $ returns $t[i]$). Hence the second variant became

def argmax2( t ):
    return max( range(len(t)), key=t.__getitem__ )

There are several other approaches, e.g. first calling $max$ to find the maximum value, then calling $index$ to find the index:

def argmax3( t ):
    return t.index( max( t ) )

Theoretically I agreed with the comment "This approach is theoretically suboptimal because it iterates over the list twice." However, I thought it did not hurt to implement it. Also, this guy here claims it beats another variant (without giving any specific detail on his benchmark) and I wanted to double-check. Also, from him using $xrange$ we know that it's a pre-3.0 code so things could have changed.

Another comment suggested

def argmax4( t ):
    return max(enumerate(t), key=lambda t:t[1])[0]

that is, converting the array first to an iterable of $key-value$ pairs, and crafting a $lambda$ returning the $value$ part. $Max$ing this we get a $(key,value)$ with the maximal value, and getting its $0$th entry returns the $key$ we seek for. OK. Been there, done that.

I was also curious how a totally non-Pythonic method performs:
def argmax5( t ):
    ret = 0
    maxval = t[0]
    for i in range(len(t)):
      if maxval < t[i]:
        ret, maxval = i, t[i]
    return ret
that is, a totally imperative approach as I would do in C, say.

The previous post compared the $index-max$ combo with
def argmax6( t ):
    return max(zip(t, range(len(t))))[1]
so I included this one as well. Here, the zip method takes two lists $(a_1, a_2, \ldots, a_k)$ and $(b_1, b_2, \ldots, b_k)$ and produces $((a_1,b_1), (a_2,b_2), \ldots, (a_k,b_k))$, then computes the (lexicographic) maximum which is exactly the $(value,index)$ pair we seek, then with the projection to the $index$ coordinate it gets rid of the value, neat. He uses $izip$ and $xrange$ since "these iterators are lazy" but since Python 3.0 $zip$ and $range$ are lazy as well (and $zip$ and $xrange$ are non-existent) so I used the latter ones.

Since I already had six different implementations for such an obvious task (and the Zen of Python states that "There should be one-- and preferably only one --obvious way to do it."... okay, it also states that "Although that way may not be obvious at first unless you're Dutch." and I'm not Dutch, so it's okay I guess), I decided to run a small benchmark, generating a random array of length, say, 100.000:
t = [random.random() for _ in range(100000)]
(easy enough), and timing the functions on this array:
guys = [argmax1, argmax2, argmax3, argmax4, argmax5, argmax6 ]
for f in guys:
    print(f.__name__, timeit.timeit(f, number=10))
timing each guy ten times.

My expectation was that $argmax5$ and $argmax3$ should be the slowest due to their non-Pythonic and obviously inefficient way; $argmax4$ and $argmax6$ seemed to be roughly equivalent ($zip$ vs $enumerate$ for essentially the same tupling thing), and $argmax1$ and $argmax2$ also seemed to be equivalent (but no tupling there), so I guessed these will score best, with $argmax2$ being slightly faster due to the abscence of lambdas. Complete code is here.

Aaaand the winner is:
argmax1 0.23996164304831566
argmax2 0.13004148940632398
argmax3 0.04990375064590147
argmax4 0.25151111932153347
argmax5 0.12684567172410843
argmax6 0.14380688729499602


For those of you who cannot read, compare numbers, interpret colors or simply forgot which function was which:
1. The fastest method became the $max-index$ combo, which is theoretically the worst one since it traverses the array twice. With a runtime 0.04. The second one having a runtime of 0.12.
2. The second, third and fourth place go to variants 5, 2 and 6 with not too big differences, 0.12--0.14. That is, to the totally non-Pythonic usage of a $for$ loop, to the $\_\_getitem\_\_$ variant of the $max$ function and to the $zip$per method.
3. At the end there is $argmax1$ with 0.23 secs (that's the first one, $max$ with the $lambda$ key) and $argmax4$ with 0.25 secs (the $enumerate$-$lambda$ combo). About six times slower than the best one and about two times slower than the other candidates.

Soo.. for me these results yield that
1. I will probably not use $lambda$s ever, they seem to ruin readability AND performance at the same time.
2. Calling $range(len(t))$ and iterating it probably has some heavy overhead. I really think that it's the bottleneck for those implementations ranging at 0.12--0.14 secs since it's the common part of them. 

Any thoughts on that? Python experts out there? :-)