Buckets:

download
raw
73.2 kB
<meta charset="utf-8" /><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Diffusion 모델을 학습하기&quot;,&quot;local&quot;:&quot;diffusion-모델을-학습하기&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;학습 구성&quot;,&quot;local&quot;:&quot;학습-구성&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;데이터셋 불러오기&quot;,&quot;local&quot;:&quot;데이터셋-불러오기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;UNet2DModel 생성하기&quot;,&quot;local&quot;:&quot;unet2dmodel-생성하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;스케줄러 생성하기&quot;,&quot;local&quot;:&quot;스케줄러-생성하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;모델 학습하기&quot;,&quot;local&quot;:&quot;모델-학습하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;다음 단계&quot;,&quot;local&quot;:&quot;다음-단계&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}">
<link href="/docs/diffusers/pr_11739/ko/_app/immutable/assets/0.e3b0c442.css" rel="modulepreload">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/entry/start.3a826cf6.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/scheduler.23542ac5.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/singletons.c67d8ae0.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/index.7166da64.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/paths.f1d82941.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/entry/app.694c4502.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/preload-helper.84b08082.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/index.9b1f405b.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/nodes/0.556378bd.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/each.e59479a4.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/nodes/33.4b22b211.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/MermaidChart.svelte_svelte_type_style_lang.99afe1ab.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/CodeBlock.1c9c6858.js">
<link rel="modulepreload" href="/docs/diffusers/pr_11739/ko/_app/immutable/chunks/DocNotebookDropdown.68a629d2.js"><!-- HEAD_svelte-u9bgzb_START --><meta name="hf:doc:metadata" content="{&quot;title&quot;:&quot;Diffusion 모델을 학습하기&quot;,&quot;local&quot;:&quot;diffusion-모델을-학습하기&quot;,&quot;sections&quot;:[{&quot;title&quot;:&quot;학습 구성&quot;,&quot;local&quot;:&quot;학습-구성&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;데이터셋 불러오기&quot;,&quot;local&quot;:&quot;데이터셋-불러오기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;UNet2DModel 생성하기&quot;,&quot;local&quot;:&quot;unet2dmodel-생성하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;스케줄러 생성하기&quot;,&quot;local&quot;:&quot;스케줄러-생성하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;모델 학습하기&quot;,&quot;local&quot;:&quot;모델-학습하기&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2},{&quot;title&quot;:&quot;다음 단계&quot;,&quot;local&quot;:&quot;다음-단계&quot;,&quot;sections&quot;:[],&quot;depth&quot;:2}],&quot;depth&quot;:1}"><!-- HEAD_svelte-u9bgzb_END --> <p></p> <div class="items-center shrink-0 min-w-[100px] max-sm:min-w-[50px] justify-end ml-auto flex" style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"><div class="inline-flex rounded-md max-sm:rounded-sm"><button class="inline-flex items-center gap-1 max-sm:gap-0.5 h-6 max-sm:h-5 px-2 max-sm:px-1.5 text-[11px] max-sm:text-[9px] font-medium text-gray-800 border border-r-0 rounded-l-md max-sm:rounded-l-sm border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-live="polite"><span class="inline-flex items-center justify-center rounded-md p-0.5 max-sm:p-0"><svg class="w-3 h-3 max-sm:w-2.5 max-sm:h-2.5" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg></span> <span>Copy page</span></button> <button class="inline-flex items-center justify-center w-6 max-sm:w-5 h-6 max-sm:h-5 disabled:pointer-events-none text-sm text-gray-500 hover:text-gray-700 dark:hover:text-white rounded-r-md max-sm:rounded-r-sm border border-l transition border-gray-200 bg-white hover:shadow-inner dark:border-gray-850 dark:bg-gray-950 dark:text-gray-200 dark:hover:bg-gray-800" aria-haspopup="menu" aria-expanded="false" aria-label="Open copy menu"><svg class="transition-transform text-gray-400 overflow-visible w-3 h-3 max-sm:w-2.5 max-sm:h-2.5 rotate-0" width="1em" height="1em" viewBox="0 0 12 7" fill="none" xmlns="http://www.w3.org/2000/svg"><path d="M1 1L6 6L11 1" stroke="currentColor"></path></svg></button></div> </div> <div class="flex space-x-1 " style="float: right; margin-left: 10px; display: inline-flex; position: relative; z-index: 10;"> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Colab" class="!m-0" src="https://colab.research.google.com/assets/colab-badge.svg"> </button> </div> <div class="relative colab-dropdown "> <button class=" " type="button"> <img alt="Open In Studio Lab" class="!m-0" src="https://studiolab.sagemaker.aws/studiolab.svg"> </button> </div></div> <h1 class="relative group"><a id="diffusion-모델을-학습하기" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#diffusion-모델을-학습하기"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>Diffusion 모델을 학습하기</span></h1> <p data-svelte-h="svelte-1aqgth7">Unconditional 이미지 생성은 학습에 사용된 데이터셋과 유사한 이미지를 생성하는 diffusion 모델에서 인기 있는 어플리케이션입니다. 일반적으로, 가장 좋은 결과는 특정 데이터셋에 사전 훈련된 모델을 파인튜닝하는 것으로 얻을 수 있습니다. 이 <a href="https://huggingface.co/search/full-text?q=unconditional-image-generation&type=model" rel="nofollow">허브</a>에서 이러한 많은 체크포인트를 찾을 수 있지만, 만약 마음에 드는 체크포인트를 찾지 못했다면, 언제든지 스스로 학습할 수 있습니다!</p> <p data-svelte-h="svelte-32ojwt">이 튜토리얼은 나만의 🦋 나비 🦋를 생성하기 위해 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 데이터셋의 하위 집합에서 <code>UNet2DModel</code> 모델을 학습하는 방법을 가르쳐줄 것입니다.</p> <blockquote class="tip" data-svelte-h="svelte-19biaqu"><p>💡 이 학습 튜토리얼은 <a href="https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb" rel="nofollow">Training with 🧨 Diffusers</a> 노트북 기반으로 합니다. Diffusion 모델의 작동 방식 및 자세한 내용은 노트북을 확인하세요!</p></blockquote> <p data-svelte-h="svelte-1o0c2y8">시작 전에, 🤗 Datasets을 불러오고 전처리하기 위해 데이터셋이 설치되어 있는지 다수 GPU에서 학습을 간소화하기 위해 🤗 Accelerate 가 설치되어 있는지 확인하세요. 그 후 학습 메트릭을 시각화하기 위해 <a href="https://www.tensorflow.org/tensorboard" rel="nofollow">TensorBoard</a>를 또한 설치하세요. (또한 학습 추적을 위해 <a href="https://docs.wandb.ai/" rel="nofollow">Weights &amp; Biases</a>를 사용할 수 있습니다.)</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!pip install diffusers[training]<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1owzlqj">커뮤니티에 모델을 공유할 것을 권장하며, 이를 위해서 Hugging Face 계정에 로그인을 해야 합니다. (계정이 없다면 <a href="https://hf.co/join" rel="nofollow">여기</a>에서 만들 수 있습니다.) 노트북에서 로그인할 수 있으며 메시지가 표시되면 토큰을 입력할 수 있습니다.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> notebook_login
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_login()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-10ayust">또는 터미널로 로그인할 수 있습니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->hf auth login<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1w5k095">모델 체크포인트가 상당히 크기 때문에 <a href="https://git-lfs.com/" rel="nofollow">Git-LFS</a>에서 대용량 파일의 버전 관리를 할 수 있습니다.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START -->!sudo apt -qq install git-lfs
!git config --global credential.helper store<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="학습-구성" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#학습-구성"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>학습 구성</span></h2> <p data-svelte-h="svelte-1rpvlkg">편의를 위해 학습 파라미터들을 포함한 <code>TrainingConfig</code> 클래스를 생성합니다 (자유롭게 조정 가능):</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> dataclasses <span class="hljs-keyword">import</span> dataclass
<span class="hljs-meta">&gt;&gt;&gt; </span>@dataclass
<span class="hljs-meta">... </span><span class="hljs-keyword">class</span> <span class="hljs-title class_">TrainingConfig</span>:
<span class="hljs-meta">... </span> image_size = <span class="hljs-number">128</span> <span class="hljs-comment"># 생성되는 이미지 해상도</span>
<span class="hljs-meta">... </span> train_batch_size = <span class="hljs-number">16</span>
<span class="hljs-meta">... </span> eval_batch_size = <span class="hljs-number">16</span> <span class="hljs-comment"># 평가 동안에 샘플링할 이미지 수</span>
<span class="hljs-meta">... </span> num_epochs = <span class="hljs-number">50</span>
<span class="hljs-meta">... </span> gradient_accumulation_steps = <span class="hljs-number">1</span>
<span class="hljs-meta">... </span> learning_rate = <span class="hljs-number">1e-4</span>
<span class="hljs-meta">... </span> lr_warmup_steps = <span class="hljs-number">500</span>
<span class="hljs-meta">... </span> save_image_epochs = <span class="hljs-number">10</span>
<span class="hljs-meta">... </span> save_model_epochs = <span class="hljs-number">30</span>
<span class="hljs-meta">... </span> mixed_precision = <span class="hljs-string">&quot;fp16&quot;</span> <span class="hljs-comment"># `no`는 float32, 자동 혼합 정밀도를 위한 `fp16`</span>
<span class="hljs-meta">... </span> output_dir = <span class="hljs-string">&quot;ddpm-butterflies-128&quot;</span> <span class="hljs-comment"># 로컬 및 HF Hub에 저장되는 모델명</span>
<span class="hljs-meta">... </span> push_to_hub = <span class="hljs-literal">True</span> <span class="hljs-comment"># 저장된 모델을 HF Hub에 업로드할지 여부</span>
<span class="hljs-meta">... </span> hub_private_repo = <span class="hljs-literal">None</span>
<span class="hljs-meta">... </span> overwrite_output_dir = <span class="hljs-literal">True</span> <span class="hljs-comment"># 노트북을 다시 실행할 때 이전 모델에 덮어씌울지</span>
<span class="hljs-meta">... </span> seed = <span class="hljs-number">0</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>config = TrainingConfig()<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="데이터셋-불러오기" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#데이터셋-불러오기"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>데이터셋 불러오기</span></h2> <p data-svelte-h="svelte-1our457">🤗 Datasets 라이브러리와 <a href="https://huggingface.co/datasets/huggan/smithsonian_butterflies_subset" rel="nofollow">Smithsonian Butterflies</a> 데이터셋을 쉽게 불러올 수 있습니다.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> datasets <span class="hljs-keyword">import</span> load_dataset
<span class="hljs-meta">&gt;&gt;&gt; </span>config.dataset_name = <span class="hljs-string">&quot;huggan/smithsonian_butterflies_subset&quot;</span>
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset = load_dataset(config.dataset_name, split=<span class="hljs-string">&quot;train&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1hi7huh">💡<a href="https://huggingface.co/huggan" rel="nofollow">HugGan Community Event</a> 에서 추가의 데이터셋을 찾거나 로컬의 <a href="https://huggingface.co/docs/datasets/image_dataset#imagefolder" rel="nofollow"><code>ImageFolder</code></a>를 만듦으로써 나만의 데이터셋을 사용할 수 있습니다. HugGan Community Event 에 가져온 데이터셋의 경우 리포지토리의 id로 <code>config.dataset_name</code> 을 설정하고, 나만의 이미지를 사용하는 경우 <code>imagefolder</code> 를 설정합니다.</p> <p data-svelte-h="svelte-g2btn3">🤗 Datasets은 <code>Image</code> 기능을 사용해 자동으로 이미지 데이터를 디코딩하고 <a href="https://pillow.readthedocs.io/en/stable/reference/Image.html" rel="nofollow"><code>PIL.Image</code></a>로 불러옵니다. 이를 시각화 해보면:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> matplotlib.pyplot <span class="hljs-keyword">as</span> plt
<span class="hljs-meta">&gt;&gt;&gt; </span>fig, axs = plt.subplots(<span class="hljs-number">1</span>, <span class="hljs-number">4</span>, figsize=(<span class="hljs-number">16</span>, <span class="hljs-number">4</span>))
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(dataset[:<span class="hljs-number">4</span>][<span class="hljs-string">&quot;image&quot;</span>]):
<span class="hljs-meta">... </span> axs[i].imshow(image)
<span class="hljs-meta">... </span> axs[i].set_axis_off()
<span class="hljs-meta">&gt;&gt;&gt; </span>fig.show()<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-12z3lda"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_ds.png"></p> <p data-svelte-h="svelte-2vcep9">이미지는 모두 다른 사이즈이기 때문에, 우선 전처리가 필요합니다:</p> <ul data-svelte-h="svelte-lrd3tn"><li><code>Resize</code><code>config.image_size</code> 에 정의된 이미지 사이즈로 변경합니다.</li> <li><code>RandomHorizontalFlip</code> 은 랜덤적으로 이미지를 미러링하여 데이터셋을 보강합니다.</li> <li><code>Normalize</code> 는 모델이 예상하는 [-1, 1] 범위로 픽셀 값을 재조정 하는데 중요합니다.</li></ul> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> torchvision <span class="hljs-keyword">import</span> transforms
<span class="hljs-meta">&gt;&gt;&gt; </span>preprocess = transforms.Compose(
<span class="hljs-meta">... </span> [
<span class="hljs-meta">... </span> transforms.Resize((config.image_size, config.image_size)),
<span class="hljs-meta">... </span> transforms.RandomHorizontalFlip(),
<span class="hljs-meta">... </span> transforms.ToTensor(),
<span class="hljs-meta">... </span> transforms.Normalize([<span class="hljs-number">0.5</span>], [<span class="hljs-number">0.5</span>]),
<span class="hljs-meta">... </span> ]
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-mhh25q">학습 도중에 <code>preprocess</code> 함수를 적용하려면 🤗 Datasets의 <code>set_transform</code> 방법이 사용됩니다.</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">transform</span>(<span class="hljs-params">examples</span>):
<span class="hljs-meta">... </span> images = [preprocess(image.convert(<span class="hljs-string">&quot;RGB&quot;</span>)) <span class="hljs-keyword">for</span> image <span class="hljs-keyword">in</span> examples[<span class="hljs-string">&quot;image&quot;</span>]]
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> {<span class="hljs-string">&quot;images&quot;</span>: images}
<span class="hljs-meta">&gt;&gt;&gt; </span>dataset.set_transform(transform)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-jbxdac">이미지의 크기가 조정되었는지 확인하기 위해 이미지를 다시 시각화해보세요. 이제 <a href="https://pytorch.org/docs/stable/data#torch.utils.data.DataLoader" rel="nofollow">DataLoader</a>에 데이터셋을 포함해 학습할 준비가 되었습니다!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span>train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.train_batch_size, shuffle=<span class="hljs-literal">True</span>)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="unet2dmodel-생성하기" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#unet2dmodel-생성하기"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>UNet2DModel 생성하기</span></h2> <p data-svelte-h="svelte-ywj4cf">🧨 Diffusers에 사전학습된 모델들은 모델 클래스에서 원하는 파라미터로 쉽게 생성할 수 있습니다. 예를 들어, <code>UNet2DModel</code>를 생성하려면:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> UNet2DModel
<span class="hljs-meta">&gt;&gt;&gt; </span>model = UNet2DModel(
<span class="hljs-meta">... </span> sample_size=config.image_size, <span class="hljs-comment"># 타겟 이미지 해상도</span>
<span class="hljs-meta">... </span> in_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 입력 채널 수, RGB 이미지에서 3</span>
<span class="hljs-meta">... </span> out_channels=<span class="hljs-number">3</span>, <span class="hljs-comment"># 출력 채널 수</span>
<span class="hljs-meta">... </span> layers_per_block=<span class="hljs-number">2</span>, <span class="hljs-comment"># UNet 블럭당 몇 개의 ResNet 레이어가 사용되는지</span>
<span class="hljs-meta">... </span> block_out_channels=(<span class="hljs-number">128</span>, <span class="hljs-number">128</span>, <span class="hljs-number">256</span>, <span class="hljs-number">256</span>, <span class="hljs-number">512</span>, <span class="hljs-number">512</span>), <span class="hljs-comment"># 각 UNet 블럭을 위한 출력 채널 수</span>
<span class="hljs-meta">... </span> down_block_types=(
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>, <span class="hljs-comment"># 일반적인 ResNet 다운샘플링 블럭</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;AttnDownBlock2D&quot;</span>, <span class="hljs-comment"># spatial self-attention이 포함된 일반적인 ResNet 다운샘플링 블럭</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;DownBlock2D&quot;</span>,
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span> up_block_types=(
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>, <span class="hljs-comment"># 일반적인 ResNet 업샘플링 블럭</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;AttnUpBlock2D&quot;</span>, <span class="hljs-comment"># spatial self-attention이 포함된 일반적인 ResNet 업샘플링 블럭</span>
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> <span class="hljs-string">&quot;UpBlock2D&quot;</span>,
<span class="hljs-meta">... </span> ),
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1x8klru">샘플의 이미지 크기와 모델 출력 크기가 맞는지 빠르게 확인하기 위한 좋은 아이디어가 있습니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span>sample_image = dataset[<span class="hljs-number">0</span>][<span class="hljs-string">&quot;images&quot;</span>].unsqueeze(<span class="hljs-number">0</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Input shape:&quot;</span>, sample_image.shape)
Input shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-built_in">print</span>(<span class="hljs-string">&quot;Output shape:&quot;</span>, model(sample_image, timestep=<span class="hljs-number">0</span>).sample.shape)
Output shape: torch.Size([<span class="hljs-number">1</span>, <span class="hljs-number">3</span>, <span class="hljs-number">128</span>, <span class="hljs-number">128</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-hu0j5c">훌륭해요! 다음, 이미지에 약간의 노이즈를 더하기 위해 스케줄러가 필요합니다.</p> <h2 class="relative group"><a id="스케줄러-생성하기" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#스케줄러-생성하기"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>스케줄러 생성하기</span></h2> <p data-svelte-h="svelte-1ickp2n">스케줄러는 모델을 학습 또는 추론에 사용하는지에 따라 다르게 작동합니다. 추론시에, 스케줄러는 노이즈로부터 이미지를 생성합니다. 학습시 스케줄러는 diffusion 과정에서의 특정 포인트로부터 모델의 출력 또는 샘플을 가져와 <em>노이즈 스케줄</em><em>업데이트 규칙</em>에 따라 이미지에 노이즈를 적용합니다.</p> <p data-svelte-h="svelte-boidtv"><code>DDPMScheduler</code>를 보면 이전으로부터 <code>sample_image</code>에 랜덤한 노이즈를 더하는 <code>add_noise</code> 메서드를 사용합니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> PIL <span class="hljs-keyword">import</span> Image
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMScheduler
<span class="hljs-meta">&gt;&gt;&gt; </span>noise_scheduler = DDPMScheduler(num_train_timesteps=<span class="hljs-number">1000</span>)
<span class="hljs-meta">&gt;&gt;&gt; </span>noise = torch.randn(sample_image.shape)
<span class="hljs-meta">&gt;&gt;&gt; </span>timesteps = torch.LongTensor([<span class="hljs-number">50</span>])
<span class="hljs-meta">&gt;&gt;&gt; </span>noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)
<span class="hljs-meta">&gt;&gt;&gt; </span>Image.fromarray(((noisy_image.permute(<span class="hljs-number">0</span>, <span class="hljs-number">2</span>, <span class="hljs-number">3</span>, <span class="hljs-number">1</span>) + <span class="hljs-number">1.0</span>) * <span class="hljs-number">127.5</span>).<span class="hljs-built_in">type</span>(torch.uint8).numpy()[<span class="hljs-number">0</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-3yki19"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/noisy_butterfly.png"></p> <p data-svelte-h="svelte-1oo0f0r">모델의 학습 목적은 이미지에 더해진 노이즈를 예측하는 것입니다. 이 단계에서 손실은 다음과 같이 계산될 수 있습니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> torch.nn.functional <span class="hljs-keyword">as</span> F
<span class="hljs-meta">&gt;&gt;&gt; </span>noise_pred = model(noisy_image, timesteps).sample
<span class="hljs-meta">&gt;&gt;&gt; </span>loss = F.mse_loss(noise_pred, noise)<!-- HTML_TAG_END --></pre></div> <h2 class="relative group"><a id="모델-학습하기" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#모델-학습하기"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>모델 학습하기</span></h2> <p data-svelte-h="svelte-1syjdvo">지금까지, 모델 학습을 시작하기 위해 많은 부분을 갖추었으며 이제 남은 것은 모든 것을 조합하는 것입니다.</p> <p data-svelte-h="svelte-1x5az67">우선 옵티마이저(optimizer)와 학습률 스케줄러(learning rate scheduler)가 필요할 것입니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers.optimization <span class="hljs-keyword">import</span> get_cosine_schedule_with_warmup
<span class="hljs-meta">&gt;&gt;&gt; </span>optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
<span class="hljs-meta">&gt;&gt;&gt; </span>lr_scheduler = get_cosine_schedule_with_warmup(
<span class="hljs-meta">... </span> optimizer=optimizer,
<span class="hljs-meta">... </span> num_warmup_steps=config.lr_warmup_steps,
<span class="hljs-meta">... </span> num_training_steps=(<span class="hljs-built_in">len</span>(train_dataloader) * config.num_epochs),
<span class="hljs-meta">... </span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-baczkn">그 후, 모델을 평가하는 방법이 필요합니다. 평가를 위해, <code>DDPMPipeline</code>을 사용해 배치의 이미지 샘플들을 생성하고 그리드 형태로 저장할 수 있습니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> diffusers <span class="hljs-keyword">import</span> DDPMPipeline
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> math
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> os
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">make_grid</span>(<span class="hljs-params">images, rows, cols</span>):
<span class="hljs-meta">... </span> w, h = images[<span class="hljs-number">0</span>].size
<span class="hljs-meta">... </span> grid = Image.new(<span class="hljs-string">&quot;RGB&quot;</span>, size=(cols * w, rows * h))
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> i, image <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(images):
<span class="hljs-meta">... </span> grid.paste(image, box=(i % cols * w, i // cols * h))
<span class="hljs-meta">... </span> <span class="hljs-keyword">return</span> grid
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">evaluate</span>(<span class="hljs-params">config, epoch, pipeline</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># 랜덤한 노이즈로 부터 이미지를 추출합니다.(이는 역전파 diffusion 과정입니다.)</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 기본 파이프라인 출력 형태는 `List[PIL.Image]` 입니다.</span>
<span class="hljs-meta">... </span> images = pipeline(
<span class="hljs-meta">... </span> batch_size=config.eval_batch_size,
<span class="hljs-meta">... </span> generator=torch.manual_seed(config.seed),
<span class="hljs-meta">... </span> ).images
<span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지들을 그리드로 만들어줍니다.</span>
<span class="hljs-meta">... </span> image_grid = make_grid(images, rows=<span class="hljs-number">4</span>, cols=<span class="hljs-number">4</span>)
<span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지들을 저장합니다.</span>
<span class="hljs-meta">... </span> test_dir = os.path.join(config.output_dir, <span class="hljs-string">&quot;samples&quot;</span>)
<span class="hljs-meta">... </span> os.makedirs(test_dir, exist_ok=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> image_grid.save(<span class="hljs-string">f&quot;<span class="hljs-subst">{test_dir}</span>/<span class="hljs-subst">{epoch:04d}</span>.png&quot;</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-14zfk37">TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽게 수행하기 위해 🤗 Accelerate를 학습 루프에 함께 앞서 말한 모든 구성 정보들을 묶어 진행할 수 있습니다. 허브에 모델을 업로드 하기 위해 리포지토리 이름 및 정보를 가져오기 위한 함수를 작성하고 허브에 업로드할 수 있습니다.</p> <p data-svelte-h="svelte-u719rq">💡아래의 학습 루프는 어렵고 길어 보일 수 있지만, 나중에 한 줄의 코드로 학습을 한다면 그만한 가치가 있을 것입니다! 만약 기다리지 못하고 이미지를 생성하고 싶다면, 아래 코드를 자유롭게 붙여넣고 작동시키면 됩니다. 🤗</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> Accelerator
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> huggingface_hub <span class="hljs-keyword">import</span> create_repo, upload_folder
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> tqdm.auto <span class="hljs-keyword">import</span> tqdm
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> pathlib <span class="hljs-keyword">import</span> Path
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> os
<span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">def</span> <span class="hljs-title function_">train_loop</span>(<span class="hljs-params">config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler</span>):
<span class="hljs-meta">... </span> <span class="hljs-comment"># Initialize accelerator and tensorboard logging</span>
<span class="hljs-meta">... </span> accelerator = Accelerator(
<span class="hljs-meta">... </span> mixed_precision=config.mixed_precision,
<span class="hljs-meta">... </span> gradient_accumulation_steps=config.gradient_accumulation_steps,
<span class="hljs-meta">... </span> log_with=<span class="hljs-string">&quot;tensorboard&quot;</span>,
<span class="hljs-meta">... </span> project_dir=os.path.join(config.output_dir, <span class="hljs-string">&quot;logs&quot;</span>),
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process:
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.output_dir <span class="hljs-keyword">is</span> <span class="hljs-keyword">not</span> <span class="hljs-literal">None</span>:
<span class="hljs-meta">... </span> os.makedirs(config.output_dir, exist_ok=<span class="hljs-literal">True</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub:
<span class="hljs-meta">... </span> repo_id = create_repo(
<span class="hljs-meta">... </span> repo_id=config.hub_model_id <span class="hljs-keyword">or</span> Path(config.output_dir).name, exist_ok=<span class="hljs-literal">True</span>
<span class="hljs-meta">... </span> ).repo_id
<span class="hljs-meta">... </span> accelerator.init_trackers(<span class="hljs-string">&quot;train_example&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-comment"># 모든 것이 준비되었습니다.</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 기억해야 할 특정한 순서는 없으며 준비한 방법에 제공한 것과 동일한 순서로 객체의 압축을 풀면 됩니다.</span>
<span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
<span class="hljs-meta">... </span> model, optimizer, train_dataloader, lr_scheduler
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> global_step = <span class="hljs-number">0</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 이제 모델을 학습합니다.</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> epoch <span class="hljs-keyword">in</span> <span class="hljs-built_in">range</span>(config.num_epochs):
<span class="hljs-meta">... </span> progress_bar = tqdm(total=<span class="hljs-built_in">len</span>(train_dataloader), disable=<span class="hljs-keyword">not</span> accelerator.is_local_main_process)
<span class="hljs-meta">... </span> progress_bar.set_description(<span class="hljs-string">f&quot;Epoch <span class="hljs-subst">{epoch}</span>&quot;</span>)
<span class="hljs-meta">... </span> <span class="hljs-keyword">for</span> step, batch <span class="hljs-keyword">in</span> <span class="hljs-built_in">enumerate</span>(train_dataloader):
<span class="hljs-meta">... </span> clean_images = batch[<span class="hljs-string">&quot;images&quot;</span>]
<span class="hljs-meta">... </span> <span class="hljs-comment"># 이미지에 더할 노이즈를 샘플링합니다.</span>
<span class="hljs-meta">... </span> noise = torch.randn(clean_images.shape, device=clean_images.device)
<span class="hljs-meta">... </span> bs = clean_images.shape[<span class="hljs-number">0</span>]
<span class="hljs-meta">... </span> <span class="hljs-comment"># 각 이미지를 위한 랜덤한 타임스텝(timestep)을 샘플링합니다.</span>
<span class="hljs-meta">... </span> timesteps = torch.randint(
<span class="hljs-meta">... </span> <span class="hljs-number">0</span>, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,
<span class="hljs-meta">... </span> dtype=torch.int64
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-comment"># 각 타임스텝의 노이즈 크기에 따라 깨끗한 이미지에 노이즈를 추가합니다.</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># (이는 foward diffusion 과정입니다.)</span>
<span class="hljs-meta">... </span> noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
<span class="hljs-meta">... </span> <span class="hljs-keyword">with</span> accelerator.accumulate(model):
<span class="hljs-meta">... </span> <span class="hljs-comment"># 노이즈를 반복적으로 예측합니다.</span>
<span class="hljs-meta">... </span> noise_pred = model(noisy_images, timesteps, return_dict=<span class="hljs-literal">False</span>)[<span class="hljs-number">0</span>]
<span class="hljs-meta">... </span> loss = F.mse_loss(noise_pred, noise)
<span class="hljs-meta">... </span> accelerator.backward(loss)
<span class="hljs-meta">... </span> accelerator.clip_grad_norm_(model.parameters(), <span class="hljs-number">1.0</span>)
<span class="hljs-meta">... </span> optimizer.step()
<span class="hljs-meta">... </span> lr_scheduler.step()
<span class="hljs-meta">... </span> optimizer.zero_grad()
<span class="hljs-meta">... </span> progress_bar.update(<span class="hljs-number">1</span>)
<span class="hljs-meta">... </span> logs = {<span class="hljs-string">&quot;loss&quot;</span>: loss.detach().item(), <span class="hljs-string">&quot;lr&quot;</span>: lr_scheduler.get_last_lr()[<span class="hljs-number">0</span>], <span class="hljs-string">&quot;step&quot;</span>: global_step}
<span class="hljs-meta">... </span> progress_bar.set_postfix(**logs)
<span class="hljs-meta">... </span> accelerator.log(logs, step=global_step)
<span class="hljs-meta">... </span> global_step += <span class="hljs-number">1</span>
<span class="hljs-meta">... </span> <span class="hljs-comment"># 각 에포크가 끝난 후 evaluate()와 몇 가지 데모 이미지를 선택적으로 샘플링하고 모델을 저장합니다.</span>
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> accelerator.is_main_process:
<span class="hljs-meta">... </span> pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_image_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>:
<span class="hljs-meta">... </span> evaluate(config, epoch, pipeline)
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> (epoch + <span class="hljs-number">1</span>) % config.save_model_epochs == <span class="hljs-number">0</span> <span class="hljs-keyword">or</span> epoch == config.num_epochs - <span class="hljs-number">1</span>:
<span class="hljs-meta">... </span> <span class="hljs-keyword">if</span> config.push_to_hub:
<span class="hljs-meta">... </span> upload_folder(
<span class="hljs-meta">... </span> repo_id=repo_id,
<span class="hljs-meta">... </span> folder_path=config.output_dir,
<span class="hljs-meta">... </span> commit_message=<span class="hljs-string">f&quot;Epoch <span class="hljs-subst">{epoch}</span>&quot;</span>,
<span class="hljs-meta">... </span> ignore_patterns=[<span class="hljs-string">&quot;step_*&quot;</span>, <span class="hljs-string">&quot;epoch_*&quot;</span>],
<span class="hljs-meta">... </span> )
<span class="hljs-meta">... </span> <span class="hljs-keyword">else</span>:
<span class="hljs-meta">... </span> pipeline.save_pretrained(config.output_dir)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-zyu14c">휴, 코드가 꽤 많았네요! 하지만 🤗 Accelerate의 <code>notebook_launcher</code> 함수와 학습을 시작할 준비가 되었습니다. 함수에 학습 루프, 모든 학습 인수, 학습에 사용할 프로세스 수(사용 가능한 GPU의 수를 변경할 수 있음)를 전달합니다:</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">from</span> accelerate <span class="hljs-keyword">import</span> notebook_launcher
<span class="hljs-meta">&gt;&gt;&gt; </span>args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)
<span class="hljs-meta">&gt;&gt;&gt; </span>notebook_launcher(train_loop, args, num_processes=<span class="hljs-number">1</span>)<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1dbylv7">한번 학습이 완료되면, diffusion 모델로 생성된 최종 🦋이미지🦋를 확인해보길 바랍니다!</p> <div class="code-block relative "><div class="absolute top-2.5 right-4"><button class="inline-flex items-center relative text-sm focus:text-green-500 cursor-pointer focus:outline-none transition duration-200 ease-in-out opacity-0 mx-0.5 text-gray-600 " title="code excerpt" type="button"><svg class="" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M28,10V28H10V10H28m0-2H10a2,2,0,0,0-2,2V28a2,2,0,0,0,2,2H28a2,2,0,0,0,2-2V10a2,2,0,0,0-2-2Z" transform="translate(0)"></path><path d="M4,18H2V4A2,2,0,0,1,4,2H18V4H4Z" transform="translate(0)"></path><rect fill="none" width="32" height="32"></rect></svg> <div class="absolute pointer-events-none transition-opacity bg-black text-white py-1 px-2 leading-tight rounded font-normal shadow left-1/2 top-full transform -translate-x-1/2 translate-y-2 opacity-0"><div class="absolute bottom-full left-1/2 transform -translate-x-1/2 w-0 h-0 border-black border-4 border-t-0" style="border-left-color: transparent; border-right-color: transparent; "></div> Copied</div></button></div> <pre class=""><!-- HTML_TAG_START --><span class="hljs-meta">&gt;&gt;&gt; </span><span class="hljs-keyword">import</span> glob
<span class="hljs-meta">&gt;&gt;&gt; </span>sample_images = <span class="hljs-built_in">sorted</span>(glob.glob(<span class="hljs-string">f&quot;<span class="hljs-subst">{config.output_dir}</span>/samples/*.png&quot;</span>))
<span class="hljs-meta">&gt;&gt;&gt; </span>Image.<span class="hljs-built_in">open</span>(sample_images[-<span class="hljs-number">1</span>])<!-- HTML_TAG_END --></pre></div> <p data-svelte-h="svelte-1bzvmcv"><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/butterflies_final.png"></p> <h2 class="relative group"><a id="다음-단계" class="header-link block pr-1.5 text-lg no-hover:hidden with-hover:absolute with-hover:p-1.5 with-hover:opacity-0 with-hover:group-hover:opacity-100 with-hover:right-full" href="#다음-단계"><span><svg class="" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 256"><path d="M167.594 88.393a8.001 8.001 0 0 1 0 11.314l-67.882 67.882a8 8 0 1 1-11.314-11.315l67.882-67.881a8.003 8.003 0 0 1 11.314 0zm-28.287 84.86l-28.284 28.284a40 40 0 0 1-56.567-56.567l28.284-28.284a8 8 0 0 0-11.315-11.315l-28.284 28.284a56 56 0 0 0 79.196 79.197l28.285-28.285a8 8 0 1 0-11.315-11.314zM212.852 43.14a56.002 56.002 0 0 0-79.196 0l-28.284 28.284a8 8 0 1 0 11.314 11.314l28.284-28.284a40 40 0 0 1 56.568 56.567l-28.285 28.285a8 8 0 0 0 11.315 11.314l28.284-28.284a56.065 56.065 0 0 0 0-79.196z" fill="currentColor"></path></svg></span></a> <span>다음 단계</span></h2> <p data-svelte-h="svelte-1mf9wqw">Unconditional 이미지 생성은 학습될 수 있는 작업 중 하나의 예시입니다. 다른 작업과 학습 방법은 <a href="../training/overview">🧨 Diffusers 학습 예시</a> 페이지에서 확인할 수 있습니다. 다음은 학습할 수 있는 몇 가지 예시입니다:</p> <ul data-svelte-h="svelte-y5d1yz"><li><a href="../training/text_inversion">Textual Inversion</a>, 특정 시각적 개념을 학습시켜 생성된 이미지에 통합시키는 알고리즘입니다.</li> <li><a href="../training/dreambooth">DreamBooth</a>, 주제에 대한 몇 가지 입력 이미지들이 주어지면 주제에 대한 개인화된 이미지를 생성하기 위한 기술입니다.</li> <li><a href="../training/text2image">Guide</a> 데이터셋에 Stable Diffusion 모델을 파인튜닝하는 방법입니다.</li> <li><a href="../training/lora">Guide</a> LoRA를 사용해 매우 큰 모델을 빠르게 파인튜닝하기 위한 메모리 효율적인 기술입니다.</li></ul> <a class="!text-gray-400 !no-underline text-sm flex items-center not-prose mt-4" href="https://github.com/huggingface/diffusers/blob/main/docs/source/ko/tutorials/basic_training.md" target="_blank"><svg class="mr-1" xmlns="http://www.w3.org/2000/svg" aria-hidden="true" fill="currentColor" focusable="false" role="img" width="1em" height="1em" preserveAspectRatio="xMidYMid meet" viewBox="0 0 32 32"><path d="M31,16l-7,7l-1.41-1.41L28.17,16l-5.58-5.59L24,9l7,7z"></path><path d="M1,16l7-7l1.41,1.41L3.83,16l5.58,5.59L8,23l-7-7z"></path><path d="M12.419,25.484L17.639,6.552l1.932,0.518L14.351,26.002z"></path></svg> <span data-svelte-h="svelte-zjs2n5"><span class="underline">Update</span> on GitHub</span></a> <p></p>
<script>
{
__sveltekit_18rq40q = {
assets: "/docs/diffusers/pr_11739/ko",
base: "/docs/diffusers/pr_11739/ko",
env: {}
};
const element = document.currentScript.parentElement;
const data = [null,null];
Promise.all([
import("/docs/diffusers/pr_11739/ko/_app/immutable/entry/start.3a826cf6.js"),
import("/docs/diffusers/pr_11739/ko/_app/immutable/entry/app.694c4502.js")
]).then(([kit, app]) => {
kit.start(app, element, {
node_ids: [0, 33],
data,
form: null,
error: null
});
});
}
</script>

Xet Storage Details

Size:
73.2 kB
·
Xet hash:
639e949387c437218f424c62e25b983e829975300cde93dc64a414eceec2f674

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.