package owl-base

  1. Overview
  2. Docs
Legend:
Page
Library
Module
Module type
Parameter
Class
Class type
Source

Source file owl_optimise_generic_sig.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# 1 "src/base/optimise/owl_optimise_generic_sig.ml"
(*
 * OWL - OCaml Scientific and Engineering Computing
 * Copyright (c) 2016-2020 Liang Wang <liang.wang@cl.cam.ac.uk>
 *)

module type Sig = sig
  module Algodiff : Owl_algodiff_generic_sig.Sig

  open Algodiff

  (** {6 Utils module} *)

  module Utils : sig
    val sample_num : t -> int
    (** Return the total number of samples in passed in ndarray. *)

    val draw_samples : t -> t -> int -> t * t
    (**
``draw_samples x y`` draws samples from both ``x`` (observations) and ``y``
(labels). The samples will be drew along axis 0, so ``x`` and ``y`` must agree
along axis 0.
     *)

    val get_chunk : t -> t -> int -> int -> t * t
    (**
``get_chunk x y i c`` gets a continuous chunk of ``c`` samples from position
``i`` from  ``x`` (observations) and ``y`` (labels).
     *)
  end

  (** {7 Learning_Rate module} *)

  module Learning_Rate : sig
    type typ =
      | Adagrad   of float
      | Const     of float
      | Decay     of float * float
      | Exp_decay of float * float
      | RMSprop   of float * float
      | Adam      of float * float * float
      | Schedule  of float array (** types of learning rate *)

    val run : typ -> int -> t -> t array -> t
    (** Execute the computations defined in module ``typ``. *)

    val default : typ -> typ
    (** Create module ``typ`` with default values. *)

    val update_ch : typ -> t -> t array -> t array
    (** Update the cache of gradients. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Batch module} *)

  module Batch : sig
    type typ =
      | Full
      | Mini       of int
      | Sample     of int
      | Stochastic (** Types of batches. *)

    val run : typ -> t -> t -> int -> t * t
    (** Execute the computations defined in module ``typ``. *)

    val batches : typ -> t -> int
    (** Return the total number of batches given a batch ``typ``. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Loss module} *)

  module Loss : sig
    type typ =
      | Hinge
      | L1norm
      | L2norm
      | Quadratic
      | Cross_entropy
      | Custom        of (t -> t -> t) (** Types of loss functions. *)

    val run : typ -> t -> t -> t
    (** Execute the computations defined in module ``typ``. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Gradient module} *)

  module Gradient : sig
    type typ =
      | GD
      | CG
      | CD
      | NonlinearCG
      | DaiYuanCG
      | NewtonCG
      | Newton (** Types of gradient function. *)

    val run : typ -> (t -> t) -> t -> t -> t -> t -> t
    (** Execute the computations defined in module ``typ``. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Momentum module} *)

  module Momentum : sig
    type typ =
      | Standard of float
      | Nesterov of float
      | None (** Types of momentum functions. *)

    val run : typ -> t -> t -> t
    (** Execute the computations defined in module ``typ``. *)

    val default : typ -> typ
    (** Create module ``typ`` with default values. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Regularisation module} *)

  module Regularisation : sig
    type typ =
      | L1norm      of float
      | L2norm      of float
      | Elastic_net of float * float
      | None (** Types of regularisation functions. *)

    val run : typ -> t -> t
    (** Execute the computations defined in module ``typ``. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Clipping module} *)

  module Clipping : sig
    type typ =
      | L2norm of float
      | Value  of float * float
      | None (** Types of clipping functions. *)

    val run : typ -> t -> t
    (** Execute the computations defined in module ``typ``. *)

    val default : typ -> typ
    (** Create module ``typ`` with default values. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Stopping module} *)

  module Stopping : sig
    type typ =
      | Const of float
      | Early of int * int
      | None (** Types of stopping functions. *)

    val run : typ -> float -> bool
    (** Execute the computations defined in module ``typ``. *)

    val default : typ -> typ
    (** Create module ``typ`` with default values. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Checkpoint module} *)

  module Checkpoint : sig
    type state =
      { mutable current_batch : int
      ; mutable batches_per_epoch : int
      ; mutable epochs : float
      ; mutable batches : int
      ; mutable loss : t array
      ; mutable start_at : float
      ; mutable stop : bool
      ; mutable gs : t array array
      ; mutable ps : t array array
      ; mutable us : t array array
      ; mutable ch : t array array array
      }
    (** Type definition of checkpoint *)

    type typ =
      | Batch  of int
      | Epoch  of float
      | Custom of (state -> unit)
      | None (** Batch type. *)

    val init_state : int -> float -> state
    (**
``init_state batches_per_epoch epochs`` initialises a state by specifying the
number of batches per epoch and the number of epochs in total.
     *)

    val default_checkpoint_fun : (string -> 'a) -> 'a
    (** This function is used for saving intermediate files during optimisation. *)

    val print_state_info : state -> unit
    (** Print out the detail information of current ``state``. *)

    val print_summary : state -> unit
    (** Print out the summary of current ``state``. *)

    val run : typ -> (string -> unit) -> int -> t -> state -> unit
    (** Execute the computations defined in module ``typ``. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Params module} *)

  module Params : sig
    type typ =
      { mutable epochs : float
      ; mutable batch : Batch.typ
      ; mutable gradient : Gradient.typ
      ; mutable loss : Loss.typ
      ; mutable learning_rate : Learning_Rate.typ
      ; mutable regularisation : Regularisation.typ
      ; mutable momentum : Momentum.typ
      ; mutable clipping : Clipping.typ
      ; mutable stopping : Stopping.typ
      ; mutable checkpoint : Checkpoint.typ
      ; mutable verbosity : bool
      }
    (** Type definition of paramater. *)

    val default : unit -> typ
    (** Create module ``typ`` with default values. *)

    val config
      :  ?batch:Batch.typ
      -> ?gradient:Gradient.typ
      -> ?loss:Loss.typ
      -> ?learning_rate:Learning_Rate.typ
      -> ?regularisation:Regularisation.typ
      -> ?momentum:Momentum.typ
      -> ?clipping:Clipping.typ
      -> ?stopping:Stopping.typ
      -> ?checkpoint:Checkpoint.typ
      -> ?verbosity:bool
      -> float
      -> typ
    (** This function creates a parameter object with many configurations. *)

    val to_string : typ -> string
    (** Convert the module ``typ`` to its string representation. *)
  end

  (** {6 Core functions} *)

  val minimise_weight
    :  ?state:Checkpoint.state
    -> Params.typ
    -> (t -> t -> t)
    -> t
    -> t
    -> t
    -> Checkpoint.state * t
  (**
This function minimises the weight ``w`` of passed-in function ``f``.

* ``f`` is a function ``f : w -> x -> y``.
* ``w`` is a row vector but ``y`` can have any shape.
   *)

  val minimise_network
    :  ?state:Checkpoint.state
    -> Params.typ
    -> (t -> t * t array array)
    -> (t -> t array array * t array array)
    -> (t array array -> unit)
    -> (string -> unit)
    -> t
    -> t
    -> Checkpoint.state
  (**
This function is specifically designed for minimising the weights in a neural
network of graph structure. In Owl's earlier versions, the functions in the
regression module were actually implemented using this function.
   *)

  val minimise_fun
    :  ?state:Checkpoint.state
    -> Params.typ
    -> (t -> t)
    -> t
    -> Checkpoint.state * t
  (**
This function minimises ``f : x -> y`` w.r.t ``x``.

``x`` is an ndarray; and ``y`` is an scalar value.
   *)

  val minimise_compiled_network
    :  ?state:Checkpoint.state
    -> Params.typ
    -> (t -> t -> t)
    -> (unit -> unit)
    -> (string -> unit)
    -> t
    -> t
    -> Checkpoint.state
  (** TODO *)
end
OCaml

Innovation. Community. Security.