package torch
PyTorch bindings for OCaml
Install
Dune Dependency
Authors
Maintainers
Sources
0.10.tar.gz
md5=63540fcb4a4aa85a63207b8ed6eee137
sha512=a6f01cc4e4d4835f54766490be9145829032e2b75330b810819058883b93562871bc6d68dbdf9d346e10d7911b9474ceacbd20942246c48689102bcfda1ee32a
doc/src/torch/var_store.ml.html
Source file var_store.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
open Base module Tensor_id : sig include Hashable.Key val create : unit -> t end = struct include Int let create = let current = ref 0 in fun () -> Int.incr current; !current end (* Maybe we should also store the full path in the var stores ? *) type t = { name : string ; trainable_tensors : (Tensor_id.t, Tensor.t) Hashtbl.t ; all_tensors_by_name : (string, Tensor.t) Hashtbl.t ; subs : (string, t) Hashtbl.t ; device : Device.t ; mutable frozen : bool } let create ?(frozen = false) ?(device = Device.Cpu) ~name () = { name ; trainable_tensors = Hashtbl.create (module Tensor_id) ; subs = Hashtbl.create (module String) ; all_tensors_by_name = Hashtbl.create (module String) ; device ; frozen } let first_free_name name table = if Hashtbl.mem table name then ( let rec loop idx = let name = Printf.sprintf "%s_%d" name idx in if Hashtbl.mem table name then loop (idx + 1) else name in loop 1) else name let sub t sub_name = if String.contains sub_name '.' then Printf.failwithf "sub names cannot contain ., %s" sub_name (); Hashtbl.find_or_add t.subs sub_name ~default:(fun () -> { name = t.name ; trainable_tensors = Hashtbl.create (module Tensor_id) ; subs = Hashtbl.create (module String) ; all_tensors_by_name = Hashtbl.create (module String) ; device = t.device ; frozen = t.frozen }) let subi t i = sub t (Int.to_string i) let ( / ) = sub let ( // ) = subi let rec freeze t = t.frozen <- true; Hashtbl.iter t.trainable_tensors ~f:(fun tensor -> ignore (Tensor.set_requires_grad tensor ~r:false : Tensor.t)); Hashtbl.iter t.subs ~f:freeze let rec unfreeze t = t.frozen <- false; Hashtbl.iter t.trainable_tensors ~f:(fun tensor -> ignore (Tensor.set_requires_grad tensor ~r:true : Tensor.t)); Hashtbl.iter t.subs ~f:unfreeze let rec num_trainable_vars t = let sub_vars = Hashtbl.data t.subs |> List.fold ~init:0 ~f:(fun acc t -> acc + num_trainable_vars t) in sub_vars + Hashtbl.length t.trainable_tensors let iter_trainable_vars t ~f = let f ~key ~data = f key data in let rec loop t = Hashtbl.iter t.subs ~f:loop; Hashtbl.iteri t.trainable_tensors ~f in loop t let all_vars t = let rec walk t ~path = let sub_vars = Hashtbl.to_alist t.subs |> List.concat_map ~f:(fun (key, t) -> walk t ~path:(key :: path)) in let vars = Hashtbl.to_alist t.all_tensors_by_name |> List.map ~f:(fun (key, tensor) -> List.rev (key :: path) |> String.concat ~sep:".", tensor) in vars @ sub_vars in walk t ~path:[] let copy ~src ~dst = Tensor.no_grad (fun () -> let rec walk ~src ~dst path = Hashtbl.iteri dst.all_tensors_by_name ~f:(fun ~key ~data -> match Hashtbl.find src.all_tensors_by_name key with | Some src -> Tensor.copy_ data ~src | None -> Printf.failwithf "cannot find var %s from var-store %s in %s" (List.rev (key :: path) |> String.concat ~sep:".") dst.name src.name ()); Hashtbl.iteri dst.subs ~f:(fun ~key ~data:dst -> match Hashtbl.find src.subs key with | Some src -> walk ~src ~dst (key :: path) | None -> Printf.failwithf "cannot find sub %s from var-store %s in %s" (List.rev (key :: path) |> String.concat ~sep:".") dst.name src.name ()) in walk ~src ~dst []) let name t = t.name let device t = t.device module Init = struct type t = | Zeros | Ones | Const of float | Normal of { mean : float ; stdev : float } | Uniform of float * float | Copy of Tensor.t end let new_var ?(trainable = true) t ~shape ~init ~name = let device = device t in let requires_grad = trainable && not t.frozen in let tensor = match (init : Init.t) with | Zeros -> Tensor.zeros shape ~requires_grad ~device | Ones -> Tensor.ones shape ~requires_grad ~device | Const scale -> Tensor.ones shape ~requires_grad ~device ~scale | Normal { mean = 0.; stdev } -> Tensor.randn shape ~scale:stdev ~requires_grad ~device | Normal { mean; stdev } -> Tensor.( + ) (Tensor.randn shape ~scale:stdev ~requires_grad ~device) (Tensor.f mean) | Uniform (from, to_) -> Tensor.zeros shape ~device |> Tensor.uniform_ ~from ~to_ |> Tensor.set_requires_grad ~r:requires_grad | Copy src -> Tensor.copy src |> Tensor.to_device ~device |> Tensor.set_requires_grad ~r:requires_grad in if String.contains name '.' then Printf.failwithf "tensor names cannot contain ., %s" name (); let name = first_free_name name t.all_tensors_by_name in Hashtbl.add_exn t.all_tensors_by_name ~key:name ~data:tensor; if trainable then Hashtbl.add_exn t.trainable_tensors ~key:(Tensor_id.create ()) ~data:tensor; tensor let new_var_copy ?trainable t ~src ~name = new_var ?trainable t ~shape:(Tensor.shape src) ~init:(Copy src) ~name
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>