package torch

  1. Overview
  2. Docs

Source file var_store.ml

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
open Base

(* Maybe we should also store the full path in the var stores ? *)
type t =
  { name : string
  ; mutable trainable_tensors : Tensor.t list
  ; all_tensors_by_name : (string, Tensor.t) Hashtbl.t
  ; subs : (string, t) Hashtbl.t
  ; device : Device.t
  ; mutable frozen : bool }

let create ?(frozen = false) ?(device = Device.Cpu) ~name () =
  { name
  ; trainable_tensors = []
  ; subs = Hashtbl.create (module String)
  ; all_tensors_by_name = Hashtbl.create (module String)
  ; device
  ; frozen }

let first_free_name name table =
  if Hashtbl.mem table name
  then
    let rec loop idx =
      let name = Printf.sprintf "%s_%d" name idx in
      if Hashtbl.mem table name then loop (idx + 1) else name
    in
    loop 1
  else name

let sub t sub_name =
  if String.contains sub_name '.'
  then Printf.failwithf "sub names cannot contain ., %s" sub_name ();
  Hashtbl.find_or_add t.subs sub_name ~default:(fun () ->
      { name = t.name
      ; trainable_tensors = []
      ; subs = Hashtbl.create (module String)
      ; all_tensors_by_name = Hashtbl.create (module String)
      ; device = t.device
      ; frozen = t.frozen } )

let ( / ) = sub

let rec freeze t =
  t.frozen <- true;
  List.iter t.trainable_tensors ~f:(fun tensor ->
      ignore (Tensor.set_requires_grad tensor ~r:false : Tensor.t) );
  Hashtbl.iter t.subs ~f:freeze

let rec unfreeze t =
  t.frozen <- false;
  List.iter t.trainable_tensors ~f:(fun tensor ->
      ignore (Tensor.set_requires_grad tensor ~r:true : Tensor.t) );
  Hashtbl.iter t.subs ~f:unfreeze

let rec trainable_vars t =
  let sub_vars = Hashtbl.data t.subs |> List.concat_map ~f:trainable_vars in
  t.trainable_tensors @ sub_vars

let all_vars t =
  let rec walk t ~path =
    let sub_vars =
      Hashtbl.to_alist t.subs
      |> List.concat_map ~f:(fun (key, t) -> walk t ~path:(key :: path))
    in
    let vars =
      Hashtbl.to_alist t.all_tensors_by_name
      |> List.map ~f:(fun (key, tensor) ->
             List.rev (key :: path) |> String.concat ~sep:".", tensor )
    in
    vars @ sub_vars
  in
  walk t ~path:[]

let copy ~src ~dst =
  Tensor.no_grad (fun () ->
      let rec walk ~src ~dst path =
        Hashtbl.iteri dst.all_tensors_by_name ~f:(fun ~key ~data ->
            match Hashtbl.find src.all_tensors_by_name key with
            | Some src -> Tensor.copy_ data ~src
            | None ->
              Printf.failwithf
                "cannot find var %s from var-store %s in %s"
                (List.rev (key :: path) |> String.concat ~sep:".")
                dst.name
                src.name
                () );
        Hashtbl.iteri dst.subs ~f:(fun ~key ~data:dst ->
            match Hashtbl.find src.subs key with
            | Some src -> walk ~src ~dst (key :: path)
            | None ->
              Printf.failwithf
                "cannot find sub %s from var-store %s in %s"
                (List.rev (key :: path) |> String.concat ~sep:".")
                dst.name
                src.name
                () )
      in
      walk ~src ~dst [] )

let name t = t.name
let device t = t.device

module Init = struct
  type t =
    | Zeros
    | Ones
    | Const of float
    | Normal of {mean : float; stdev : float}
    | Uniform of float * float
    | Copy of Tensor.t
end

let new_var ?(trainable = true) t ~shape ~init ~name =
  let device = device t in
  let requires_grad = trainable && not t.frozen in
  let tensor =
    match (init : Init.t) with
    | Zeros -> Tensor.zeros shape ~requires_grad ~device
    | Ones -> Tensor.ones shape ~requires_grad ~device
    | Const scale -> Tensor.ones shape ~requires_grad ~device ~scale
    | Normal {mean = 0.; stdev} -> Tensor.randn shape ~scale:stdev ~requires_grad ~device
    | Normal {mean; stdev} ->
      Tensor.( + )
        (Tensor.randn shape ~scale:stdev ~requires_grad ~device)
        (Tensor.f mean)
    | Uniform (from, to_) ->
      Tensor.zeros shape ~device
      |> Tensor.uniform_ ~from ~to_
      |> Tensor.set_requires_grad ~r:requires_grad
    | Copy src ->
      Tensor.copy src
      |> Tensor.to_device ~device
      |> Tensor.set_requires_grad ~r:requires_grad
  in
  if String.contains name '.'
  then Printf.failwithf "tensor names cannot contain ., %s" name ();
  let name = first_free_name name t.all_tensors_by_name in
  Hashtbl.add_exn t.all_tensors_by_name ~key:name ~data:tensor;
  if trainable then t.trainable_tensors <- tensor :: t.trainable_tensors;
  tensor

let new_var_copy ?trainable t ~src ~name =
  new_var ?trainable t ~shape:(Tensor.shape src) ~init:(Copy src) ~name
OCaml

Innovation. Community. Security.