package kaun
Flax-inspired neural network library for OCaml
Install
Dune Dependency
Authors
Maintainers
Sources
raven-1.0.0.alpha0.tbz
sha256=a9a8a9787f8250337187bb7b21cb317c41bfd2ecf08bcfe0ab407c7b6660764d
sha512=fe13cf257c487e41efe2967be147d80fa94bac8996d3aab2b8fd16f0bbbd108c15e0e58c025ec9bf294d4a0d220ca2ba00c3b1b42fa2143f758c5f0ee4c15782
doc/src/kaun.datasets/kaun_datasets.ml.html
Source file kaun_datasets.ml
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
(** Ready-to-use datasets for Kaun *) (* Set up logging *) let src = Logs.Src.create "kaun.datasets" ~doc:"Kaun datasets module" module Log = (val Logs.src_log src : Logs.LOG) (* Cache for loaded datasets to avoid reloading *) let mnist_cache = ref None let load_mnist_cached () = match !mnist_cache with | Some data -> Log.debug (fun m -> m "Using cached MNIST data"); data | None -> Log.info (fun m -> m "Loading MNIST dataset from disk..."); let data = Nx_datasets.load_mnist () in mnist_cache := Some data; data let mnist ?(train = true) ?(flatten = false) ?(normalize = true) ?(data_format = `NCHW) ~device () = (* Load MNIST data from Nx_datasets *) let start = Unix.gettimeofday () in let (x_train, y_train), (x_test, y_test) = load_mnist_cached () in Log.debug (fun m -> m "Data loaded in %.3fs" (Unix.gettimeofday () -. start)); (* Select training or test data *) let x, y = if train then (x_train, y_train) else (x_test, y_test) in (* Convert from uint8 to float *) (* Cast to float32 *) let cast_start = Unix.gettimeofday () in let x = Nx.cast Nx.float32 x in let y = Nx.cast Nx.float32 y in Log.debug (fun m -> m "Cast to float32 in %.3fs" (Unix.gettimeofday () -. cast_start)); (* Convert to Rune tensors *) let dtype = Rune.float32 in (* Convert Nx tensors to Rune tensors via bigarray *) let convert_start = Unix.gettimeofday () in let x = Rune.of_bigarray device (Nx.to_bigarray x) in let y = Rune.of_bigarray device (Nx.to_bigarray y) in Log.debug (fun m -> m "Converted to Rune tensors in %.3fs" (Unix.gettimeofday () -. convert_start)); (* Normalize to [0, 1] if requested *) let norm_start = Unix.gettimeofday () in let x = if normalize then Rune.div x (Rune.scalar device dtype 255.0) else x in Log.debug (fun m -> m "Normalization in %.3fs" (Unix.gettimeofday () -. norm_start)); (* Handle data format *) let format_start = Unix.gettimeofday () in let x = match data_format with | `NCHW -> (* Original shape is [N, H, W, 1], convert to [N, 1, H, W] *) let shape = Rune.shape x in let n, h, w, _ = (shape.(0), shape.(1), shape.(2), shape.(3)) in let x_reshaped = Rune.reshape [| n; h; w; 1 |] x in let x_transposed = Rune.transpose x_reshaped ~axes:[| 0; 3; 1; 2 |] in Log.debug (fun m -> m "After transpose, is_c_contiguous: %b" (Rune.is_c_contiguous x_transposed)); x_transposed | `NHWC -> (* Keep original shape [N, H, W, 1] *) x in Log.debug (fun m -> m "Data format conversion in %.3fs" (Unix.gettimeofday () -. format_start)); (* Flatten if requested *) let x = if flatten then let shape = Rune.shape x in let n = shape.(0) in Rune.reshape [| n; 28 * 28 |] x else x in (* Keep labels as class indices *) let y = Rune.squeeze y ~axes:[| 1 |] in (* Remove the extra dimension [N, 1] -> [N] *) (* Keep as float for now, will be cast to int when needed *) (* Create the dataset *) let dataset_start = Unix.gettimeofday () in let result = Kaun.Dataset.of_xy (x, y) in Log.debug (fun m -> m "Dataset.of_xy in %.3fs" (Unix.gettimeofday () -. dataset_start)); (* Log summary info *) let shape = Rune.shape x in let shape_str = shape |> Array.to_list |> List.map string_of_int |> String.concat "×" in Log.info (fun m -> m "MNIST %s set loaded: %s (normalized=%b)" (if train then "train" else "test") shape_str normalize); result
sectionYPositions = computeSectionYPositions($el), 10)"
x-init="setTimeout(() => sectionYPositions = computeSectionYPositions($el), 10)"
>