Memoizacija v OCamlu

Osnovna memoizacija v OCamlu poteka podobno tisti v Pythonu. Za začetek si spet poglejmo funkcijo, ki vrne kvadrat celega števila:

let kvadrat x =
    print_endline ("Računam " ^ string_of_int x);
    x * x
val kvadrat : int -> int = <fun>
kvadrat 10
Računam 10
- : int = 100
kvadrat 10
Računam 10
- : int = 100

S pomočjo te funkcije lahko definiramo funkcijo mem_kvadrat, ki si shranjuje že izračunane vrednosti. Za shranjevanje uporabimo knjižnico Hashtbl za delo z zgoščevalnimi tabelami, s katerimi so implementirani tudi Pythonovi slovarji.

let kvadrati = Hashtbl.create 512 (* argument 512 predstavlja pričakovano začetno velikost tabele *)
let mem_kvadrat x =
  match Hashtbl.find_opt kvadrati x with
  | Some y -> y
  | None ->
      let y = kvadrat x in
      Hashtbl.add kvadrati x y;
      y
val kvadrati : ('_weak1, '_weak2) Hashtbl.t = <abstr>
val mem_kvadrat : int -> int = <fun>
mem_kvadrat 10
Računam 10
- : int = 100
mem_kvadrat 10
- : int = 100

Tip tabele kvadrati lahko ignorirate, sporoča pa, da sta tipa ključev in vrednosti zaenkrat še neznana, vendar nista polimorfna. V resnici že definicija funkcije mem_kvadrat povzroči, da se oba nastavita na int.

kvadrati
- : (int, int) Hashtbl.t = <abstr>

Tudi v OCamlu lahko napišemo funkcijo višjega reda, ki memoizira dano funkcijo:

let memoiziraj f =
  let rezultati = Hashtbl.create 512 in
  let mem_f x =
    match Hashtbl.find_opt rezultati x with
    | None ->
        let y = f x in
        Hashtbl.add rezultati x y;
        y
    | Some y ->
        y
  in
  mem_f
val memoiziraj : ('a -> 'b) -> 'a -> 'b = <fun>
let mem_kvadrat2 = memoiziraj kvadrat
val mem_kvadrat2 : int -> int = <fun>
mem_kvadrat2 10
Računam 10
- : int = 100
mem_kvadrat2 10
- : int = 100

Memoizacija rekurzivnih funkcij

Pri memoizaciji rekurzivnih funkcij pa nastopijo težave.

let rec fib n =
  print_endline ("Računam " ^ string_of_int n);
  match n with
  | 0 | 1 -> n
  | n -> fib (n - 1) + fib (n - 2)

let mem_fib = memoiziraj fib
val fib : int -> int = <fun>
val mem_fib : int -> int = <fun>
mem_fib 4
Računam 4
Računam 2
Računam 0
Računam 1
Računam 3
Računam 1
Računam 2
Računam 0
- : int = 3
mem_fib 4
- : int = 3

Na prvi pogled je videti, kot da memoizacija deluje pravilno, saj je drugi klic mem_fib vrnil že izračunano vrednost. Vendar ob natančnem pregledu vidimo, da se ja primer vrednost pri 2 izračunala večkrat. Težava je v tem, da si mem_fib shrani vrednosti, na katerih je bil poklican. Če pa rezultata še ne pozna, pokliče funkcij fib, ki pa o že izračunanih vrednostih ne ve nič, kar vodi do velikega števila klicev. Če mem_fib na primer pokličemo na 5, mu poprej izračunana vrednost nič ne pomaga, saj fib 5 pokliče fib 4 in ne mem_fib 4.

mem_fib 5
- : int = 5
Računam 5
Računam 3
Računam 1
Računam 2
Računam 0
Računam 1
Računam 4
Računam 2
Računam 0
Računam 1
Računam 3
Računam 1
Računam 2
Računam 0
Računam 1

Tudi če si mem_fib shranimo pod isto ime, ne rešimo ničesar.

let fib = memoiziraj fib
val fib : int -> int = <fun>
fib 5
- : int = 5

Kljub istemu imenu gre za dve različni funkciji: eno, ki smo jo zgoraj definirali rekurzivno, in drugo, ki je bila rezultat klica memoiziraj. V Pythonu ta težava ne nastopi, saj je dinamičen jezik. To pomeni, da računalnik ob klicu funkcije ne skoči na vnaprej (statično) določeno mesto v programski kodi, temveč šele takrat pogleda, kaj se skriva pod tem imenom. V našem primeru lahko to izkoristimo, da pod to ime shranimo drugo funkcijo. Seveda so dinamični jeziki zaradi te fleksibilnosti počasnejši in tudi manj varni.

Še vedno pa si želimo splošnega načina za memoizacijo rekurzivnih funkcij. Kot smo videli, je težava v tem, da rekurzivne funkcije kličejo same sebe, mi pa se želimo v te klice vriniti. To dosežemo tako, da funkciji podamo dodaten argument, s katerim povemo, katero funkcijo naj pokliče namesto sebe. Na primer, rekurzivni definiciji

let rec fib n =
  print_endline ("Računam " ^ string_of_int n);
  match n with
  | 0 | 1 -> n
  | n -> fib (n - 1) + fib (n - 2)
val fib : int -> int = <fun>

kot smo jo videli prej, kot dodaten argument f podamo funkcijo, ki naj jo pokliče namesto sebe (opazimo, da v tem primeru funkcija ni več rekurzivna, zato ključna beseda rec ni potrebna). Takim funkcijam pravimo, da so odvite (unrolled), saj smo rekurzivno zanko prekinili.

let odviti_fib f n =
  print_endline ("Računam " ^ string_of_int n);
  match n with
  | 0 | 1 -> n
  | n -> f (n - 1) + f (n - 2)
val odviti_fib : (int -> int) -> int -> int = <fun>

Za f lahko podamo poljubno funkcijo, na primer tako, ki vedno vrača 42:

let nagajivi_fib n = odviti_fib (fun _ -> 42) n
val nagajivi_fib : int -> int = <fun>
nagajivi_fib 10
- : int = 84
nagajivi_fib 200
- : int = 84

Če za f podamo dobljeno funkcijo, dobimo ravno prvotno rekurzivno definicijo:

let rec fib n = odviti_fib fib n
val fib : int -> int = <fun>
fib 5
- : int = 5

Seveda pa je naš namen, da v klic vrinemo funkcijo, ki hrani rezultate:

let rezultati = Hashtbl.create 512
let rec mem_fib x =
    match Hashtbl.find_opt rezultati x with
    | None ->
        let y = odviti_fib mem_fib x in
        Hashtbl.add rezultati x y;
        y
    | Some y ->
        y
val rezultati : ('_weak3, '_weak4) Hashtbl.t = <abstr>
val mem_fib : int -> int = <fun>
mem_fib 5
Računam 5
Računam 3
Računam 1
Računam 2
Računam 0
Računam 4
- : int = 5
mem_fib 6
- : int = 8

Vidimo, da se je vsaka vrednost izračunala natanko enkrat. Postopek sedaj lahko naredimo tudi v splošnem:

let memoiziraj_rec odviti_f =
  let rezultati = Hashtbl.create 512 in
  let rec mem_f x =
    match Hashtbl.find_opt rezultati x with
    | None ->
        let y = odviti_f mem_f x in
        Hashtbl.add rezultati x y;
        y
    | Some y ->
        y
  in
  mem_f
val memoiziraj_rec : (('a -> 'b) -> 'a -> 'b) -> 'a -> 'b = <fun>
let mem_fib = memoiziraj_rec odviti_fib
val mem_fib : int -> int = <fun>
mem_fib 5
Računam 5
Računam 3
Računam 1
Računam 2
Računam 0
Računam 4
- : int = 5
mem_fib 6
Računam 6
- : int = 8

Z ustreznim poimenovanjem lahko pridemo do oblike, ki je z izjemo prve vrstice (in zamika in oklepaja) enaka naši prvotni naivni rekurzivni definiciji:

let fib = memoiziraj_rec (fun fib n ->
  print_endline ("Računam " ^ string_of_int n);
  match n with
  | 0 | 1 -> n
  | n -> fib (n - 1) + fib (n - 2)
)
val fib : int -> int = <fun>
fib 5
Računam 5
Računam 3
Računam 1
Računam 2
Računam 0
Računam 4
- : int = 5
fib 6
- : int = 8