goodmodeler commited on
Commit
7066d20
·
0 Parent(s):

add train lora

Browse files
Files changed (15) hide show
  1. .gitattributes +35 -0
  2. README.md +68 -0
  3. app.py +154 -0
  4. build_embeddings.py +11 -0
  5. diffusers +1 -0
  6. ds_config.json +20 -0
  7. image_download.py +35 -0
  8. image_gen.py +8 -0
  9. inference.py +97 -0
  10. ppo_tune.py +19 -0
  11. requirements.txt +16 -0
  12. reward_model.py +21 -0
  13. sft_train.py +41 -0
  14. train_lora.py +51 -0
  15. xformers +1 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ZeroGPU
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.25.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ commands:
16
+
17
+ pip install git+https://github.com/huggingface/diffusers
18
+
19
+ accelerate launch \
20
+ --deepspeed_config_file ds_config.json \
21
+ diffusers/examples/dreambooth/train_dreambooth.py \
22
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
23
+ --instance_data_dir="./nyc_ads_dataset" \
24
+ --instance_prompt="a photo of an urbanad nyc" \
25
+ --output_dir="./nyc-ad-model" \
26
+ --resolution=100 \
27
+ --train_batch_size=1 \
28
+ --gradient_accumulation_steps=1 \
29
+ --gradient_checkpointing \
30
+ --learning_rate=5e-6 \
31
+ --lr_scheduler="constant" \
32
+ --lr_warmup_steps=0 \
33
+ --max_train_steps=400 \
34
+ --mixed_precision="fp16" \
35
+ --checkpointing_steps=100 \
36
+ --checkpoints_total_limit=1 \
37
+ --report_to="tensorboard" \
38
+ --logging_dir="./nyc-ad-model/logs"
39
+
40
+ fine tune a trained model: --pretrained_model_name_or_path="./nyc-ad-model/checkpoint-400" \
41
+
42
+
43
+
44
+ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
45
+
46
+ import torch
47
+ torch.cuda.empty_cache()
48
+ torch.cuda.reset_peak_memory_stats()
49
+
50
+ 7/12
51
+ # 1 Fine‑tune image model LoRA+QLoRA
52
+ accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
53
+ python train_lora.py
54
+
55
+ # 2 SFT 语言模型
56
+ python sft_train.py
57
+
58
+ # 3 Build RAG index
59
+ python build_embeddings.py
60
+
61
+ # 4 (可选) 收集偏好 → 训练 reward model
62
+ python reward_model.py
63
+
64
+ # 5 PPO RLHF 微调
65
+ python ppo_tune.py
66
+
67
+ # 6 Inference with RAG
68
+ python rag_infer.py
app.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+
5
+ # import spaces #[uncomment to use ZeroGPU]
6
+ from diffusers import DiffusionPipeline
7
+ import torch
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
+
12
+ if torch.cuda.is_available():
13
+ torch_dtype = torch.float16
14
+ else:
15
+ torch_dtype = torch.float32
16
+
17
+ pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
+ pipe = pipe.to(device)
19
+
20
+ MAX_SEED = np.iinfo(np.int32).max
21
+ MAX_IMAGE_SIZE = 1024
22
+
23
+
24
+ # @spaces.GPU #[uncomment to use ZeroGPU]
25
+ def infer(
26
+ prompt,
27
+ negative_prompt,
28
+ seed,
29
+ randomize_seed,
30
+ width,
31
+ height,
32
+ guidance_scale,
33
+ num_inference_steps,
34
+ progress=gr.Progress(track_tqdm=True),
35
+ ):
36
+ if randomize_seed:
37
+ seed = random.randint(0, MAX_SEED)
38
+
39
+ generator = torch.Generator().manual_seed(seed)
40
+
41
+ image = pipe(
42
+ prompt=prompt,
43
+ negative_prompt=negative_prompt,
44
+ guidance_scale=guidance_scale,
45
+ num_inference_steps=num_inference_steps,
46
+ width=width,
47
+ height=height,
48
+ generator=generator,
49
+ ).images[0]
50
+
51
+ return image, seed
52
+
53
+
54
+ examples = [
55
+ "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
+ "An astronaut riding a green horse",
57
+ "A delicious ceviche cheesecake slice",
58
+ ]
59
+
60
+ css = """
61
+ #col-container {
62
+ margin: 0 auto;
63
+ max-width: 640px;
64
+ }
65
+ """
66
+
67
+ with gr.Blocks(css=css) as demo:
68
+ with gr.Column(elem_id="col-container"):
69
+ gr.Markdown(" # Text-to-Image Gradio Template")
70
+
71
+ with gr.Row():
72
+ prompt = gr.Text(
73
+ label="Prompt",
74
+ show_label=False,
75
+ max_lines=1,
76
+ placeholder="Enter your prompt",
77
+ container=False,
78
+ )
79
+
80
+ run_button = gr.Button("Run", scale=0, variant="primary")
81
+
82
+ result = gr.Image(label="Result", show_label=False)
83
+
84
+ with gr.Accordion("Advanced Settings", open=False):
85
+ negative_prompt = gr.Text(
86
+ label="Negative prompt",
87
+ max_lines=1,
88
+ placeholder="Enter a negative prompt",
89
+ visible=False,
90
+ )
91
+
92
+ seed = gr.Slider(
93
+ label="Seed",
94
+ minimum=0,
95
+ maximum=MAX_SEED,
96
+ step=1,
97
+ value=0,
98
+ )
99
+
100
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
+
102
+ with gr.Row():
103
+ width = gr.Slider(
104
+ label="Width",
105
+ minimum=256,
106
+ maximum=MAX_IMAGE_SIZE,
107
+ step=32,
108
+ value=1024, # Replace with defaults that work for your model
109
+ )
110
+
111
+ height = gr.Slider(
112
+ label="Height",
113
+ minimum=256,
114
+ maximum=MAX_IMAGE_SIZE,
115
+ step=32,
116
+ value=1024, # Replace with defaults that work for your model
117
+ )
118
+
119
+ with gr.Row():
120
+ guidance_scale = gr.Slider(
121
+ label="Guidance scale",
122
+ minimum=0.0,
123
+ maximum=10.0,
124
+ step=0.1,
125
+ value=0.0, # Replace with defaults that work for your model
126
+ )
127
+
128
+ num_inference_steps = gr.Slider(
129
+ label="Number of inference steps",
130
+ minimum=1,
131
+ maximum=50,
132
+ step=1,
133
+ value=2, # Replace with defaults that work for your model
134
+ )
135
+
136
+ gr.Examples(examples=examples, inputs=[prompt])
137
+ gr.on(
138
+ triggers=[run_button.click, prompt.submit],
139
+ fn=infer,
140
+ inputs=[
141
+ prompt,
142
+ negative_prompt,
143
+ seed,
144
+ randomize_seed,
145
+ width,
146
+ height,
147
+ guidance_scale,
148
+ num_inference_steps,
149
+ ],
150
+ outputs=[result, seed],
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ demo.launch()
build_embeddings.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import faiss, json, glob, os, numpy as np
3
+
4
+ model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
5
+ texts=[]; vecs=[]
6
+ for f in glob.glob("nyc_ads_dataset/*.json"):
7
+ cap=json.load(open(f))["caption"]
8
+ texts.append(cap); vecs.append(model.encode(cap,normalize_embeddings=True))
9
+ vecs=np.vstack(vecs).astype("float32")
10
+ index=faiss.IndexFlatIP(vecs.shape[1]); index.add(vecs)
11
+ faiss.write_index(index,"prompt.index"); json.dump(texts,open("prompt.txt","w"))
diffusers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 92fe689f06bcec27c4f48cb90574c2b9c42c643b
ds_config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "zero_optimization": {
3
+ "stage": 2,
4
+ "offload_param": {
5
+ "device": "cpu",
6
+ "pin_memory": true
7
+ },
8
+ "offload_optimizer": {
9
+ "device": "cpu",
10
+ "pin_memory": true
11
+ },
12
+ "overlap_comm": true,
13
+ "contiguous_gradients": true
14
+ },
15
+ "gradient_accumulation_steps": 1,
16
+ "train_batch_size": 1,
17
+ "fp16": {
18
+ "enabled": true
19
+ }
20
+ }
image_download.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flickrapi
2
+ import requests
3
+ import os
4
+
5
+ # Your Flickr API credentials
6
+ FLICKR_PUBLIC = '0ff89a88a2a61c24f452774ad32ee62c'
7
+ FLICKR_SECRET = '35c5034466630c82'
8
+
9
+ # Create Flickr API object
10
+ flickr = flickrapi.FlickrAPI(FLICKR_PUBLIC, FLICKR_SECRET, format='parsed-json')
11
+
12
+ # Search for images with relevant tags
13
+ results = flickr.photos.search(
14
+ text='advertisement',
15
+ per_page=50,
16
+ media='photos',
17
+ sort='relevance',
18
+ extras='url_o,url_l,url_c,tags',
19
+ content_type=1,
20
+ safe_search=1
21
+ )
22
+
23
+ photos = results['photos']['photo']
24
+
25
+ # Create folder to save images
26
+ # os.makedirs('flickr_brooklyn_ads', exist_ok=True)
27
+
28
+ # Download images
29
+ for i, photo in enumerate(photos):
30
+ url = photo.get('url_o') or photo.get('url_l') or photo.get('url_c')
31
+ if url:
32
+ img_data = requests.get(url).content
33
+ with open(f'nyc_ads_dataset/img_{i}.jpg', 'wb') as handler:
34
+ handler.write(img_data)
35
+ print(f"Downloaded: img_{i}.jpg")
image_gen.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import os
3
+
4
+ dataset = load_dataset("cifar10", split="train", streaming=True)
5
+
6
+ os.makedirs("./nyc_ads_dataset", exist_ok=True)
7
+ for i, ex in zip(range(5), dataset):
8
+ ex["img"].save(f"./nyc_ads_dataset/{i+1:03d}.jpg")
inference.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ from diffusers import StableDiffusionPipeline
3
+ import torch
4
+
5
+ # Load the fine-tuned DreamBooth model
6
+ pipe = StableDiffusionPipeline.from_pretrained(
7
+ "./nyc-ad-model",
8
+ torch_dtype=torch.float16,
9
+ ).to("cuda") # use "cpu" if no GPU
10
+
11
+ prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street"
12
+ image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0]
13
+
14
+ # Display or save the image
15
+ image.save("output_nyc_ad.png")
16
+ image.show()
17
+ '''
18
+ '''
19
+ import torch, faiss, json
20
+ from sentence_transformers import SentenceTransformer
21
+ from diffusers import StableDiffusionPipeline
22
+
23
+ texts=json.load(open("prompt.txt"))
24
+ index=faiss.read_index("prompt.index")
25
+ emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
26
+ pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
27
+
28
+ def rag_prompt(query,k=3):
29
+ q=emb.encode(query,normalize_embeddings=True).astype("float32")
30
+ _,I=index.search(q.reshape(1,-1),k)
31
+ retrieved=" ".join(texts[i] for i in I[0])
32
+ return f"{retrieved}. {query}"
33
+
34
+ prompt=rag_prompt("fried chicken advertisement poster")
35
+ img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0]
36
+ img.save("rag_output.png")
37
+ '''
38
+
39
+ import torch, faiss, json
40
+ from sentence_transformers import SentenceTransformer
41
+ from diffusers import StableDiffusionPipeline
42
+ from transformers import AutoTokenizer, AutoModelForCausalLM
43
+
44
+ # Load RAG index
45
+ texts = json.load(open("prompt.txt"))
46
+ index = faiss.read_index("prompt.index")
47
+ emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
48
+
49
+ # Load image generation pipeline
50
+ pipe = StableDiffusionPipeline.from_pretrained(
51
+ "./nyc-ad-model",
52
+ torch_dtype=torch.float16
53
+ ).to("cuda")
54
+
55
+ # Load your own fine-tuned SFT model
56
+ text_model_path = "./sft-model" # Path to your SFT-finetuned model
57
+ tokenizer = AutoTokenizer.from_pretrained(text_model_path)
58
+ text_model = AutoModelForCausalLM.from_pretrained(
59
+ text_model_path,
60
+ torch_dtype=torch.float16,
61
+ device_map="auto"
62
+ )
63
+
64
+ # Build retrieval-augmented prompt
65
+ def rag_prompt(query, k=3):
66
+ q = emb.encode(query, normalize_embeddings=True).astype("float32")
67
+ _, I = index.search(q.reshape(1, -1), k)
68
+ retrieved = " ".join(texts[i] for i in I[0])
69
+ return f"{retrieved}. {query}"
70
+
71
+ # Prompt for generation
72
+ user_prompt = "fried chicken advertisement poster"
73
+ full_prompt = rag_prompt(user_prompt)
74
+
75
+ # Generate image
76
+ image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
77
+ image.save("rag_output.png")
78
+
79
+ # Construct input prompt compatible with SFT format
80
+ copy_prompt = f"""### Instruction:
81
+ Generate a catchy advertisement slogan for: {user_prompt}
82
+
83
+ ### Response:"""
84
+
85
+ inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda")
86
+ output_ids = text_model.generate(
87
+ **inputs,
88
+ max_new_tokens=30,
89
+ do_sample=True,
90
+ top_p=0.95
91
+ )
92
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
93
+
94
+ # Output result
95
+ print("🖼️ Image saved to rag_output.png")
96
+ print("📝 Generated slogan:")
97
+ print(response.strip())
ppo_tune.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from trl import PPOTrainer, PPOConfig
2
+ from peft import PeftModel
3
+ import torch, random, json, glob
4
+ from diffusers import StableDiffusionPipeline
5
+ from reward_model import CLIPModel, CLIPProcessor
6
+
7
+ rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
8
+ proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
9
+ pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
10
+ ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
11
+ trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
12
+
13
+ prompts=[l.strip() for l in open("prompt.txt")]
14
+ for step in range(500):
15
+ p=random.choice(prompts)
16
+ img=pipe(p,num_inference_steps=20).images[0]
17
+ reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
18
+ trainer.step(prompts=[p], rewards=[reward])
19
+ pipe.save_pretrained("nyc-ad-model-rlhf")
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ invisible_watermark
4
+ torch
5
+ transformers
6
+ xformers
7
+ torchvision
8
+ flickrapi
9
+ requests
10
+ peft>=0.9.0
11
+ bitsandbytes
12
+ faiss-cpu
13
+ sentence-transformers
14
+ trl[peft]
15
+ label-studio
16
+ datasets
reward_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
2
+ import datasets, torch, json, glob
3
+
4
+ model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
5
+ processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
6
+
7
+ data=[]
8
+ for f in glob.glob("human_prefs/*.json"):
9
+ j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
10
+
11
+ dataset=datasets.Dataset.from_list(data)
12
+
13
+ def preprocess(ex):
14
+ inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
15
+ inputs["labels"]=torch.tensor([1,0])
16
+ return inputs
17
+
18
+ dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
19
+ args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
20
+ trainer=Trainer(model,args,train_dataset=dataset)
21
+ trainer.train(); model.save_pretrained("rm")
sft_train.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, json
2
+ from datasets import load_dataset, Dataset
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
4
+ from peft import get_peft_model, LoraConfig, TaskType
5
+
6
+ # Load your dataset
7
+ data = [json.loads(l) for l in open("data/sft_data.jsonl")]
8
+ dataset = Dataset.from_list(data)
9
+
10
+ # Load model & tokenizer
11
+ base_model = "meta-llama/Llama-2-7b-hf" # Or use Mistral, Falcon, etc.
12
+ tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
13
+ model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16)
14
+
15
+ # Add LoRA (optional)
16
+ lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.05,
17
+ target_modules=["q_proj", "v_proj"])
18
+ model = get_peft_model(model, lora_config)
19
+
20
+ # Preprocessing
21
+ def tokenize(example):
22
+ prompt = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['output']}"
23
+ return tokenizer(prompt, truncation=True, max_length=512, padding="max_length")
24
+ dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
25
+
26
+ # Training setup
27
+ args = TrainingArguments(
28
+ output_dir="./sft-model",
29
+ per_device_train_batch_size=2,
30
+ num_train_epochs=3,
31
+ fp16=True,
32
+ evaluation_strategy="no",
33
+ save_strategy="epoch",
34
+ logging_steps=20,
35
+ learning_rate=2e-5,
36
+ report_to="tensorboard",
37
+ )
38
+
39
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
40
+ trainer = Trainer(model=model, args=args, train_dataset=dataset, data_collator=data_collator)
41
+ trainer.train()
train_lora.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion)
2
+
3
+ import os, argparse, torch
4
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
5
+ from diffusers import DreamBoothLoraTrainer
6
+ from peft import LoraConfig
7
+ from accelerate import Accelerator
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录
11
+ args = parser.parse_args()
12
+
13
+ # LoRA 配置(兼容 QLoRA)
14
+ lora_cfg = LoraConfig(
15
+ r=8,
16
+ lora_alpha=32,
17
+ lora_dropout=0.05,
18
+ target_modules=["q_proj", "v_proj"]
19
+ )
20
+
21
+ # 4-bit 量化加载 SD-1.5
22
+ pipe = StableDiffusionPipeline.from_pretrained(
23
+ "runwayml/stable-diffusion-v1-5",
24
+ torch_dtype=torch.float16,
25
+ load_in_4bit=True,
26
+ quantization_config={
27
+ "bnb_4bit_compute_dtype": torch.float16,
28
+ "bnb_4bit_use_double_quant": True,
29
+ "bnb_4bit_quant_type": "nf4"
30
+ },
31
+ )
32
+
33
+ # DreamBooth LoRA Trainer
34
+ trainer = DreamBoothLoraTrainer(
35
+ instance_data_root=args.data,
36
+ instance_prompt="a photo of an urbanad nyc",
37
+ lora_config=lora_cfg,
38
+ output_dir="./nyc-ad-model",
39
+ max_train_steps=400,
40
+ train_batch_size=1,
41
+ gradient_checkpointing=True,
42
+ )
43
+
44
+ # DeepSpeed ZeRO-3 加速 / 显存拆分
45
+ accelerator = Accelerator(
46
+ mixed_precision="fp16",
47
+ deepspeed_config="./ds_config_zero3.json" # 需提前放置
48
+ )
49
+
50
+ # 开始训练
51
+ trainer.train(accelerator)
xformers ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 8fc8ec5a4d6498ff81c0c418b89bbaf133ae3a44