davidberenstein1957 HF staff commited on
Commit
f9bfc2d
·
1 Parent(s): 84c9d20

fix: oauth

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -1,11 +1,13 @@
1
  import io
2
  import multiprocessing
3
  import time
 
4
 
5
  import gradio as gr
6
  import pandas as pd
7
  from datasets import Dataset
8
  from distilabel.distiset import Distiset
 
9
  from huggingface_hub import upload_file
10
 
11
  from src.distilabel_dataset_generator.pipelines.sft import (
@@ -18,7 +20,6 @@ from src.distilabel_dataset_generator.pipelines.sft import (
18
  get_prompt_generation_step,
19
  )
20
  from src.distilabel_dataset_generator.utils import (
21
- OAuthToken,
22
  get_login_button,
23
  get_org_dropdown,
24
  swap_visibilty,
@@ -146,7 +147,7 @@ def push_to_hub(
146
  private: bool = True,
147
  org_name: str = None,
148
  repo_name: str = None,
149
- oauth_token: OAuthToken = None,
150
  ):
151
  distiset = Distiset(
152
  {
@@ -163,7 +164,7 @@ def push_to_hub(
163
 
164
 
165
  def upload_pipeline_code(
166
- pipeline_code, org_name, repo_name, oauth_token: OAuthToken = None
167
  ):
168
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
169
  upload_file(
 
1
  import io
2
  import multiprocessing
3
  import time
4
+ from typing import Union
5
 
6
  import gradio as gr
7
  import pandas as pd
8
  from datasets import Dataset
9
  from distilabel.distiset import Distiset
10
+ from gradio.oauth import OAuthToken
11
  from huggingface_hub import upload_file
12
 
13
  from src.distilabel_dataset_generator.pipelines.sft import (
 
20
  get_prompt_generation_step,
21
  )
22
  from src.distilabel_dataset_generator.utils import (
 
23
  get_login_button,
24
  get_org_dropdown,
25
  swap_visibilty,
 
147
  private: bool = True,
148
  org_name: str = None,
149
  repo_name: str = None,
150
+ oauth_token: Union[OAuthToken, None] = None,
151
  ):
152
  distiset = Distiset(
153
  {
 
164
 
165
 
166
  def upload_pipeline_code(
167
+ pipeline_code, org_name, repo_name, oauth_token: Union[OAuthToken, None] = None
168
  ):
169
  with io.BytesIO(pipeline_code.encode("utf-8")) as f:
170
  upload_file(