Sakalti commited on
Commit
d48bb37
·
verified ·
1 Parent(s): 9287bf6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
+ from datasets import load_dataset, Dataset, DatasetDict
4
+ import os
5
+
6
+ def train_and_deploy(write_token, repo_name, license_text):
7
+ # トークンを環境変数に設定
8
+ os.environ['HF_WRITE_TOKEN'] = write_token
9
+
10
+ # ライセンスファイルを作成
11
+ with open("LICENSE", "w") as f:
12
+ f.write(license_text)
13
+
14
+ # モデルとトークナイザーの読み込み
15
+ model_name = "EleutherAI/pythia-14m" # トレーニング対象のモデル
16
+ model = AutoModelForCausalLM.from_pretrained(model_name)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ # FBK-MT/mosel データセットの読み込み
20
+ dataset = load_dataset("FBK-MT/mosel")
21
+
22
+ # データセットのキーを確認
23
+ print(f"Dataset keys: {dataset.keys()}")
24
+ if "train" not in dataset:
25
+ raise KeyError("The dataset does not contain a 'train' split.")
26
+ if "test" not in dataset:
27
+ raise KeyError("The dataset does not contain a 'test' split.")
28
+
29
+ # データセットの最初のエントリのキーを確認
30
+ print(f"Sample keys in 'train' split: {dataset['train'][0].keys()}")
31
+
32
+ # データセットのトークン化
33
+ def tokenize_function(examples):
34
+ try:
35
+ texts = examples['text']
36
+ return tokenizer(texts, padding="max_length", truncation=True, max_length=128)
37
+ except KeyError as e:
38
+ print(f"KeyError: {e}")
39
+ print(f"Available keys: {examples.keys()}")
40
+ raise
41
+
42
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
43
+
44
+ # トレーニング設定
45
+ training_args = TrainingArguments(
46
+ output_dir="./results",
47
+ per_device_train_batch_size=8,
48
+ per_device_eval_batch_size=8,
49
+ evaluation_strategy="epoch",
50
+ save_strategy="epoch",
51
+ logging_dir="./logs",
52
+ logging_steps=10,
53
+ num_train_epochs=3, # トレーニングエポック数
54
+ push_to_hub=True, # Hugging Face Hubにプッシュ
55
+ hub_token=write_token,
56
+ hub_model_id=repo_name # ユーザーが入力したリポジトリ名
57
+ )
58
+
59
+ # Trainerの設定
60
+ trainer = Trainer(
61
+ model=model,
62
+ args=training_args,
63
+ train_dataset=tokenized_datasets["train"],
64
+ eval_dataset=tokenized_datasets["test"],
65
+ )
66
+
67
+ # トレーニング実行
68
+ trainer.train()
69
+
70
+ # モデルをHugging Face Hubにプッシュ
71
+ trainer.push_to_hub()
72
+
73
+ return f"モデルが'{repo_name}'リポジトリにデプロイされました!"
74
+
75
+ # Gradio UI
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("### pythia トレーニングとデプロイ")
78
+ token_input = gr.Textbox(label="Hugging Face Write Token", placeholder="トークンを入力してください...")
79
+ repo_input = gr.Textbox(label="リポジトリ名", placeholder="デプロイするリポジトリ名を入力してください...")
80
+ license_input = gr.Textbox(label="ライセンス", placeholder="ライセンス情報を入力してください...")
81
+ output = gr.Textbox(label="出力")
82
+ train_button = gr.Button("デプロイ")
83
+
84
+ train_button.click(fn=train_and_deploy, inputs=[token_input, repo_input, license_input], outputs=output)
85
+
86
+ demo.launch()