package prbnmcn-mcts

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file mcts.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
open Ucb1

type 'a gen = Random.State.t -> 'a

module type S = sig
  (** Type of states. *)
  type terminal

  type nonterminal

  (** States must be distinguished as either terminal or nonterminal. *)
  type state = Terminal of terminal | Nonterminal of nonterminal

  (** Actions that can be taken at each state. *)
  type action

  (** Actions available at a given state. *)
  val actions : nonterminal -> action array

  (** Given a state and an action, one can move to the next state. *)
  val next : nonterminal -> action -> state

  (** Reward at a terminal state. *)
  val reward : terminal -> float

  (** The MCTS is parameterised by a Monte-Carlo exploration. Setting
      to [`Uniform] will use an uniform search. *)
  val exploration_depth : [ `Unbounded | `Bounded of int ]

  val exploration_kernel : [ `Uniform | `Kernel of nonterminal -> state gen ]

  val pp_action : Format.formatter -> action -> unit

  val pp_terminal : Format.formatter -> terminal -> unit

  val pp_nonterminal : Format.formatter -> nonterminal -> unit
end

(** Monte-Carlo Tree Search yields a policy, i.e. a way to decide which
    action to take at each nonterminal state. *)
module type Policy = sig
  type t

  type action

  val policy : playouts:int -> t -> action gen
end

module MCTS : functor (X : S) ->
  Policy with type t = X.nonterminal and type action = X.action =
functor
  (X : S)
  ->
  struct
    type t = X.nonterminal

    type action = X.action

    module Bandit = Ucb1.Make (struct
      type t = int

      let compare (x : int) (y : int) =
        if x < y then -1 else if x > y then 1 else 0

      let pp = Format.pp_print_int
    end)

    type tree =
      | Terminal of terminal_node
      | Nonterminal of nonterminal_node
      | Unexplored of X.nonterminal

    and bandit = ready_to_move Bandit.t

    and nonterminal_node =
      { state : X.nonterminal;
        mutable bandit : bandit;
        actions : action array;
        branches : tree lazy_t array
      }

    and terminal_node = { final : X.terminal; reward : float }

    let uniform_exploration : t -> X.state gen =
     fun state ->
      let actions = X.actions state in
      fun rng_state ->
        let act =
          let index = Random.State.int rng_state (Array.length actions) in
          actions.(index)
        in
        X.next state act

    let exploration =
      match X.exploration_kernel with
      | `Uniform -> uniform_exploration
      | `Kernel f -> f

    let rec explore_until_termination (node : X.nonterminal) rng_state =
      let next = exploration node rng_state in
      match next with
      | X.Terminal state -> X.reward state
      | X.Nonterminal state -> explore_until_termination state rng_state

    let rec explore_until_termination_bounded (gas : int) (node : X.nonterminal)
        rng_state =
      if gas < 0 then 0.0
      else
        let next = exploration node rng_state in
        match next with
        | X.Terminal state -> X.reward state
        | X.Nonterminal state ->
            explore_until_termination_bounded (gas - 1) state rng_state

    let exploration_loop =
      match X.exploration_depth with
      | `Unbounded -> explore_until_termination
      | `Bounded gas -> explore_until_termination_bounded gas

    let rec assign_reward path reward =
      match path with
      | [] -> ()
      | (node, bandit) :: tl ->
          let bandit = Bandit.set_reward bandit reward in
          node.bandit <- bandit ;
          assign_reward tl reward

    let rec playout (node : nonterminal_node) path rng_state =
      let (act, awaiting) = Bandit.next_action node.bandit in
      let path = (node, awaiting) :: path in
      match Lazy.force node.branches.(act) with
      | Terminal { reward; _ } -> assign_reward path reward
      | Nonterminal node' -> playout node' path rng_state
      | Unexplored nonterminal ->
          let new_node = expand_node nonterminal in
          node.branches.(act) <- Lazy.from_val (Nonterminal new_node) ;
          let reward = exploration_loop nonterminal rng_state in
          assign_reward path reward

    and expand_node nonterminal =
      let actions = X.actions nonterminal in
      let arms = Array.init (Array.length actions) (fun i -> i) in
      let bandit = Bandit.create arms in
      let branches =
        Array.map
          (fun act ->
            Lazy.from_fun (fun () ->
                match X.next nonterminal act with
                | X.Terminal final ->
                    Terminal { final; reward = X.reward final }
                | X.Nonterminal state -> Unexplored state))
          actions
      in
      { state = nonterminal; bandit; actions; branches }

    let policy ~playouts initial_state rng_state =
      let root = expand_node initial_state in
      for _i = 0 to playouts - 1 do
        playout root [] rng_state
      done ;
      let (act, _) = Bandit.next_action root.bandit in
      root.actions.(act)
  end
OCaml

Innovation. Community. Security.