Spaces:
Running
Running
Commit
·
7066d20
0
Parent(s):
add train lora
Browse files- .gitattributes +35 -0
- README.md +68 -0
- app.py +154 -0
- build_embeddings.py +11 -0
- diffusers +1 -0
- ds_config.json +20 -0
- image_download.py +35 -0
- image_gen.py +8 -0
- inference.py +97 -0
- ppo_tune.py +19 -0
- requirements.txt +16 -0
- reward_model.py +21 -0
- sft_train.py +41 -0
- train_lora.py +51 -0
- xformers +1 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: ZeroGPU
|
3 |
+
emoji: 🖼
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.2
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
14 |
+
|
15 |
+
commands:
|
16 |
+
|
17 |
+
pip install git+https://github.com/huggingface/diffusers
|
18 |
+
|
19 |
+
accelerate launch \
|
20 |
+
--deepspeed_config_file ds_config.json \
|
21 |
+
diffusers/examples/dreambooth/train_dreambooth.py \
|
22 |
+
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
|
23 |
+
--instance_data_dir="./nyc_ads_dataset" \
|
24 |
+
--instance_prompt="a photo of an urbanad nyc" \
|
25 |
+
--output_dir="./nyc-ad-model" \
|
26 |
+
--resolution=100 \
|
27 |
+
--train_batch_size=1 \
|
28 |
+
--gradient_accumulation_steps=1 \
|
29 |
+
--gradient_checkpointing \
|
30 |
+
--learning_rate=5e-6 \
|
31 |
+
--lr_scheduler="constant" \
|
32 |
+
--lr_warmup_steps=0 \
|
33 |
+
--max_train_steps=400 \
|
34 |
+
--mixed_precision="fp16" \
|
35 |
+
--checkpointing_steps=100 \
|
36 |
+
--checkpoints_total_limit=1 \
|
37 |
+
--report_to="tensorboard" \
|
38 |
+
--logging_dir="./nyc-ad-model/logs"
|
39 |
+
|
40 |
+
fine tune a trained model: --pretrained_model_name_or_path="./nyc-ad-model/checkpoint-400" \
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
45 |
+
|
46 |
+
import torch
|
47 |
+
torch.cuda.empty_cache()
|
48 |
+
torch.cuda.reset_peak_memory_stats()
|
49 |
+
|
50 |
+
7/12
|
51 |
+
# 1 Fine‑tune image model LoRA+QLoRA
|
52 |
+
accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
|
53 |
+
python train_lora.py
|
54 |
+
|
55 |
+
# 2 SFT 语言模型
|
56 |
+
python sft_train.py
|
57 |
+
|
58 |
+
# 3 Build RAG index
|
59 |
+
python build_embeddings.py
|
60 |
+
|
61 |
+
# 4 (可选) 收集偏好 → 训练 reward model
|
62 |
+
python reward_model.py
|
63 |
+
|
64 |
+
# 5 PPO RLHF 微调
|
65 |
+
python ppo_tune.py
|
66 |
+
|
67 |
+
# 6 Inference with RAG
|
68 |
+
python rag_infer.py
|
app.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import random
|
4 |
+
|
5 |
+
# import spaces #[uncomment to use ZeroGPU]
|
6 |
+
from diffusers import DiffusionPipeline
|
7 |
+
import torch
|
8 |
+
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
|
11 |
+
|
12 |
+
if torch.cuda.is_available():
|
13 |
+
torch_dtype = torch.float16
|
14 |
+
else:
|
15 |
+
torch_dtype = torch.float32
|
16 |
+
|
17 |
+
pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
|
18 |
+
pipe = pipe.to(device)
|
19 |
+
|
20 |
+
MAX_SEED = np.iinfo(np.int32).max
|
21 |
+
MAX_IMAGE_SIZE = 1024
|
22 |
+
|
23 |
+
|
24 |
+
# @spaces.GPU #[uncomment to use ZeroGPU]
|
25 |
+
def infer(
|
26 |
+
prompt,
|
27 |
+
negative_prompt,
|
28 |
+
seed,
|
29 |
+
randomize_seed,
|
30 |
+
width,
|
31 |
+
height,
|
32 |
+
guidance_scale,
|
33 |
+
num_inference_steps,
|
34 |
+
progress=gr.Progress(track_tqdm=True),
|
35 |
+
):
|
36 |
+
if randomize_seed:
|
37 |
+
seed = random.randint(0, MAX_SEED)
|
38 |
+
|
39 |
+
generator = torch.Generator().manual_seed(seed)
|
40 |
+
|
41 |
+
image = pipe(
|
42 |
+
prompt=prompt,
|
43 |
+
negative_prompt=negative_prompt,
|
44 |
+
guidance_scale=guidance_scale,
|
45 |
+
num_inference_steps=num_inference_steps,
|
46 |
+
width=width,
|
47 |
+
height=height,
|
48 |
+
generator=generator,
|
49 |
+
).images[0]
|
50 |
+
|
51 |
+
return image, seed
|
52 |
+
|
53 |
+
|
54 |
+
examples = [
|
55 |
+
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
+
"An astronaut riding a green horse",
|
57 |
+
"A delicious ceviche cheesecake slice",
|
58 |
+
]
|
59 |
+
|
60 |
+
css = """
|
61 |
+
#col-container {
|
62 |
+
margin: 0 auto;
|
63 |
+
max-width: 640px;
|
64 |
+
}
|
65 |
+
"""
|
66 |
+
|
67 |
+
with gr.Blocks(css=css) as demo:
|
68 |
+
with gr.Column(elem_id="col-container"):
|
69 |
+
gr.Markdown(" # Text-to-Image Gradio Template")
|
70 |
+
|
71 |
+
with gr.Row():
|
72 |
+
prompt = gr.Text(
|
73 |
+
label="Prompt",
|
74 |
+
show_label=False,
|
75 |
+
max_lines=1,
|
76 |
+
placeholder="Enter your prompt",
|
77 |
+
container=False,
|
78 |
+
)
|
79 |
+
|
80 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
81 |
+
|
82 |
+
result = gr.Image(label="Result", show_label=False)
|
83 |
+
|
84 |
+
with gr.Accordion("Advanced Settings", open=False):
|
85 |
+
negative_prompt = gr.Text(
|
86 |
+
label="Negative prompt",
|
87 |
+
max_lines=1,
|
88 |
+
placeholder="Enter a negative prompt",
|
89 |
+
visible=False,
|
90 |
+
)
|
91 |
+
|
92 |
+
seed = gr.Slider(
|
93 |
+
label="Seed",
|
94 |
+
minimum=0,
|
95 |
+
maximum=MAX_SEED,
|
96 |
+
step=1,
|
97 |
+
value=0,
|
98 |
+
)
|
99 |
+
|
100 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
101 |
+
|
102 |
+
with gr.Row():
|
103 |
+
width = gr.Slider(
|
104 |
+
label="Width",
|
105 |
+
minimum=256,
|
106 |
+
maximum=MAX_IMAGE_SIZE,
|
107 |
+
step=32,
|
108 |
+
value=1024, # Replace with defaults that work for your model
|
109 |
+
)
|
110 |
+
|
111 |
+
height = gr.Slider(
|
112 |
+
label="Height",
|
113 |
+
minimum=256,
|
114 |
+
maximum=MAX_IMAGE_SIZE,
|
115 |
+
step=32,
|
116 |
+
value=1024, # Replace with defaults that work for your model
|
117 |
+
)
|
118 |
+
|
119 |
+
with gr.Row():
|
120 |
+
guidance_scale = gr.Slider(
|
121 |
+
label="Guidance scale",
|
122 |
+
minimum=0.0,
|
123 |
+
maximum=10.0,
|
124 |
+
step=0.1,
|
125 |
+
value=0.0, # Replace with defaults that work for your model
|
126 |
+
)
|
127 |
+
|
128 |
+
num_inference_steps = gr.Slider(
|
129 |
+
label="Number of inference steps",
|
130 |
+
minimum=1,
|
131 |
+
maximum=50,
|
132 |
+
step=1,
|
133 |
+
value=2, # Replace with defaults that work for your model
|
134 |
+
)
|
135 |
+
|
136 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
137 |
+
gr.on(
|
138 |
+
triggers=[run_button.click, prompt.submit],
|
139 |
+
fn=infer,
|
140 |
+
inputs=[
|
141 |
+
prompt,
|
142 |
+
negative_prompt,
|
143 |
+
seed,
|
144 |
+
randomize_seed,
|
145 |
+
width,
|
146 |
+
height,
|
147 |
+
guidance_scale,
|
148 |
+
num_inference_steps,
|
149 |
+
],
|
150 |
+
outputs=[result, seed],
|
151 |
+
)
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
demo.launch()
|
build_embeddings.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import faiss, json, glob, os, numpy as np
|
3 |
+
|
4 |
+
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
5 |
+
texts=[]; vecs=[]
|
6 |
+
for f in glob.glob("nyc_ads_dataset/*.json"):
|
7 |
+
cap=json.load(open(f))["caption"]
|
8 |
+
texts.append(cap); vecs.append(model.encode(cap,normalize_embeddings=True))
|
9 |
+
vecs=np.vstack(vecs).astype("float32")
|
10 |
+
index=faiss.IndexFlatIP(vecs.shape[1]); index.add(vecs)
|
11 |
+
faiss.write_index(index,"prompt.index"); json.dump(texts,open("prompt.txt","w"))
|
diffusers
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 92fe689f06bcec27c4f48cb90574c2b9c42c643b
|
ds_config.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"zero_optimization": {
|
3 |
+
"stage": 2,
|
4 |
+
"offload_param": {
|
5 |
+
"device": "cpu",
|
6 |
+
"pin_memory": true
|
7 |
+
},
|
8 |
+
"offload_optimizer": {
|
9 |
+
"device": "cpu",
|
10 |
+
"pin_memory": true
|
11 |
+
},
|
12 |
+
"overlap_comm": true,
|
13 |
+
"contiguous_gradients": true
|
14 |
+
},
|
15 |
+
"gradient_accumulation_steps": 1,
|
16 |
+
"train_batch_size": 1,
|
17 |
+
"fp16": {
|
18 |
+
"enabled": true
|
19 |
+
}
|
20 |
+
}
|
image_download.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import flickrapi
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Your Flickr API credentials
|
6 |
+
FLICKR_PUBLIC = '0ff89a88a2a61c24f452774ad32ee62c'
|
7 |
+
FLICKR_SECRET = '35c5034466630c82'
|
8 |
+
|
9 |
+
# Create Flickr API object
|
10 |
+
flickr = flickrapi.FlickrAPI(FLICKR_PUBLIC, FLICKR_SECRET, format='parsed-json')
|
11 |
+
|
12 |
+
# Search for images with relevant tags
|
13 |
+
results = flickr.photos.search(
|
14 |
+
text='advertisement',
|
15 |
+
per_page=50,
|
16 |
+
media='photos',
|
17 |
+
sort='relevance',
|
18 |
+
extras='url_o,url_l,url_c,tags',
|
19 |
+
content_type=1,
|
20 |
+
safe_search=1
|
21 |
+
)
|
22 |
+
|
23 |
+
photos = results['photos']['photo']
|
24 |
+
|
25 |
+
# Create folder to save images
|
26 |
+
# os.makedirs('flickr_brooklyn_ads', exist_ok=True)
|
27 |
+
|
28 |
+
# Download images
|
29 |
+
for i, photo in enumerate(photos):
|
30 |
+
url = photo.get('url_o') or photo.get('url_l') or photo.get('url_c')
|
31 |
+
if url:
|
32 |
+
img_data = requests.get(url).content
|
33 |
+
with open(f'nyc_ads_dataset/img_{i}.jpg', 'wb') as handler:
|
34 |
+
handler.write(img_data)
|
35 |
+
print(f"Downloaded: img_{i}.jpg")
|
image_gen.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
import os
|
3 |
+
|
4 |
+
dataset = load_dataset("cifar10", split="train", streaming=True)
|
5 |
+
|
6 |
+
os.makedirs("./nyc_ads_dataset", exist_ok=True)
|
7 |
+
for i, ex in zip(range(5), dataset):
|
8 |
+
ex["img"].save(f"./nyc_ads_dataset/{i+1:03d}.jpg")
|
inference.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
from diffusers import StableDiffusionPipeline
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Load the fine-tuned DreamBooth model
|
6 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
7 |
+
"./nyc-ad-model",
|
8 |
+
torch_dtype=torch.float16,
|
9 |
+
).to("cuda") # use "cpu" if no GPU
|
10 |
+
|
11 |
+
prompt = "brand name: xyc, fried chicken advertisement poster: a fried chicken in brooklyn street"
|
12 |
+
image = pipe(prompt, num_inference_steps=500, guidance_scale=7.5).images[0]
|
13 |
+
|
14 |
+
# Display or save the image
|
15 |
+
image.save("output_nyc_ad.png")
|
16 |
+
image.show()
|
17 |
+
'''
|
18 |
+
'''
|
19 |
+
import torch, faiss, json
|
20 |
+
from sentence_transformers import SentenceTransformer
|
21 |
+
from diffusers import StableDiffusionPipeline
|
22 |
+
|
23 |
+
texts=json.load(open("prompt.txt"))
|
24 |
+
index=faiss.read_index("prompt.index")
|
25 |
+
emb=SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
26 |
+
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
|
27 |
+
|
28 |
+
def rag_prompt(query,k=3):
|
29 |
+
q=emb.encode(query,normalize_embeddings=True).astype("float32")
|
30 |
+
_,I=index.search(q.reshape(1,-1),k)
|
31 |
+
retrieved=" ".join(texts[i] for i in I[0])
|
32 |
+
return f"{retrieved}. {query}"
|
33 |
+
|
34 |
+
prompt=rag_prompt("fried chicken advertisement poster")
|
35 |
+
img=pipe(prompt,num_inference_steps=30,guidance_scale=7.5).images[0]
|
36 |
+
img.save("rag_output.png")
|
37 |
+
'''
|
38 |
+
|
39 |
+
import torch, faiss, json
|
40 |
+
from sentence_transformers import SentenceTransformer
|
41 |
+
from diffusers import StableDiffusionPipeline
|
42 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
43 |
+
|
44 |
+
# Load RAG index
|
45 |
+
texts = json.load(open("prompt.txt"))
|
46 |
+
index = faiss.read_index("prompt.index")
|
47 |
+
emb = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
|
48 |
+
|
49 |
+
# Load image generation pipeline
|
50 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
51 |
+
"./nyc-ad-model",
|
52 |
+
torch_dtype=torch.float16
|
53 |
+
).to("cuda")
|
54 |
+
|
55 |
+
# Load your own fine-tuned SFT model
|
56 |
+
text_model_path = "./sft-model" # Path to your SFT-finetuned model
|
57 |
+
tokenizer = AutoTokenizer.from_pretrained(text_model_path)
|
58 |
+
text_model = AutoModelForCausalLM.from_pretrained(
|
59 |
+
text_model_path,
|
60 |
+
torch_dtype=torch.float16,
|
61 |
+
device_map="auto"
|
62 |
+
)
|
63 |
+
|
64 |
+
# Build retrieval-augmented prompt
|
65 |
+
def rag_prompt(query, k=3):
|
66 |
+
q = emb.encode(query, normalize_embeddings=True).astype("float32")
|
67 |
+
_, I = index.search(q.reshape(1, -1), k)
|
68 |
+
retrieved = " ".join(texts[i] for i in I[0])
|
69 |
+
return f"{retrieved}. {query}"
|
70 |
+
|
71 |
+
# Prompt for generation
|
72 |
+
user_prompt = "fried chicken advertisement poster"
|
73 |
+
full_prompt = rag_prompt(user_prompt)
|
74 |
+
|
75 |
+
# Generate image
|
76 |
+
image = pipe(full_prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
77 |
+
image.save("rag_output.png")
|
78 |
+
|
79 |
+
# Construct input prompt compatible with SFT format
|
80 |
+
copy_prompt = f"""### Instruction:
|
81 |
+
Generate a catchy advertisement slogan for: {user_prompt}
|
82 |
+
|
83 |
+
### Response:"""
|
84 |
+
|
85 |
+
inputs = tokenizer(copy_prompt, return_tensors="pt").to("cuda")
|
86 |
+
output_ids = text_model.generate(
|
87 |
+
**inputs,
|
88 |
+
max_new_tokens=30,
|
89 |
+
do_sample=True,
|
90 |
+
top_p=0.95
|
91 |
+
)
|
92 |
+
response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
93 |
+
|
94 |
+
# Output result
|
95 |
+
print("🖼️ Image saved to rag_output.png")
|
96 |
+
print("📝 Generated slogan:")
|
97 |
+
print(response.strip())
|
ppo_tune.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trl import PPOTrainer, PPOConfig
|
2 |
+
from peft import PeftModel
|
3 |
+
import torch, random, json, glob
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from reward_model import CLIPModel, CLIPProcessor
|
6 |
+
|
7 |
+
rm=CLIPModel.from_pretrained("rm").eval().half().cuda()
|
8 |
+
proc=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
9 |
+
pipe=StableDiffusionPipeline.from_pretrained("./nyc-ad-model",torch_dtype=torch.float16).to("cuda")
|
10 |
+
ppo_cfg=PPOConfig(batch_size=1,learning_rate=1e-6,target_kl=0.2)
|
11 |
+
trainer=PPOTrainer(model=pipe.unet, reward_model=rm, config=ppo_cfg)
|
12 |
+
|
13 |
+
prompts=[l.strip() for l in open("prompt.txt")]
|
14 |
+
for step in range(500):
|
15 |
+
p=random.choice(prompts)
|
16 |
+
img=pipe(p,num_inference_steps=20).images[0]
|
17 |
+
reward=rm(**proc(text=p,images=img,return_tensors="pt").to("cuda")).logits[0,0].item()
|
18 |
+
trainer.step(prompts=[p], rewards=[reward])
|
19 |
+
pipe.save_pretrained("nyc-ad-model-rlhf")
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate
|
2 |
+
diffusers
|
3 |
+
invisible_watermark
|
4 |
+
torch
|
5 |
+
transformers
|
6 |
+
xformers
|
7 |
+
torchvision
|
8 |
+
flickrapi
|
9 |
+
requests
|
10 |
+
peft>=0.9.0
|
11 |
+
bitsandbytes
|
12 |
+
faiss-cpu
|
13 |
+
sentence-transformers
|
14 |
+
trl[peft]
|
15 |
+
label-studio
|
16 |
+
datasets
|
reward_model.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import CLIPProcessor, CLIPModel, TrainingArguments, Trainer
|
2 |
+
import datasets, torch, json, glob
|
3 |
+
|
4 |
+
model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
5 |
+
processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
6 |
+
|
7 |
+
data=[]
|
8 |
+
for f in glob.glob("human_prefs/*.json"):
|
9 |
+
j=json.load(open(f)); data.append(j) # {"prompt":…, "good":img_path, "bad":img_path}
|
10 |
+
|
11 |
+
dataset=datasets.Dataset.from_list(data)
|
12 |
+
|
13 |
+
def preprocess(ex):
|
14 |
+
inputs=processor(text=[ex["prompt"]*2], images=[ex["good"],ex["bad"]], return_tensors="pt")
|
15 |
+
inputs["labels"]=torch.tensor([1,0])
|
16 |
+
return inputs
|
17 |
+
|
18 |
+
dataset=dataset.map(preprocess,remove_columns=dataset.column_names)
|
19 |
+
args=TrainingArguments("rm_ckpt",per_device_train_batch_size=2,fp16=True,learning_rate=5e-6,epochs=3)
|
20 |
+
trainer=Trainer(model,args,train_dataset=dataset)
|
21 |
+
trainer.train(); model.save_pretrained("rm")
|
sft_train.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch, json
|
2 |
+
from datasets import load_dataset, Dataset
|
3 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
|
4 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
5 |
+
|
6 |
+
# Load your dataset
|
7 |
+
data = [json.loads(l) for l in open("data/sft_data.jsonl")]
|
8 |
+
dataset = Dataset.from_list(data)
|
9 |
+
|
10 |
+
# Load model & tokenizer
|
11 |
+
base_model = "meta-llama/Llama-2-7b-hf" # Or use Mistral, Falcon, etc.
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
|
13 |
+
model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16)
|
14 |
+
|
15 |
+
# Add LoRA (optional)
|
16 |
+
lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM, r=8, lora_alpha=32, lora_dropout=0.05,
|
17 |
+
target_modules=["q_proj", "v_proj"])
|
18 |
+
model = get_peft_model(model, lora_config)
|
19 |
+
|
20 |
+
# Preprocessing
|
21 |
+
def tokenize(example):
|
22 |
+
prompt = f"### Instruction:\n{example['prompt']}\n\n### Response:\n{example['output']}"
|
23 |
+
return tokenizer(prompt, truncation=True, max_length=512, padding="max_length")
|
24 |
+
dataset = dataset.map(tokenize, remove_columns=dataset.column_names)
|
25 |
+
|
26 |
+
# Training setup
|
27 |
+
args = TrainingArguments(
|
28 |
+
output_dir="./sft-model",
|
29 |
+
per_device_train_batch_size=2,
|
30 |
+
num_train_epochs=3,
|
31 |
+
fp16=True,
|
32 |
+
evaluation_strategy="no",
|
33 |
+
save_strategy="epoch",
|
34 |
+
logging_steps=20,
|
35 |
+
learning_rate=2e-5,
|
36 |
+
report_to="tensorboard",
|
37 |
+
)
|
38 |
+
|
39 |
+
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
|
40 |
+
trainer = Trainer(model=model, args=args, train_dataset=dataset, data_collator=data_collator)
|
41 |
+
trainer.train()
|
train_lora.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion)
|
2 |
+
|
3 |
+
import os, argparse, torch
|
4 |
+
from diffusers import StableDiffusionPipeline, DDPMScheduler
|
5 |
+
from diffusers import DreamBoothLoraTrainer
|
6 |
+
from peft import LoraConfig
|
7 |
+
from accelerate import Accelerator
|
8 |
+
|
9 |
+
parser = argparse.ArgumentParser()
|
10 |
+
parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录
|
11 |
+
args = parser.parse_args()
|
12 |
+
|
13 |
+
# LoRA 配置(兼容 QLoRA)
|
14 |
+
lora_cfg = LoraConfig(
|
15 |
+
r=8,
|
16 |
+
lora_alpha=32,
|
17 |
+
lora_dropout=0.05,
|
18 |
+
target_modules=["q_proj", "v_proj"]
|
19 |
+
)
|
20 |
+
|
21 |
+
# 4-bit 量化加载 SD-1.5
|
22 |
+
pipe = StableDiffusionPipeline.from_pretrained(
|
23 |
+
"runwayml/stable-diffusion-v1-5",
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
load_in_4bit=True,
|
26 |
+
quantization_config={
|
27 |
+
"bnb_4bit_compute_dtype": torch.float16,
|
28 |
+
"bnb_4bit_use_double_quant": True,
|
29 |
+
"bnb_4bit_quant_type": "nf4"
|
30 |
+
},
|
31 |
+
)
|
32 |
+
|
33 |
+
# DreamBooth LoRA Trainer
|
34 |
+
trainer = DreamBoothLoraTrainer(
|
35 |
+
instance_data_root=args.data,
|
36 |
+
instance_prompt="a photo of an urbanad nyc",
|
37 |
+
lora_config=lora_cfg,
|
38 |
+
output_dir="./nyc-ad-model",
|
39 |
+
max_train_steps=400,
|
40 |
+
train_batch_size=1,
|
41 |
+
gradient_checkpointing=True,
|
42 |
+
)
|
43 |
+
|
44 |
+
# DeepSpeed ZeRO-3 加速 / 显存拆分
|
45 |
+
accelerator = Accelerator(
|
46 |
+
mixed_precision="fp16",
|
47 |
+
deepspeed_config="./ds_config_zero3.json" # 需提前放置
|
48 |
+
)
|
49 |
+
|
50 |
+
# 开始训练
|
51 |
+
trainer.train(accelerator)
|
xformers
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Subproject commit 8fc8ec5a4d6498ff81c0c418b89bbaf133ae3a44
|