Saturday, February 15, 2020

03 - tail recursion

Nézzük a következő két rekurzív függvényt:

def butaSum(n: Int ): Int =
  if (n <= 0) 0 else n + butaSum(n - 1)

println(butaSum(20000)) //StackOverflowError

def tailSum(n: Int, acc: Int ): Int =
  if ( n <= 0 ) acc else tailSum(n - 1, acc + n)

def tailSum(n: Int) = tailSum(n, 0)

println(tailSum(20000)) //prints 200010000
Mindkét függvény, a $\mathrm{butaSum}$ és a $\mathrm{tailSum}$ is ha megkapja az $n$ egész számot, akkor visszaadja az $1+2+\ldots+n$ összeget. Az első implementáció, ha (nálam, a default call stack mérettel)  $n=20000$-re hívjuk, elszáll egy $\mathrm{StackOverflowError}$ral, pont azért, ami az előző poszt végén áll: betelik a call stack.

Amit mindenesetre már láthatunk a kódból: Scalában nem baj, ha két függvénynek ugyanaz a neve, mindaddig, amíg a paraméterlistájuk nem egyforma hosszú, vagy legalább egy típusban nem térnek el. (tehát csak a paraméterek típusa számít, a nevük nem.) Mivel a $\mathrm{tailSum}$ függvény első változata két $\mathrm{Int}$et, a második pedig csak egy $\mathrm{Int}$et vár, így ez két különböző függvény, a fordító képes eldönteni, hogy mikor melyiket hívja a programozó.
Ezt úgy is nevezik, hogy a függvény polimorf (polymorphic) -- konkrétan, ha ugyanazon a néven több implementációja is van a függvénynek, melyek egészen eltérő paraméterekre másképp viselkednek, azt "ad hoc polymorphism"nak nevezik. Látni fogunk később másmilyen polimorfizmust is.

Értékeljük ki a $\mathrm{tailSum}(5)$ kifejezést, hogy legyen valami képünk arról, hogyan is működik és miért nem tölti a call stacket:
$\begin{align*}
\mathrm{tailSum}(5) &\triangleright \mathrm{tailSum}(5,0)\\
&\triangleright \mathrm{if}(5\leq 0)~0~\mathrm{else}~\mathrm{tailSum}(5-1,0+5)\\
&\triangleright \mathrm{if}(\mathrm{false})~0~\mathrm{else}~\mathrm{tailSum}(5-1,0+5)\\
&\triangleright \mathrm{tailSum}(5-1,0+5)\\
&\triangleright \mathrm{tailSum}(4,0+5)\\
&\triangleright \mathrm{tailSum}(4,5)\\
&\triangleright \mathrm{if}(4\leq 0)~5~\mathrm{else}~\mathrm{tailSum}(4-1,5+4)\\
&\triangleright \mathrm{if}(\mathrm{false})~5~\mathrm{else}~\mathrm{tailSum}(4-1,5+4)\\
&\triangleright \mathrm{tailSum}(4-1,5+4)\\
&\triangleright \mathrm{tailSum}(3,5+4)\\
&\triangleright \mathrm{tailSum}(3,9)\\
&\triangleright \mathrm{if}(3\leq 0)~9~\mathrm{else}~\mathrm{tailSum}(3-1,9+3)\\
&\triangleright \mathrm{if}(\mathrm{false})~9~\mathrm{else}~\mathrm{tailSum}(3-1,9+3)\\
&\triangleright \mathrm{tailSum}(3-1,9+3)\\
&\triangleright \mathrm{tailSum}(2,9+3)\\
&\triangleright \mathrm{tailSum}(2,12)\\
&\triangleright \mathrm{if}(2\leq 0)~12~\mathrm{else}~\mathrm{tailSum}(2-1,12+2)\\
&\triangleright \mathrm{if}(\mathrm{false})~12~\mathrm{else}~\mathrm{tailSum}(2-1,12+2)\\
&\triangleright \mathrm{tailSum}(2-1,12+2)\\
&\triangleright \mathrm{tailSum}(1,12+2)\\
&\triangleright \mathrm{tailSum}(1,14)\\
&\triangleright \mathrm{if}(1\leq 0)~14~\mathrm{else}~\mathrm{tailSum}(1-1,14+1)\\
&\triangleright \mathrm{if}(\textrm{false})~14~\mathrm{else}~\mathrm{tailSum}(1-1,14+1)\\
&\triangleright \mathrm{tailSum}(1-1,14+1)\\
&\triangleright \mathrm{tailSum}(0,14+1)\\
&\triangleright \mathrm{tailSum}(0,15)\\
&\triangleright \mathrm{if}(0\leq 0)~15~\mathrm{else}~\mathrm{tailSum}(0-1,15+0)\\
&\triangleright \mathrm{if}(\mathrm{true})~15~\mathrm{else}~\mathrm{tailSum}(0-1,15+0)\\
&\triangleright 15
\end{align*}$
és valóban, $1+2+3+4+5=15$.

Ha megfigyeljük a kifejezés kiértékelés lépéseit, láthatjuk azt is, hogy ez egy olyan rekurzió, amit bármekkora $n$-re is hívunk meg, a kifejezés mérete egy konstans korlát alatt marad (ha a számokat konstans méretűnek, pl $32$ bitesnek vesszük), pontosan ezért nem dobott $\mathrm{StackOverflowError}t, szemben a másik implementációval. Ha egy rekurzív függvény rendelkezik azzal a tulajdonsággal, hogy minden rekurzív hívás "végérték" pozícióban szerepel, azaz: a rekurzív hívás kiértékelése a végeredményt fogja adni, és nem csinálunk utána vele már semmit, akkor ezt farokrekurzív (tail recursive) függvénynek nevezzük és nagyon hasznos - ahogy a fenti levezetésben is látjuk, ha farokrekurzió van, az nem kell töltse a call stacket.
Legalábbis, Scalában nem tölti, mert a fordító kioptimalizálja, ezt a tevékenységét nevezik "tail call optimization"nak, TCO-nak. Mivel funkcionális programozási nyelvekben a rekurzió a fő "vezérlési szerkezet", így a funkcionális nyelvek jellemzően támogatják a TCO-t. A Java nem támogatja a TCO-t, ha Javában implementáljuk ezt az összeadó függvényt, ugyanúgy kivételt fog dobni kiértékeléskor.

Lássuk csak most, miért is működik helyesen ez a tail recursive implementáció? Nézegetve a futást, azt látjuk, hogy mintha a $\mathrm{tailSum}(n,acc)$ kifejezés értéke mindig $1+2+\ldots+n+acc$ lenne. Egy ilyen állítást, sejtést, összefüggést, whatevert például indukcióval tudunk belátni - a funkcionális nyelvek előnye, hogy nagyon sok esetben sokkal egyszerűbb reasoningolni a kód helyességéről, mint ha a kódban van mutable data is.
Az indukció pl. természetes számokra így működik: ha valamit be akarunk látni, hogy minden természetes számra igaz, akkor belátjuk $n=0$-ra, hogy igaz, majd, abból kiindulva, hogy $n$-re igaz, kihozzuk, hogy $n+1$-re is igaz lesz. Nézzük meg, ez most itt hogy megy:
  • azt akarjuk bebizonyítani, hogy $\mathrm{tailSum}(n,acc)=1+2+\ldots+n+acc$ minden $n$-re és $acc$-ra.
  • ezt $n$ szerinti indukcióval tesszük: minden $n$-re belátjuk, hogy bármi is $acc$, ez az összefüggés fennáll.
  • először megnézzük $n=0$-ra: akkor az $\mathrm{if}(0\leq 0)~acc~\mathrm{else}~\mathrm{tailSum}(n-1,acc+n)$ kifejezés $acc$-ra értékelődik ki végül, ekkor tényleg igaz, hiszen $1+2+\ldots+0$ egy $0$-tagú összeg, $0$ lesz, így az $1+2+\ldots+0+acc=acc$.
  • eztán megnézzük $n+1$-re, ami egy pozitív szám: ekkor az $\mathrm{if}(n+1\leq 0)~acc~\mathrm{else}~\mathrm{tailSum}(n+1-1,acc+n+1)$ kifejezés, mivel a feltétel hamis lesz, így $\mathrm{tailSum}(n+1-1,acc+n)$-re, azaz $\mathrm{tailSum}(n,acc+n+1)$-re íródik át. Itt most (indukció!) valid feltételezni, hogy $n$-re már tudjuk, hogy kijön, azaz ennek a kifejezésnek az értéke $1+\ldots+n+acc+n+1$ lesz. Ami ugyanaz, mint $1+\ldots+(n+1)+acc$, és pont ezt akartuk igazolni.
Tehát, a $\mathrm{tailSum}(n,acc)$ függvényt most már tudjuk, hogy mit számol ki: így a $\mathrm{tailSum}(n)=\mathrm{tailSum}(n,0)$ függvény pedig az első $n$ pozitív egész összegét, plusz $0$-t, azaz tényleg az első $n$ pozitív egész összegét adja vissza.

 Hagyhatnánk ennyiben is, de ez is ér és szebb:

import scala.annotation.tailrec

object Main extends App {
  def tailSum(n: Int) = {
    @tailrec
    def tailSum(n: Int, acc: Int): Int = 
      if ( n <= 0 ) acc else tailSum(n - 1, acc + n)
    
    tailSum(n, 0)
  }

  println(tailSum(20000)) //prints 200010000
}

Mit látunk itt? Egyrészt, Scalában szabad függvényen belül deklarálni másik függvényt. Ennek az az előnye, hogy ha a belső függvény egy "helper" függvény, mint most, amit alapvetően nem akarunk kiajánlani másoknak, hogy hívogassák, akkor berakva a függvény scope-jába elérjük, hogy ezt senki ne lássa és csak itt, ebben a függvényben létezzen, less clutter. Másrészt, a $\mathrm{@tailrec}$ egy annotáció (hogy csak így keresztnéven emlegethessük és ne az első sorban az import után látható teljes nevét kelljen kiírni, arra szolgál az import parancs föntebb; az idea ctrl-space kiegészíti nekünk a tailrec annotációt és felírja az importot automatikusan, ezzel egyelőre nem kell foglalkoznunk). A lényege: a fordítónak azt jelezzük vele, hogy itt az állt a szándékunkban, hogy egy tail recursive metódust írjunk. A fordító ennek hatására nem csak hogy lefordítja nekünk a kódot mint máskor, hanem még ellenőrzi is, hogy tényleg tail recursive-e a $\mathrm{@tailrec}$ után álló függvény, és ha nem az (pl. mert nem is rekurzív, vagy mert nem mindig végpozícióban hívja magát, hanem van legalább egy olyan pont, ahol van rekurzív hívás, de annak az értékével még utána csinálunk valamit), akkor fordítási hibát dob.
A fordítási időben előjövő hibák a legjobbak, sokkal könnyebb detektálni és általában lejavítani is őket, mint a futásidejűeket :) ezért ahol csak lehet, használni fogunk annotációkat, hogy a fordító a kezünkre üthessen még időben, ha valamit rosszul csinálunk.

Még egy részlet: ebben a $\mathrm{tailSum}$ implementációban az egyváltozós függvényen belül két kifejezés szerepel: egy függvénydeklaráció (a kétváltozós függvényé), majd egy függvényhívás-kifejezés. Ezért ide kell a kapcsos zárójel (a kétváltozós belső függvény egyetlen if-else kifejezés, aköré nem muszáj tenni). Ha több kifejezés egymás után téve alkot egy blokkot (mint most), akkor a kiszámítás úgy zajlik, hogy kiértékeljük az első kifejezést, aztán a másodikat, stb, és a végső érték az utolsó kifejezés értéke lesz. Ezért most amit kapunk értéknek, az végül a $\mathrm{tailSum}(n,0)$ hívás értéke lesz.

No comments:

Post a Comment