import{s as ga,o as Ma,n as Ws}from"../chunks/scheduler.bdbef820.js";import{S as ya,i as Ja,g as m,s as n,r as h,A as fa,h as c,f as t,c as p,j as da,u as i,x as g,k as ua,y as ba,a as e,v as o,d as j,t as d,w as u}from"../chunks/index.c0aea24a.js";import{T as Ds}from"../chunks/Tip.31005f7d.js";import{C as J}from"../chunks/CodeBlock.6ccca92e.js";import{H as ns,E as Ua}from"../chunks/EditOnGithub.725ee0c1.js";function Ta(f){let l,M=`<code>jax</code> and <code>jaxlib</code> are required to reproduce to code above, so please make sure you
install them as <code>pip install datasets[jax]</code>.`;return{c(){l=m("p"),l.innerHTML=M},l(r){l=c(r,"P",{"data-svelte-h":!0}),g(l)!=="svelte-1i2qrbm"&&(l.innerHTML=M)},m(r,y){e(r,l,y)},p:Ws,d(r){r&&t(l)}}}function wa(f){let l,M='A <a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Dataset">Dataset</a> object is a wrapper of an Arrow table, which allows fast reads from arrays in the dataset to JAX arrays.';return{c(){l=m("p"),l.innerHTML=M},l(r){l=c(r,"P",{"data-svelte-h":!0}),g(l)!=="svelte-1eeedok"&&(l.innerHTML=M)},m(r,y){e(r,l,y)},p:Ws,d(r){r&&t(l)}}}function $a(f){let l,M=`To use the <a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Image">Image</a> feature type, you’ll need to install the <code>vision</code> extra as
<code>pip install datasets[vision]</code>.`;return{c(){l=m("p"),l.innerHTML=M},l(r){l=c(r,"P",{"data-svelte-h":!0}),g(l)!=="svelte-pyl3xs"&&(l.innerHTML=M)},m(r,y){e(r,l,y)},p:Ws,d(r){r&&t(l)}}}function Ca(f){let l,M=`To use the <a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Audio">Audio</a> feature type, you’ll need to install the <code>audio</code> extra as
<code>pip install datasets[audio]</code>.`;return{c(){l=m("p"),l.innerHTML=M},l(r){l=c(r,"P",{"data-svelte-h":!0}),g(l)!=="svelte-1bf13fm"&&(l.innerHTML=M)},m(r,y){e(r,l,y)},p:Ws,d(r){r&&t(l)}}}function Ia(f){let l,M,r,y,$,rs,C,Hs=`This document is a quick introduction to using <code>datasets</code> with JAX, with a particular focus on how to get
<code>jax.Array</code> objects out of our datasets, and how to use them to train JAX models.`,ms,b,cs,I,hs,R,Ls=`By default, datasets return regular Python objects: integers, floats, strings, lists, etc., and
string and binary objects are unchanged, since JAX only supports numbers.`,is,Z,Ss="To get JAX arrays (numpy-like) instead, you can set the format of the dataset to <code>jax</code>:",os,k,js,U,ds,X,Ps=`Note that the exact same procedure applies to <code>DatasetDict</code> objects, so that
when setting the format of a <code>DatasetDict</code> to <code>jax</code>, all the <code>Dataset</code>s there
will be formatted as <code>jax</code>:`,us,v,gs,x,Ks=`Another thing you’ll need to take into consideration is that the formatting is not applied
until you actually access the data. So if you want to get a JAX array out of a dataset,
you’ll need to access the data first, otherwise the format will remain the same.`,Ms,N,Os=`Finally, to load the data in the device of your choice, you can specify the <code>device</code> argument,
but note that <code>jaxlib.xla_extension.Device</code> is not supported as it’s not serializable with neither
<code>pickle</code> not <code>dill</code>, so you’ll need to use its string identifier instead:`,ys,Q,Js,A,sa=`Note that if the <code>device</code> argument is not provided to <code>with_format</code> then it will use the default
device which is <code>jax.devices()[0]</code>.`,fs,q,bs,G,aa="If your dataset consists of N-dimensional arrays, you will see that by default they are considered as the same tensor if the shape is fixed:",Us,_,Ts,E,ws,V,ta=`However this logic often requires slow shape comparisons and data copies.
To avoid this, you must explicitly use the <code>Array</code> feature type and specify the shape of your tensors:`,$s,F,Cs,z,Is,Y,ea='<a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.ClassLabel">ClassLabel</a> data is properly converted to arrays:',Rs,B,Zs,D,la="String and binary objects are unchanged, since JAX only supports numbers.",ks,W,na='The <a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Image">Image</a> and <a href="/docs/datasets/v3.2.0/en/package_reference/main_classes#datasets.Audio">Audio</a> feature types are also supported.',Xs,T,vs,H,xs,w,Ns,L,Qs,S,As,P,pa=`JAX doesn’t have any built-in data loading capabilities, so you’ll need to use a library such
as <a href="https://pytorch.org/" rel="nofollow">PyTorch</a> to load your data using a <code>DataLoader</code> or <a href="https://www.tensorflow.org/" rel="nofollow">TensorFlow</a>
using a <code>tf.data.Dataset</code>. Citing the <a href="https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html#data-loading-with-pytorch" rel="nofollow">JAX documentation</a> on this topic:
“JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don’t
include data loading or munging in the JAX library. There are already a lot of great data loaders
out there, so let’s just use them instead of reinventing anything. We’ll grab PyTorch’s data loader,
and make a tiny shim to make it work with NumPy arrays.”.`,qs,K,ra=`So that’s the reason why JAX-formatting in <code>datasets</code> is so useful, because it lets you use
any model from the HuggingFace Hub with JAX, without having to worry about the data loading
part.`,Gs,O,_s,ss,ma=`The easiest way to get JAX arrays out of a dataset is to use the <code>with_format(&#39;jax&#39;)</code> method. Lets assume
that we want to train a neural network on the <a href="http://yann.lecun.com/exdb/mnist/" rel="nofollow">MNIST dataset</a> available
at the HuggingFace Hub at <a href="https://huggingface.co/datasets/mnist" rel="nofollow">https://huggingface.co/datasets/mnist</a>.`,Es,as,Vs,ts,ca=`Once the format is set we can feed the dataset to the JAX model in batches using the <code>Dataset.iter()</code>
method:`,Fs,es,zs,ls,Ys,ps,Bs;return $=new ns({props:{title:"Use with JAX",local:"use-with-jax",headingTag:"h1"}}),b=new Ds({props:{$$slots:{default:[Ta]},$$scope:{ctx:f}}}),I=new ns({props:{title:"Dataset format",local:"dataset-format",headingTag:"h2"}}),k=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUwQWRhdGElMjAlM0QlMjAlNUIlNUIxJTJDJTIwMiU1RCUyQyUyMCU1QjMlMkMlMjA0JTVEJTVEJTBBZHMlMjAlM0QlMjBEYXRhc2V0LmZyb21fZGljdCglN0IlMjJkYXRhJTIyJTNBJTIwZGF0YSU3RCklMEFkcyUyMCUzRCUyMGRzLndpdGhfZm9ybWF0KCUyMmpheCUyMiklMEFkcyU1QjAlNUQlMEFkcyU1QiUzQTIlNUQ=",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: DeviceArray([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], dtype=int32)}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: DeviceArray([
    [<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
    [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], dtype=int32)}`,wrap:!1}}),U=new Ds({props:{$$slots:{default:[wa]},$$scope:{ctx:f}}}),v=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldERpY3QlMEFkYXRhJTIwJTNEJTIwJTdCJTIydHJhaW4lMjIlM0ElMjAlN0IlMjJkYXRhJTIyJTNBJTIwJTVCJTVCMSUyQyUyMDIlNUQlMkMlMjAlNUIzJTJDJTIwNCU1RCU1RCU3RCUyQyUyMCUyMnRlc3QlMjIlM0ElMjAlN0IlMjJkYXRhJTIyJTNBJTIwJTVCJTVCNSUyQyUyMDYlNUQlMkMlMjAlNUI3JTJDJTIwOCU1RCU1RCU3RCU3RCUwQWRkcyUyMCUzRCUyMERhdGFzZXREaWN0LmZyb21fZGljdChkYXRhKSUwQWRkcyUyMCUzRCUyMGRkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIpJTBBZGRzJTVCJTIydHJhaW4lMjIlNUQlNUIlM0EyJTVE",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> DatasetDict
<span class="hljs-meta">&gt;&gt;&gt; </span>data = {<span class="hljs-string">&quot;train&quot;</span>: {<span class="hljs-string">&quot;data&quot;</span>: [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]}, <span class="hljs-string">&quot;test&quot;</span>: {<span class="hljs-string">&quot;data&quot;</span>: [[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>], [<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]}}
<span class="hljs-meta">&gt;&gt;&gt; </span>dds = DatasetDict.from_dict(data)
<span class="hljs-meta">&gt;&gt;&gt; </span>dds = dds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>dds[<span class="hljs-string">&quot;train&quot;</span>][:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: DeviceArray([
    [<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
    [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], dtype=int32)}`,wrap:!1}}),Q=new J({props:{code:"aW1wb3J0JTIwamF4JTBBZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUwQWRhdGElMjAlM0QlMjAlNUIlNUIxJTJDJTIwMiU1RCUyQyUyMCU1QjMlMkMlMjA0JTVEJTVEJTBBZHMlMjAlM0QlMjBEYXRhc2V0LmZyb21fZGljdCglN0IlMjJkYXRhJTIyJTNBJTIwZGF0YSU3RCklMEFkZXZpY2UlMjAlM0QlMjBzdHIoamF4LmRldmljZXMoKSU1QjAlNUQpJTIwJTIwJTIzJTIwTm90JTIwY2FzdGluZyUyMHRvJTIwJTYwc3RyJTYwJTIwYmVmb3JlJTIwcGFzc2luZyUyMGl0JTIwdG8lMjAlNjB3aXRoX2Zvcm1hdCU2MCUyMHdpbGwlMjByYWlzZSUyMGElMjAlNjBWYWx1ZUVycm9yJTYwJTBBZHMlMjAlM0QlMjBkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIlMkMlMjBkZXZpY2UlM0RkZXZpY2UpJTBBZHMlNUIwJTVEJTBBZHMlNUIwJTVEJTVCJTIyZGF0YSUyMiU1RC5kZXZpY2UoKSUwQWFzc2VydCUyMGRzJTVCMCU1RCU1QiUyMmRhdGElMjIlNUQuZGV2aWNlKCklMjAlM0QlM0QlMjBqYXguZGV2aWNlcygpJTVCMCU1RA==",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> jax
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]]
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>device = <span class="hljs-built_in">str</span>(jax.devices()[<span class="hljs-number">0</span>])  <span class="hljs-comment"># Not casting to \`str\` before passing it to \`with_format\` will raise a \`ValueError\`</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>, device=device)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: DeviceArray([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], dtype=int32)}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;data&quot;</span>].device()
TFRT_CPU_0
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">assert</span> ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;data&quot;</span>].device() == jax.devices()[<span class="hljs-number">0</span>]
<span class="hljs-literal">True</span>`,wrap:!1}}),q=new ns({props:{title:"N-dimensional arrays",local:"n-dimensional-arrays",headingTag:"h3"}}),_=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUwQWRhdGElMjAlM0QlMjAlNUIlNUIlNUIxJTJDJTIwMiU1RCUyQyU1QjMlMkMlMjA0JTVEJTVEJTJDJTIwJTVCJTVCNSUyQyUyMDYlNUQlMkMlNUI3JTJDJTIwOCU1RCU1RCU1RCUyMCUyMCUyMyUyMGZpeGVkJTIwc2hhcGUlMEFkcyUyMCUzRCUyMERhdGFzZXQuZnJvbV9kaWN0KCU3QiUyMmRhdGElMjIlM0ElMjBkYXRhJTdEKSUwQWRzJTIwJTNEJTIwZHMud2l0aF9mb3JtYXQoJTIyamF4JTIyKSUwQWRzJTVCMCU1RA==",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], [[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]  <span class="hljs-comment"># fixed shape</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: Array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
        [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], dtype=int32)}`,wrap:!1}}),E=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUwQWRhdGElMjAlM0QlMjAlNUIlNUIlNUIxJTJDJTIwMiU1RCUyQyU1QjMlNUQlNUQlMkMlMjAlNUIlNUI0JTJDJTIwNSUyQyUyMDYlNUQlMkMlNUI3JTJDJTIwOCU1RCU1RCU1RCUyMCUyMCUyMyUyMHZhcnlpbmclMjBzaGFwZSUwQWRzJTIwJTNEJTIwRGF0YXNldC5mcm9tX2RpY3QoJTdCJTIyZGF0YSUyMiUzQSUyMGRhdGElN0QpJTBBZHMlMjAlM0QlMjBkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIpJTBBZHMlNUIwJTVE",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>]], [[<span class="hljs-number">4</span>, <span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]  <span class="hljs-comment"># varying shape</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: [Array([<span class="hljs-number">1</span>, <span class="hljs-number">2</span>], dtype=int32), Array([<span class="hljs-number">3</span>], dtype=int32)]}`,wrap:!1}}),F=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUyQyUyMEZlYXR1cmVzJTJDJTIwQXJyYXkyRCUwQWRhdGElMjAlM0QlMjAlNUIlNUIlNUIxJTJDJTIwMiU1RCUyQyU1QjMlMkMlMjA0JTVEJTVEJTJDJTVCJTVCNSUyQyUyMDYlNUQlMkMlNUI3JTJDJTIwOCU1RCU1RCU1RCUwQWZlYXR1cmVzJTIwJTNEJTIwRmVhdHVyZXMoJTdCJTIyZGF0YSUyMiUzQSUyMEFycmF5MkQoc2hhcGUlM0QoMiUyQyUyMDIpJTJDJTIwZHR5cGUlM0QnaW50MzInKSU3RCklMEFkcyUyMCUzRCUyMERhdGFzZXQuZnJvbV9kaWN0KCU3QiUyMmRhdGElMjIlM0ElMjBkYXRhJTdEJTJDJTIwZmVhdHVyZXMlM0RmZWF0dXJlcyklMEFkcyUyMCUzRCUyMGRzLndpdGhfZm9ybWF0KCUyMnRvcmNoJTIyKSUwQWRzJTVCMCU1RCUwQWRzJTVCJTNBMiU1RA==",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Array2D
<span class="hljs-meta">&gt;&gt;&gt; </span>data = [[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],[<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],[[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],[<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;data&quot;</span>: Array2D(shape=(<span class="hljs-number">2</span>, <span class="hljs-number">2</span>), dtype=<span class="hljs-string">&#x27;int32&#x27;</span>)})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;data&quot;</span>: data}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;torch&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: Array([[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
        [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]], dtype=int32)}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;data&#x27;</span>: Array([[[<span class="hljs-number">1</span>, <span class="hljs-number">2</span>],
         [<span class="hljs-number">3</span>, <span class="hljs-number">4</span>]],
 
        [[<span class="hljs-number">5</span>, <span class="hljs-number">6</span>],
         [<span class="hljs-number">7</span>, <span class="hljs-number">8</span>]]], dtype=int32)}`,wrap:!1}}),z=new ns({props:{title:"Other feature types",local:"other-feature-types",headingTag:"h3"}}),B=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUyQyUyMEZlYXR1cmVzJTJDJTIwQ2xhc3NMYWJlbCUwQWxhYmVscyUyMCUzRCUyMCU1QjAlMkMlMjAwJTJDJTIwMSU1RCUwQWZlYXR1cmVzJTIwJTNEJTIwRmVhdHVyZXMoJTdCJTIybGFiZWwlMjIlM0ElMjBDbGFzc0xhYmVsKG5hbWVzJTNEJTVCJTIybmVnYXRpdmUlMjIlMkMlMjAlMjJwb3NpdGl2ZSUyMiU1RCklN0QpJTBBZHMlMjAlM0QlMjBEYXRhc2V0LmZyb21fZGljdCglN0IlMjJsYWJlbCUyMiUzQSUyMGxhYmVscyU3RCUyQyUyMGZlYXR1cmVzJTNEZmVhdHVyZXMpJTBBZHMlMjAlM0QlMjBkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIpJTBBZHMlNUIlM0EzJTVE",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, ClassLabel
<span class="hljs-meta">&gt;&gt;&gt; </span>labels = [<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>]
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;label&quot;</span>: ClassLabel(names=[<span class="hljs-string">&quot;negative&quot;</span>, <span class="hljs-string">&quot;positive&quot;</span>])})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;label&quot;</span>: labels}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">3</span>]
{<span class="hljs-string">&#x27;label&#x27;</span>: DeviceArray([<span class="hljs-number">0</span>, <span class="hljs-number">0</span>, <span class="hljs-number">1</span>], dtype=int32)}`,wrap:!1}}),T=new Ds({props:{$$slots:{default:[$a]},$$scope:{ctx:f}}}),H=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUyQyUyMEZlYXR1cmVzJTJDJTIwSW1hZ2UlMEFpbWFnZXMlMjAlM0QlMjAlNUIlMjJwYXRoJTJGdG8lMkZpbWFnZS5wbmclMjIlNUQlMjAqJTIwMTAlMEFmZWF0dXJlcyUyMCUzRCUyMEZlYXR1cmVzKCU3QiUyMmltYWdlJTIyJTNBJTIwSW1hZ2UoKSU3RCklMEFkcyUyMCUzRCUyMERhdGFzZXQuZnJvbV9kaWN0KCU3QiUyMmltYWdlJTIyJTNBJTIwaW1hZ2VzJTdEJTJDJTIwZmVhdHVyZXMlM0RmZWF0dXJlcyklMEFkcyUyMCUzRCUyMGRzLndpdGhfZm9ybWF0KCUyMmpheCUyMiklMEFkcyU1QjAlNUQlNUIlMjJpbWFnZSUyMiU1RC5zaGFwZSUwQWRzJTVCMCU1RCUwQWRzJTVCJTNBMiU1RCU1QiUyMmltYWdlJTIyJTVELnNoYXBlJTBBZHMlNUIlM0EyJTVE",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Image
<span class="hljs-meta">&gt;&gt;&gt; </span>images = [<span class="hljs-string">&quot;path/to/image.png&quot;</span>] * <span class="hljs-number">10</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;image&quot;</span>: Image()})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;image&quot;</span>: images}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;image&quot;</span>].shape
(<span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;image&#x27;</span>: DeviceArray([[[ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              ...,
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>]]], dtype=uint8)}
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>][<span class="hljs-string">&quot;image&quot;</span>].shape
(<span class="hljs-number">2</span>, <span class="hljs-number">512</span>, <span class="hljs-number">512</span>, <span class="hljs-number">3</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[:<span class="hljs-number">2</span>]
{<span class="hljs-string">&#x27;image&#x27;</span>: DeviceArray([[[[ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              ...,
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>],
              [ <span class="hljs-number">255</span>, <span class="hljs-number">255</span>, <span class="hljs-number">255</span>]]]], dtype=uint8)}`,wrap:!1}}),w=new Ds({props:{$$slots:{default:[Ca]},$$scope:{ctx:f}}}),L=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwRGF0YXNldCUyQyUyMEZlYXR1cmVzJTJDJTIwQXVkaW8lMEFhdWRpbyUyMCUzRCUyMCU1QiUyMnBhdGglMkZ0byUyRmF1ZGlvLndhdiUyMiU1RCUyMColMjAxMCUwQWZlYXR1cmVzJTIwJTNEJTIwRmVhdHVyZXMoJTdCJTIyYXVkaW8lMjIlM0ElMjBBdWRpbygpJTdEKSUwQWRzJTIwJTNEJTIwRGF0YXNldC5mcm9tX2RpY3QoJTdCJTIyYXVkaW8lMjIlM0ElMjBhdWRpbyU3RCUyQyUyMGZlYXR1cmVzJTNEZmVhdHVyZXMpJTBBZHMlMjAlM0QlMjBkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIpJTBBZHMlNUIwJTVEJTVCJTIyYXVkaW8lMjIlNUQlNUIlMjJhcnJheSUyMiU1RCUwQWRzJTVCMCU1RCU1QiUyMmF1ZGlvJTIyJTVEJTVCJTIyc2FtcGxpbmdfcmF0ZSUyMiU1RA==",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> Dataset, Features, Audio
<span class="hljs-meta">&gt;&gt;&gt; </span>audio = [<span class="hljs-string">&quot;path/to/audio.wav&quot;</span>] * <span class="hljs-number">10</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>features = Features({<span class="hljs-string">&quot;audio&quot;</span>: Audio()})
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = Dataset.from_dict({<span class="hljs-string">&quot;audio&quot;</span>: audio}, features=features)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;audio&quot;</span>][<span class="hljs-string">&quot;array&quot;</span>]
DeviceArray([-<span class="hljs-number">0.059021</span>  , -<span class="hljs-number">0.03894043</span>, -<span class="hljs-number">0.00735474</span>, ...,  <span class="hljs-number">0.0133667</span> ,
              <span class="hljs-number">0.01809692</span>,  <span class="hljs-number">0.00268555</span>], dtype=float32)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;audio&quot;</span>][<span class="hljs-string">&quot;sampling_rate&quot;</span>]
DeviceArray(<span class="hljs-number">44100</span>, dtype=int32, weak_type=<span class="hljs-literal">True</span>)`,wrap:!1}}),S=new ns({props:{title:"Data loading",local:"data-loading",headingTag:"h2"}}),O=new ns({props:{title:"Using with_format('jax')",local:"using-withformatjax",headingTag:"h3"}}),as=new J({props:{code:"ZnJvbSUyMGRhdGFzZXRzJTIwaW1wb3J0JTIwbG9hZF9kYXRhc2V0JTBBZHMlMjAlM0QlMjBsb2FkX2RhdGFzZXQoJTIybW5pc3QlMjIpJTBBZHMlMjAlM0QlMjBkcy53aXRoX2Zvcm1hdCglMjJqYXglMjIpJTBBZHMlNUIlMjJ0cmFpbiUyMiU1RCU1QjAlNUQ=",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = load_dataset(<span class="hljs-string">&quot;mnist&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds = ds.with_format(<span class="hljs-string">&quot;jax&quot;</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>ds[<span class="hljs-string">&quot;train&quot;</span>][<span class="hljs-number">0</span>]
{<span class="hljs-string">&#x27;image&#x27;</span>: DeviceArray([[  <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>, ...],
                       [  <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>, ...],
                       ...,
                       [  <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>, ...],
                       [  <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>,   <span class="hljs-number">0</span>, ...]], dtype=uint8),
 <span class="hljs-string">&#x27;label&#x27;</span>: DeviceArray(<span class="hljs-number">5</span>, dtype=int32)}`,wrap:!1}}),es=new J({props:{code:"Zm9yJTIwZXBvY2glMjBpbiUyMHJhbmdlKGVwb2NocyklM0ElMEElMjAlMjAlMjAlMjBmb3IlMjBiYXRjaCUyMGluJTIwZHMlNUIlMjJ0cmFpbiUyMiU1RC5pdGVyKGJhdGNoX3NpemUlM0QzMiklM0ElMEElMjAlMjAlMjAlMjAlMjAlMjAlMjAlMjB4JTJDJTIweSUyMCUzRCUyMGJhdGNoJTVCJTIyaW1hZ2UlMjIlNUQlMkMlMjBiYXRjaCU1QiUyMmxhYmVsJTIyJTVEJTBBJTIwJTIwJTIwJTIwJTIwJTIwJTIwJTIwLi4u",highlighted:`<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(epochs):
<span class="hljs-meta">... </span>    <span class="hljs-keyword">for</span> batch <span class="hljs-keyword">in</span> ds[<span class="hljs-string">&quot;train&quot;</span>].<span class="hljs-built_in">iter</span>(batch_size=<span class="hljs-number">32</span>):
<span class="hljs-meta">... </span>        x, y = batch[<span class="hljs-string">&quot;image&quot;</span>], batch[<span class="hljs-string">&quot;label&quot;</span>]
<span class="hljs-meta">... </span>        ...`,wrap:!1}}),ls=new Ua({props:{source:"https://github.com/huggingface/datasets/blob/main/docs/source/use_with_jax.mdx"}}),{c(){l=m("meta"),M=n(),r=m("p"),y=n(),h($.$$.fragment),rs=n(),C=m("p"),C.innerHTML=Hs,ms=n(),h(b.$$.fragment),cs=n(),h(I.$$.fragment),hs=n(),R=m("p"),R.textContent=Ls,is=n(),Z=m("p"),Z.innerHTML=Ss,os=n(),h(k.$$.fragment),js=n(),h(U.$$.fragment),ds=n(),X=m("p"),X.innerHTML=Ps,us=n(),h(v.$$.fragment),gs=n(),x=m("p"),x.textContent=Ks,Ms=n(),N=m("p"),N.innerHTML=Os,ys=n(),h(Q.$$.fragment),Js=n(),A=m("p"),A.innerHTML=sa,fs=n(),h(q.$$.fragment),bs=n(),G=m("p"),G.textContent=aa,Us=n(),h(_.$$.fragment),Ts=n(),h(E.$$.fragment),ws=n(),V=m("p"),V.innerHTML=ta,$s=n(),h(F.$$.fragment),Cs=n(),h(z.$$.fragment),Is=n(),Y=m("p"),Y.innerHTML=ea,Rs=n(),h(B.$$.fragment),Zs=n(),D=m("p"),D.textContent=la,ks=n(),W=m("p"),W.innerHTML=na,Xs=n(),h(T.$$.fragment),vs=n(),h(H.$$.fragment),xs=n(),h(w.$$.fragment),Ns=n(),h(L.$$.fragment),Qs=n(),h(S.$$.fragment),As=n(),P=m("p"),P.innerHTML=pa,qs=n(),K=m("p"),K.innerHTML=ra,Gs=n(),h(O.$$.fragment),_s=n(),ss=m("p"),ss.innerHTML=ma,Es=n(),h(as.$$.fragment),Vs=n(),ts=m("p"),ts.innerHTML=ca,Fs=n(),h(es.$$.fragment),zs=n(),h(ls.$$.fragment),Ys=n(),ps=m("p"),this.h()},l(s){const a=fa("svelte-u9bgzb",document.head);l=c(a,"META",{name:!0,content:!0}),a.forEach(t),M=p(s),r=c(s,"P",{}),da(r).forEach(t),y=p(s),i($.$$.fragment,s),rs=p(s),C=c(s,"P",{"data-svelte-h":!0}),g(C)!=="svelte-1a2kcgx"&&(C.innerHTML=Hs),ms=p(s),i(b.$$.fragment,s),cs=p(s),i(I.$$.fragment,s),hs=p(s),R=c(s,"P",{"data-svelte-h":!0}),g(R)!=="svelte-1ix49d8"&&(R.textContent=Ls),is=p(s),Z=c(s,"P",{"data-svelte-h":!0}),g(Z)!=="svelte-1upv66g"&&(Z.innerHTML=Ss),os=p(s),i(k.$$.fragment,s),js=p(s),i(U.$$.fragment,s),ds=p(s),X=c(s,"P",{"data-svelte-h":!0}),g(X)!=="svelte-1b7i338"&&(X.innerHTML=Ps),us=p(s),i(v.$$.fragment,s),gs=p(s),x=c(s,"P",{"data-svelte-h":!0}),g(x)!=="svelte-1c29vp0"&&(x.textContent=Ks),Ms=p(s),N=c(s,"P",{"data-svelte-h":!0}),g(N)!=="svelte-1ooxy7c"&&(N.innerHTML=Os),ys=p(s),i(Q.$$.fragment,s),Js=p(s),A=c(s,"P",{"data-svelte-h":!0}),g(A)!=="svelte-hj2cy7"&&(A.innerHTML=sa),fs=p(s),i(q.$$.fragment,s),bs=p(s),G=c(s,"P",{"data-svelte-h":!0}),g(G)!=="svelte-smjp9l"&&(G.textContent=aa),Us=p(s),i(_.$$.fragment,s),Ts=p(s),i(E.$$.fragment,s),ws=p(s),V=c(s,"P",{"data-svelte-h":!0}),g(V)!=="svelte-1gw41y9"&&(V.innerHTML=ta),$s=p(s),i(F.$$.fragment,s),Cs=p(s),i(z.$$.fragment,s),Is=p(s),Y=c(s,"P",{"data-svelte-h":!0}),g(Y)!=="svelte-p5lpqv"&&(Y.innerHTML=ea),Rs=p(s),i(B.$$.fragment,s),Zs=p(s),D=c(s,"P",{"data-svelte-h":!0}),g(D)!=="svelte-gkri0z"&&(D.textContent=la),ks=p(s),W=c(s,"P",{"data-svelte-h":!0}),g(W)!=="svelte-cbvp2m"&&(W.innerHTML=na),Xs=p(s),i(T.$$.fragment,s),vs=p(s),i(H.$$.fragment,s),xs=p(s),i(w.$$.fragment,s),Ns=p(s),i(L.$$.fragment,s),Qs=p(s),i(S.$$.fragment,s),As=p(s),P=c(s,"P",{"data-svelte-h":!0}),g(P)!=="svelte-1msw6w0"&&(P.innerHTML=pa),qs=p(s),K=c(s,"P",{"data-svelte-h":!0}),g(K)!=="svelte-1xmbhz7"&&(K.innerHTML=ra),Gs=p(s),i(O.$$.fragment,s),_s=p(s),ss=c(s,"P",{"data-svelte-h":!0}),g(ss)!=="svelte-1pw5xoa"&&(ss.innerHTML=ma),Es=p(s),i(as.$$.fragment,s),Vs=p(s),ts=c(s,"P",{"data-svelte-h":!0}),g(ts)!=="svelte-lnmbh3"&&(ts.innerHTML=ca),Fs=p(s),i(es.$$.fragment,s),zs=p(s),i(ls.$$.fragment,s),Ys=p(s),ps=c(s,"P",{}),da(ps).forEach(t),this.h()},h(){ua(l,"name","hf:doc:metadata"),ua(l,"content",Ra)},m(s,a){ba(document.head,l),e(s,M,a),e(s,r,a),e(s,y,a),o($,s,a),e(s,rs,a),e(s,C,a),e(s,ms,a),o(b,s,a),e(s,cs,a),o(I,s,a),e(s,hs,a),e(s,R,a),e(s,is,a),e(s,Z,a),e(s,os,a),o(k,s,a),e(s,js,a),o(U,s,a),e(s,ds,a),e(s,X,a),e(s,us,a),o(v,s,a),e(s,gs,a),e(s,x,a),e(s,Ms,a),e(s,N,a),e(s,ys,a),o(Q,s,a),e(s,Js,a),e(s,A,a),e(s,fs,a),o(q,s,a),e(s,bs,a),e(s,G,a),e(s,Us,a),o(_,s,a),e(s,Ts,a),o(E,s,a),e(s,ws,a),e(s,V,a),e(s,$s,a),o(F,s,a),e(s,Cs,a),o(z,s,a),e(s,Is,a),e(s,Y,a),e(s,Rs,a),o(B,s,a),e(s,Zs,a),e(s,D,a),e(s,ks,a),e(s,W,a),e(s,Xs,a),o(T,s,a),e(s,vs,a),o(H,s,a),e(s,xs,a),o(w,s,a),e(s,Ns,a),o(L,s,a),e(s,Qs,a),o(S,s,a),e(s,As,a),e(s,P,a),e(s,qs,a),e(s,K,a),e(s,Gs,a),o(O,s,a),e(s,_s,a),e(s,ss,a),e(s,Es,a),o(as,s,a),e(s,Vs,a),e(s,ts,a),e(s,Fs,a),o(es,s,a),e(s,zs,a),o(ls,s,a),e(s,Ys,a),e(s,ps,a),Bs=!0},p(s,[a]){const ha={};a&2&&(ha.$$scope={dirty:a,ctx:s}),b.$set(ha);const ia={};a&2&&(ia.$$scope={dirty:a,ctx:s}),U.$set(ia);const oa={};a&2&&(oa.$$scope={dirty:a,ctx:s}),T.$set(oa);const ja={};a&2&&(ja.$$scope={dirty:a,ctx:s}),w.$set(ja)},i(s){Bs||(j($.$$.fragment,s),j(b.$$.fragment,s),j(I.$$.fragment,s),j(k.$$.fragment,s),j(U.$$.fragment,s),j(v.$$.fragment,s),j(Q.$$.fragment,s),j(q.$$.fragment,s),j(_.$$.fragment,s),j(E.$$.fragment,s),j(F.$$.fragment,s),j(z.$$.fragment,s),j(B.$$.fragment,s),j(T.$$.fragment,s),j(H.$$.fragment,s),j(w.$$.fragment,s),j(L.$$.fragment,s),j(S.$$.fragment,s),j(O.$$.fragment,s),j(as.$$.fragment,s),j(es.$$.fragment,s),j(ls.$$.fragment,s),Bs=!0)},o(s){d($.$$.fragment,s),d(b.$$.fragment,s),d(I.$$.fragment,s),d(k.$$.fragment,s),d(U.$$.fragment,s),d(v.$$.fragment,s),d(Q.$$.fragment,s),d(q.$$.fragment,s),d(_.$$.fragment,s),d(E.$$.fragment,s),d(F.$$.fragment,s),d(z.$$.fragment,s),d(B.$$.fragment,s),d(T.$$.fragment,s),d(H.$$.fragment,s),d(w.$$.fragment,s),d(L.$$.fragment,s),d(S.$$.fragment,s),d(O.$$.fragment,s),d(as.$$.fragment,s),d(es.$$.fragment,s),d(ls.$$.fragment,s),Bs=!1},d(s){s&&(t(M),t(r),t(y),t(rs),t(C),t(ms),t(cs),t(hs),t(R),t(is),t(Z),t(os),t(js),t(ds),t(X),t(us),t(gs),t(x),t(Ms),t(N),t(ys),t(Js),t(A),t(fs),t(bs),t(G),t(Us),t(Ts),t(ws),t(V),t($s),t(Cs),t(Is),t(Y),t(Rs),t(Zs),t(D),t(ks),t(W),t(Xs),t(vs),t(xs),t(Ns),t(Qs),t(As),t(P),t(qs),t(K),t(Gs),t(_s),t(ss),t(Es),t(Vs),t(ts),t(Fs),t(zs),t(Ys),t(ps)),t(l),u($,s),u(b,s),u(I,s),u(k,s),u(U,s),u(v,s),u(Q,s),u(q,s),u(_,s),u(E,s),u(F,s),u(z,s),u(B,s),u(T,s),u(H,s),u(w,s),u(L,s),u(S,s),u(O,s),u(as,s),u(es,s),u(ls,s)}}}const Ra=`{"title":"Use with JAX","local":"use-with-jax","sections":[{"title":"Dataset format","local":"dataset-format","sections":[{"title":"N-dimensional arrays","local":"n-dimensional-arrays","sections":[],"depth":3},{"title":"Other feature types","local":"other-feature-types","sections":[],"depth":3}],"depth":2},{"title":"Data loading","local":"data-loading","sections":[{"title":"Using with_format('jax')","local":"using-withformatjax","sections":[],"depth":3}],"depth":2}],"depth":1}`;function Za(f){return Ma(()=>{new URLSearchParams(window.location.search).get("fw")}),[]}class Qa extends ya{constructor(l){super(),Ja(this,l,Za,Ia,ga,{})}}export{Qa as component};
