Skip to content

Commit 1703e40

Browse files
committed
Add basix SDXL model support to GUI
1 parent ac04f95 commit 1703e40

7 files changed

Lines changed: 157 additions & 20 deletions

File tree

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,8 @@ learning_rate = 4e-7 # SDXL original learning rate
8080

8181
[![LoRA Part 2 Tutorial](https://img.youtube.com/vi/k5imq01uvUY/0.jpg)](https://www.youtube.com/watch?v=k5imq01uvUY)
8282

83-
<<<<<<< HEAD
8483
Newer Tutorial: [Generate Studio Quality Realistic Photos By Kohya LoRA Stable Diffusion Training](https://www.youtube.com/watch?v=TpuDOsuKIBo):
85-
=======
8684
The scripts are tested with PyTorch 1.12.1 and 2.0.1, Diffusers 0.17.1.
87-
>>>>>>> 227a62e4c4d3c3c5269a244328609ce2da96ebda
8885

8986
[![Newer Tutorial: Generate Studio Quality Realistic Photos By Kohya LoRA Stable Diffusion Training](https://user-images.githubusercontent.com/19240467/235306147-85dd8126-f397-406b-83f2-368927fa0281.png)](https://www.youtube.com/watch?v=TpuDOsuKIBo)
9087

dreambooth_gui.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def save_configuration(
5959
file_path,
6060
pretrained_model_name_or_path,
6161
v2,
62-
v_parameterization,
62+
v_parameterization, sdxl,
6363
logging_dir,
6464
train_data_dir,
6565
reg_data_dir,
@@ -176,7 +176,7 @@ def open_configuration(
176176
file_path,
177177
pretrained_model_name_or_path,
178178
v2,
179-
v_parameterization,
179+
v_parameterization, sdxl,
180180
logging_dir,
181181
train_data_dir,
182182
reg_data_dir,
@@ -281,7 +281,7 @@ def train_model(
281281
print_only,
282282
pretrained_model_name_or_path,
283283
v2,
284-
v_parameterization,
284+
v_parameterization, sdxl,
285285
logging_dir,
286286
train_data_dir,
287287
reg_data_dir,
@@ -399,6 +399,10 @@ def train_model(
399399
output_name, output_dir, save_model_as, headless=headless_bool
400400
):
401401
return
402+
403+
if sdxl:
404+
output_message(msg='TI training is not compatible with an SDXL model.', headless=headless_bool)
405+
return
402406

403407
if optimizer == 'Adafactor' and lr_warmup != '0':
404408
output_message(
@@ -659,6 +663,7 @@ def dreambooth_tab(
659663
pretrained_model_name_or_path,
660664
v2,
661665
v_parameterization,
666+
sdxl,
662667
save_model_as,
663668
model_list,
664669
) = gradio_source_model(headless=headless)
@@ -884,6 +889,7 @@ def dreambooth_tab(
884889
pretrained_model_name_or_path,
885890
v2,
886891
v_parameterization,
892+
sdxl,
887893
logging_dir,
888894
train_data_dir,
889895
reg_data_dir,

finetune_gui.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def save_configuration(
5151
file_path,
5252
pretrained_model_name_or_path,
5353
v2,
54-
v_parameterization,
54+
v_parameterization, sdxl,
5555
train_dir,
5656
image_folder,
5757
output_dir,
@@ -174,7 +174,7 @@ def open_configuration(
174174
file_path,
175175
pretrained_model_name_or_path,
176176
v2,
177-
v_parameterization,
177+
v_parameterization, sdxl,
178178
train_dir,
179179
image_folder,
180180
output_dir,
@@ -285,7 +285,7 @@ def train_model(
285285
print_only,
286286
pretrained_model_name_or_path,
287287
v2,
288-
v_parameterization,
288+
v_parameterization, sdxl,
289289
train_dir,
290290
image_folder,
291291
output_dir,
@@ -473,7 +473,12 @@ def train_model(
473473
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
474474
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
475475

476-
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "./fine_tune.py"'
476+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}'
477+
if sdxl:
478+
run_cmd += f' "./sdxl_train.py"'
479+
else:
480+
run_cmd += f' "./fine_tune.py"'
481+
477482
if v2:
478483
run_cmd += ' --v2'
479484
if v_parameterization:
@@ -626,6 +631,7 @@ def finetune_tab(headless=False):
626631
pretrained_model_name_or_path,
627632
v2,
628633
v_parameterization,
634+
sdxl,
629635
save_model_as,
630636
model_list,
631637
) = gradio_source_model(headless=headless)
@@ -852,6 +858,7 @@ def finetune_tab(headless=False):
852858
pretrained_model_name_or_path,
853859
v2,
854860
v_parameterization,
861+
sdxl,
855862
train_dir,
856863
image_folder,
857864
output_dir,

library/common_gui.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -477,13 +477,14 @@ def save_inference_file(output_dir, v2, v_parameterization, output_name):
477477

478478

479479
def set_pretrained_model_name_or_path_input(
480-
model_list, pretrained_model_name_or_path, v2, v_parameterization
480+
model_list, pretrained_model_name_or_path, v2, v_parameterization, sdxl
481481
):
482482
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v2 list
483483
if str(model_list) in V2_BASE_MODELS:
484484
log.info('SD v2 model detected. Setting --v2 parameter')
485485
v2 = True
486486
v_parameterization = False
487+
sdxl = False
487488
pretrained_model_name_or_path = str(model_list)
488489

489490
# check if $v2 and $v_parameterization are empty and if $pretrained_model_name_or_path contains any of the substrings in the v_parameterization list
@@ -493,11 +494,13 @@ def set_pretrained_model_name_or_path_input(
493494
)
494495
v2 = True
495496
v_parameterization = True
497+
sdxl = False
496498
pretrained_model_name_or_path = str(model_list)
497499

498500
if str(model_list) in V1_MODELS:
499501
v2 = False
500502
v_parameterization = False
503+
sdxl = False
501504
pretrained_model_name_or_path = str(model_list)
502505

503506
if model_list == 'custom':
@@ -509,7 +512,8 @@ def set_pretrained_model_name_or_path_input(
509512
pretrained_model_name_or_path = ''
510513
v2 = False
511514
v_parameterization = False
512-
return model_list, pretrained_model_name_or_path, v2, v_parameterization
515+
sdxl = False
516+
return model_list, pretrained_model_name_or_path, v2, sdxl
513517

514518

515519
def set_v2_checkbox(model_list, v2, v_parameterization):
@@ -535,14 +539,15 @@ def set_model_list(
535539
pretrained_model_name_or_path,
536540
v2,
537541
v_parameterization,
542+
sdxl
538543
):
539544

540545
if not pretrained_model_name_or_path in ALL_PRESET_MODELS:
541546
model_list = 'custom'
542547
else:
543548
model_list = pretrained_model_name_or_path
544549

545-
return model_list, v2, v_parameterization
550+
return model_list, v2, v_parameterization, sdxl
546551

547552

548553
###
@@ -655,6 +660,9 @@ def gradio_source_model(
655660
v_parameterization = gr.Checkbox(
656661
label='v_parameterization', value=False
657662
)
663+
sdxl = gr.Checkbox(
664+
label='SDXL Model', value=False
665+
)
658666
v2.change(
659667
set_v2_checkbox,
660668
inputs=[model_list, v2, v_parameterization],
@@ -674,12 +682,14 @@ def gradio_source_model(
674682
pretrained_model_name_or_path,
675683
v2,
676684
v_parameterization,
685+
sdxl,
677686
],
678687
outputs=[
679688
model_list,
680689
pretrained_model_name_or_path,
681690
v2,
682691
v_parameterization,
692+
sdxl,
683693
],
684694
show_progress=False,
685695
)
@@ -691,18 +701,21 @@ def gradio_source_model(
691701
pretrained_model_name_or_path,
692702
v2,
693703
v_parameterization,
704+
sdxl,
694705
],
695706
outputs=[
696707
model_list,
697708
v2,
698709
v_parameterization,
710+
sdxl,
699711
],
700712
show_progress=False,
701713
)
702714
return (
703715
pretrained_model_name_or_path,
704716
v2,
705717
v_parameterization,
718+
sdxl,
706719
save_model_as,
707720
model_list,
708721
)

lora_gui.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def save_configuration(
6969
file_path,
7070
pretrained_model_name_or_path,
7171
v2,
72-
v_parameterization,
72+
v_parameterization, sdxl,
7373
logging_dir,
7474
train_data_dir,
7575
reg_data_dir,
@@ -215,7 +215,7 @@ def open_configuration(
215215
file_path,
216216
pretrained_model_name_or_path,
217217
v2,
218-
v_parameterization,
218+
v_parameterization, sdxl,
219219
logging_dir,
220220
train_data_dir,
221221
reg_data_dir,
@@ -376,7 +376,7 @@ def train_model(
376376
print_only,
377377
pretrained_model_name_or_path,
378378
v2,
379-
v_parameterization,
379+
v_parameterization, sdxl,
380380
logging_dir,
381381
train_data_dir,
382382
reg_data_dir,
@@ -658,7 +658,11 @@ def train_model(
658658
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
659659
log.info(f'lr_warmup_steps = {lr_warmup_steps}')
660660

661-
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process} "train_network.py"'
661+
run_cmd = f'accelerate launch --num_cpu_threads_per_process={num_cpu_threads_per_process}'
662+
if sdxl:
663+
run_cmd += f' "./sdxl_train_network.py"'
664+
else:
665+
run_cmd += f' "./train_network.py"'
662666

663667
if v2:
664668
run_cmd += ' --v2'
@@ -996,6 +1000,7 @@ def lora_tab(
9961000
pretrained_model_name_or_path,
9971001
v2,
9981002
v_parameterization,
1003+
sdxl,
9991004
save_model_as,
10001005
model_list,
10011006
) = gradio_source_model(
@@ -1588,6 +1593,7 @@ def update_LoRA_settings(LoRA_type):
15881593
pretrained_model_name_or_path,
15891594
v2,
15901595
v_parameterization,
1596+
sdxl,
15911597
logging_dir,
15921598
train_data_dir,
15931599
reg_data_dir,
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
{
2+
"LoRA_type": "Standard",
3+
"adaptive_noise_scale": 0,
4+
"additional_parameters": "",
5+
"block_alphas": "",
6+
"block_dims": "",
7+
"block_lr_zero_threshold": "",
8+
"bucket_no_upscale": true,
9+
"bucket_reso_steps": 1,
10+
"cache_latents": true,
11+
"cache_latents_to_disk": true,
12+
"caption_dropout_every_n_epochs": 0.0,
13+
"caption_dropout_rate": 0,
14+
"caption_extension": ".none-use-foldername",
15+
"clip_skip": "1",
16+
"color_aug": false,
17+
"conv_alpha": 64,
18+
"conv_alphas": "",
19+
"conv_dim": 64,
20+
"conv_dims": "",
21+
"decompose_both": false,
22+
"dim_from_weights": false,
23+
"down_lr_weight": "",
24+
"enable_bucket": true,
25+
"epoch": 4,
26+
"factor": -1,
27+
"flip_aug": false,
28+
"full_fp16": false,
29+
"gradient_accumulation_steps": 1,
30+
"gradient_checkpointing": false,
31+
"keep_tokens": "0",
32+
"learning_rate": 4e-07,
33+
"logging_dir": "",
34+
"lora_network_weights": "",
35+
"lr_scheduler": "constant_with_warmup",
36+
"lr_scheduler_num_cycles": "",
37+
"lr_scheduler_power": "",
38+
"lr_warmup": 8,
39+
"max_data_loader_n_workers": "0",
40+
"max_resolution": "512,512",
41+
"max_token_length": "75",
42+
"max_train_epochs": "",
43+
"mem_eff_attn": false,
44+
"mid_lr_weight": "",
45+
"min_snr_gamma": 10,
46+
"mixed_precision": "bf16",
47+
"model_list": "runwayml/stable-diffusion-v1-5",
48+
"module_dropout": 0,
49+
"multires_noise_discount": 0.2,
50+
"multires_noise_iterations": 8,
51+
"network_alpha": 64,
52+
"network_dim": 64,
53+
"network_dropout": 0,
54+
"no_token_padding": false,
55+
"noise_offset": 0.0357,
56+
"noise_offset_type": "Original",
57+
"num_cpu_threads_per_process": 2,
58+
"optimizer": "Adafactor",
59+
"optimizer_args": "scale_parameter=False relative_step=False warmup_init=False",
60+
"output_dir": "",
61+
"output_name": "",
62+
"persistent_data_loader_workers": false,
63+
"pretrained_model_name_or_path": "runwayml/stable-diffusion-v1-5",
64+
"prior_loss_weight": 1.0,
65+
"random_crop": false,
66+
"rank_dropout": 0,
67+
"reg_data_dir": "",
68+
"resume": "",
69+
"sample_every_n_epochs": 0,
70+
"sample_every_n_steps": 0,
71+
"sample_prompts": "",
72+
"sample_sampler": "euler_a",
73+
"save_every_n_epochs": 1,
74+
"save_every_n_steps": 0,
75+
"save_last_n_steps": 0,
76+
"save_last_n_steps_state": 0,
77+
"save_model_as": "safetensors",
78+
"save_precision": "fp16",
79+
"save_state": false,
80+
"scale_v_pred_loss_like_noise_pred": false,
81+
"scale_weight_norms": 0,
82+
"sdxl": true,
83+
"seed": "",
84+
"shuffle_caption": false,
85+
"stop_text_encoder_training": 0,
86+
"text_encoder_lr": 0.0,
87+
"train_batch_size": 1,
88+
"train_data_dir": "",
89+
"train_on_input": true,
90+
"training_comment": "",
91+
"unet_lr": 4e-07,
92+
"unit": 1,
93+
"up_lr_weight": "",
94+
"use_cp": false,
95+
"use_wandb": false,
96+
"v2": false,
97+
"v_parameterization": false,
98+
"vae_batch_size": 0,
99+
"wandb_api_key": "",
100+
"weighted_captions": false,
101+
"xformers": true
102+
}

0 commit comments

Comments
 (0)