Kleiner Rekursions-Trick

Das Thema Endrekursion hatte ich schon einmal behandelt. Zur Erinnerung, eine naive Implementierung der Fakultät wie diese hier…

//Variante 1 - nicht endrekursiv
def fac1(n:Int):BigInt =
  if (n <= 1) 1 else fac1(n-1) * n

… ist nicht endrekursiv, da der rekursive Aufruf nicht die letzte Operation der Methode ist, und würde dementsprechend irgendwann zu einem Stackoverflow führen. Eine einfache Lösung besteht darin, einen Akkumulator einzuführen. Der Code sieht dann etwa so aus:

//Variante 2 - endrekursiv
def fac2(n:Int):BigInt = {
  @tailrec def loop(k:Int, acc:BigInt):BigInt =
     if (k <= 1) acc
     else loop(k-1, acc * k)
  loop(n,1)
}

Nun ist das ziemlich langatmig, und ich habe eine bessere Formulierung gefunden:

//Variante 3 - auch endrekursiv
@tailrec def fac3(n:Int, acc:BigInt = 1):BigInt =
   if(n <= 1) acc else fac3(n-1, acc*n)

Besser, nicht wahr? Statt in einer separaten lokalen Methode gibt man den Akkumulator sofort mit. Der Clou ist, dass er als Default-Argument übergeben wird, also bei einem externen Aufruf wie fac3(42) nicht angegeben werden braucht (und auch nicht angegeben werden soll). Natürlich wird damit die Methoden-Signatur etwas verwirrender, und für eine externe API sollte man doch lieber auf Variante 2 zurückgreifen. Für Code, der hinter den Kulissen werkelt, sehe ich das aber nicht so problematisch.

Nach uns die Sintflut!

In einer Diskussion auf www.java-forum.org wurde argumentiert, dass einem die tolle Endrekursions-Optimierung nicht viel nützt, wenn man eine Rekursion hat, die „verzweigt“. Ein einfaches Beispiel dafür sei der Flood-Fill Algorithmus.

Dabei ist gerade Flood-Fill sehr einfach endrekursiv umzuschreiben, weil es nicht auf die Reihenfolge der Pixel ankommt. Rein intuitiv würde ich sagen, dass sich mit entsprechendem Aufwand jede verzweigte Rekursion „linearisieren“ lässt, aber das wissen die Theoretiker sicher besser.

Hier eine ganz einfache endrekursive Flood-Fill-Lösung in Scala, die statt Bilder „ASCII-Kunst“ bearbeitet:

object flood {
  val neighbors = List((-1,0),(1,0),(0,-1),(0,1))
  val pic = Array(
    "#####################################",
    "#........#...................#......#",
    "#..................####.......#######",
    "#...#....#..........................#",
    "#...##...####################.......#",
    "#...#....#..........................#",
    "#####################################").map(_.toArray)

  def floodFill(pixels:Set[(Int,Int)]) {
   if (pixels.isEmpty) pic.foreach{line => line.foreach(print); println}
   else floodFill(pixels.foldLeft(Set[(Int,Int)]()){ (set, point) =>
        val (x,y) = point
        pic(y)(x) = 'X'
        set ++ neighbors.map{case (px,py) => (px + x, py + y)}.
                         filter{case (px,py) => pic(py)(px) == '.'}
    })}

  def main(args:Array[String]) {
    floodFill(Set((2,2)))
  }
}

Wie im „richtigen Leben“ habe ich auf eine veränderliche Daten zurückgegriffen, denn bei einem größeren Bild müsste man erheblichen Aufwand betreiben, damit eine „pure“, auf unveränderlichen Daten basierende Implementierung auch nur halbwegs performant wäre.

Das Ergebnis sieht aus wie erhofft:

#####################################
#XXXXXXXX#XXXXXXXXXXXXXXXXXXX#......#
#XXXXXXXXXXXXXXXXXX####XXXXXXX#######
#XXX#XXXX#XXXXXXXXXXXXXXXXXXXXXXXXXX#
#XXX##XXX####################XXXXXXX#
#XXX#XXXX#XXXXXXXXXXXXXXXXXXXXXXXXXX#
#####################################

Selbstverständlich ist das nur eine Spiel-Implementierung, aber es ist schon erstaunlich, wie kurz sich das Ganze formulieren läßt. Dass die neu zu zeichnenden Punkte in einem Set statt einer Liste vorgehalten werden, spart das lästige Vergleichen auf Duplikate. Apropos Punkte: Die Verwendung einer Punkt-Klasse wäre eine sinnvolle Verbesserung, denn auf die Dauer wird die Arbeit mit Tupeln doch lästig, so praktisch sie als „Wegwerf-Lösung“ auch sein mögen.

Würde man das Bild bei jedem floodFill-Aufruf ausgeben, würde man sehen, wie sich trotz der „Linearisierung“ die „Farbe“ rings um den vorhandenen „Fleck“ ausbreitet, und nicht erst in eine Richtung, denn die verwendete Strategie ist eine Breitensuche. Da ich gerade ziemlich müde bin, sei die Implementierung einer Version mit Tiefensuche dem geneigten Leser als Übung überlassen 🙂

Endrekursion

Ich musste erst mal Tante Wikipedia fragen, wie man „tail recursion“ in vernünftiges Deutsch übersetzt – „Schwanzrekursion“ schien mir irgendwie nicht ganz passend zu sein. Was Rekursion ist, sollte eigentlich jeder Programmierer wissen (falls nicht, wird es gut im 2. Kapitel dieses Skripts erklärt), aber „Endrekursion“ dürfte für die meisten Java-Programmierer ein Fremdwort sein. In Scala, wo Rekursion der Normalfall und Schleifen eher die Ausnahme sind, gehört dieses Konzept zum grundlegenden Rüstzeug – aber keine Angst, es ist wirklich nicht kompliziert.

Schauen wir uns zwei verschiedene Implementierungsvarianten für die Fakultät an:

//Variante 1 - nicht endrekursiv
def fac1(n:Int):BigInt = 
  if (n <= 1) 1
  else fac1(n-1) * n

fac1(20)
//--> res0: BigInt = 2432902008176640000

//Variante 2 - endrekursiv
def fac2(n:Int):BigInt = {
  def loop(k:Int, prod:BigInt):BigInt = 
     if (k <= 1) prod 
     else loop(k-1, prod * k)
 
  if (n <= 1) 1 else loop(n,1) 
}

fac2(20)
//--> res1: BigInt = 2432902008176640000

Schön, zumindest kommt bei beiden das gleiche Ergebnis heraus. Variante 1 sieht kürzer, hübscher und verständlicher aus, Variante 2 ist der Stil, den man gewöhnlich in Scala findet. Warum? Weil die erste Version nicht nur langsamer ist, sondern sogar zu Stack-Überläufen führen kann, während das bei Variante 2 nicht vorkommen kann.

Variante 1 hat nämlich das Problem, dass die Rechnung noch nicht „fertig“ ist, wenn der rekursive Aufruf erfolgt. So wird der Aufruf von fac1(5) als ((((1*2)*3)*4)*5) ausgeführt. Die Rekursion steigt tiefer und tiefer ab, und am Ende liefert sie die Ergebnisse jeweils dem darüberliegenden Aufruf zurück. Dieser führt noch eine Multiplikation aus und gibt seinerseits das Ergebnis nach „oben“ weiter.

Variante 2 dagegen lagert die Rekursion in eine innere Methode loop aus, die ein zusätzliches Argument – nämlich das Produkt der bisherigen Werte – besitzt. Durch diesen Trick ist der rekursive Aufruf das allerletzte, was in loop passiert. Aus diesem Grund braucht man die gerade abgegearbeitete Version des Aufrufs nicht mehr, sie ist für die weitere Berechnung nicht mehr nötig, und kann sozusagen (nach einer Neubelegung der Argumente) „recycelt“ werden. Und genau das tut die JVM (allerdings nicht so raffiniert, wie das funktionale Compiler tun): sie steigt nicht eine Ebene „tiefer“, sondern verwendet die aktuelle Umgebung von loop erneut, ohne dass ein echter Methodenaufruf stattfinden würde. Der generierte Byte-Code ähnelt dann mehr einer while-Schleife als dem von Variante 1, und ist entsprechend effizient. Diese Optimierung nennt sich Endrekursion.

Es gibt also einen guten Grund dafür, Variante 2 zu bevorzugen, auch wenn sie ein wenig komplizierter ist. Es lohnt sich, etwas Gehirnschmalz zu investieren, um seine Methoden endrekursiv zu formulieren – mit etwas Übung ist das gar kein Problem. Manchmal reicht schon eine andere Berechnungsreihenfolge, aber oft muss man auch auf den Trick mit dem „Sammelparameter“ (der offizielle Name ist „Akkumulator“ oder kurz „accu“) zurückgreifen.

In Scala 2.8 gibt es übrigens eine Annotation namens @tailrec, bei der der Compiler eine Warnung ausgibt, wenn die entsprechende Methode nicht so übersetzt werden kann, dass sie endrekursiv ausgeführt wird – eine großartige Hilfe, um Performanceprobleme und Stacküberläufe zu vermeiden.

Da Endrekursion eine Optimierungstechnik der JVM und nicht des verwendeten Compilers ist, gelten alle Ausführungen auch für Java. Wenn man also in Java einen Stacküberlauf durch Rekursion bekommt, kann man die betroffene Methode genauso in der hier präsentierten Weise umformulieren, um Endrekursion zu ermöglichen.
… dachte ich zumindest, bin aber eines besseren belehrt worden. Es ist der Scala-Compiler und nicht die JVM, der die Endrekursions-Optimierung vornimmt. Später mehr dazu.

Dann wünsche ich euch noch fröhliches Rekursieren!

Russische Bauernmultiplikation

Zufälligerweise habe ich von diesem kleinen Programmierwettberb auf „The Daily WTF“ gelesen. Die Aufgabe ist einfach: Implementiere die Russische Bauernmultiplikation. Das ist ein perfektes Beispiel für die Arbeit mit Scala-Collections, und auch unser alter Bekannter foldLeft kommt zum Einsatz.

Hier ist nun meine Lösung, die ich Schritt für Schritt sezieren möchte:

def russianMultiply(a:Int, b:Int) = {
  def loop(x:Int, y:Int, list:List[(Int,Int)]):List[(Int,Int)] = 
      if(x == 0) list else loop(x >> 1, y << 1, (x, y) :: list)
  loop(a, b, Nil).filter(_._1 % 2 == 1).foldLeft(0)(_+_._2)
}
&#91;/sourcecode&#93;

Das ist wieder eines dieser typische Scala-Beispiele: Kurz und unverständlich. Aber nicht mehr lange!

Fangen wir mit der inneren Funktion "loop" an. 

Als erstes muss man wissen, was ein Tupel ist. Ein Tupel kann man sich wie ein Array vorstellen, nur besitzt jedes Feld einen eigenen Typ. Ein Tupel ist also ein bequemer Weg, momentan zusammengehörige Werte zu "bündeln". Wenn in Java eine Funktion gleichzeitig einen String, ein int und ein Date zurückgeben soll, müsste man dafür einen extra Klasse schreiben, die die drei Werte aufnimmt. In Scala hat man dafür Tupel, und da sie so praktisch ist, gibt es einen Kurzschreibwiese, nämlich einfach nur Klammern. Wollten wir also ein Int, einen String und ein Date zurückgeben, würden wir schreiben (42, "answer", new Date), und dieses Konstrukt hätte den Typ Tuple3&#91;Int, String, Date&#93; oder kurz (Int, String, Date). 

Was wollen wir nun vertupeln? Nun, die Wertepaare mit den halbierten und verdoppelten Multiplikanden. Und die Tupel packen wir in eine Liste. So eine Liste ist entweder leer (dafür gibt es den Untertyp Nil) oder sie besteht aus einem Listenkopf und der Restliste, und beide sind durch den Operator :: verknüpft. Eine Liste der Zahlen 1 bis 3 kann man bequem als List(1, 2, 3) schreiben, aber genauso gut ginge 1 :: List(2, 3) oder 1 :: 2 :: 3 :: Nil.

Die Funktion loop hat drei Argumente: den zu halbierenden Wert, den zu verdoppelnden Werte und eine "Sammel-Liste" mit allen Tupel-Paaren, die wir schon berechnet haben. Ist der erste Wert gleich 0, sind wir fertig und geben einfach unsere Sammel-Liste zurück (man nennt so ein Sammel-Argument auch "Akkumulator"). Ist das nicht der Fall, packen wir ein Tupel aus unseren aktuellen Argument-Werten zur vorhandenen Liste dazu, berechnen die neuen Werte (wir halbieren und verdoppeln ganz elegant mit den beiden Bit-Schubs-Operatoren &gt;&gt; und &lt;&lt;) und rufen mit diesen Argumenten loop rekursiv auf.

So, probieren wir loop einmal alleine aus, z.B. auf <a href="http://www.simplyscala.com">Simply Scala</a>:


def loop(x:Int, y:Int, list:List[(Int,Int)]):List[(Int,Int)] = 
   if(x == 0) list else loop(x >> 1, y <<1, (x,y)::list)
//--> loop: (Int,Int,List[(Int, Int)])List[(Int, Int)]
loop(23,34,Nil)
//--> res1: List[(Int, Int)] = List((1,544), (2,272), (5,136), (11,68), (23,34))

Gut, die Wertepaare sind rückwärts geordnen, weil wir ja immer am Listenanfang angefügt haben. Aber an dieser Stelle ist uns das egal, Hauptsache wir haben die Wertepaare. Nun wäre es schön, wenn wir nur die Paare mit einem ungeraden ersten Wert hätten. Auf das erste Element eines Tupels tupel greift man mit tupel._1 zu, auf das zweite mit tupel._2 usw. Und filtern tut man eine Collection mit der Funktion „filter“ – wer hätte das gedacht! Kurz ausprobieren:

List((1,544), (2,272), (5,136), (11,68), (23,34)).filter(_._1 % 2 == 1)
//--> res2: List[(Int, Int)] = List((1,544), (5,136), (11,68), (23,34))

Sehr schön! Nun hatte ich letzten Mal geschrieben, dass man foldLeft oder foldRight oft verwenden kann, wenn es irgendwie ums „Daten sammeln“ geht, und auch in unserem Fall passt es prima: Starte mit 0 und addiere fortlaufend die zweiten Werte der Tupel hinzu. Wieder ein kleiner Test:

List((1,544), (5,136), (11,68), (23,34)).foldLeft(0)(_+_._2)
//-->res3: Int = 782

Damit sind hätten wir alle Zutaten beisammen und sind fertig. Der Trick beim Schreiben solcher Funktionen ist, sich vom typischen „Schleifendenken“ in Java zu lösen, und stattdessen die Aufgabe in logische Schritte zu zerteilen: „Erst einmal brauche ich die ganzen Wertepaare, vielleicht in einer Liste oder so. Wenn ich die habe, werfe ich alle mit geradem ersten Wert weg und addiere die zweiten Werte vom Rest zusammen“.

So, ich hoffe, dass jetzt der obige Code nicht mehr ganz so furchteinflößend aussieht. War doch gar nicht so schlimm, oder?