Skip to content

G_prompt for cut_turbo for dataset with single prompt#662

Open
wr0124 wants to merge 7 commits into
jolibrain:masterfrom
wr0124:G_prompt
Open

G_prompt for cut_turbo for dataset with single prompt#662
wr0124 wants to merge 7 commits into
jolibrain:masterfrom
wr0124:G_prompt

Conversation

@wr0124

@wr0124 wr0124 commented Jun 20, 2024

Copy link
Copy Markdown
Collaborator

add G_prompt for cut_turbo for unaligned dataset and works for batch_size larger than 1

  • inference
  • unit tests
  • documentation

The training works with the following command line

python3 train.py
--dataroot /data1/juliew/dataset/horse2zebra
--checkpoints_dir /data1/juliew/checkpoints
--name horse2zebra_turbo
--config_json examples/example_cut_turbo_horse2zebra.json
--train_batch_size 2
--output_print_freq 10
--data_crop_size 64
--data_load_size 64
--G_prompt zebra (this option is mandatory if there is no prompt file in the dataset)

The inference works with the following command line

cd scripts
python3 gen_single_image.py
--model_in_file /data1/juliew/checkpoints/horse2zebra_turbo/latest_net_G_A.pth
--img_in /data1/juliew/dataset/horse2zebra/testA/n02381460_1000.jpg
--img_out /data1/juliew/target.jpg
--prompt zebra
--gpuid 0 \

@beniz beniz changed the title G_prompt for cut_turbo for unaligned horse2zebra dataset G_prompt for cut_turbo for dataset with single prompts Jun 21, 2024
Comment thread models/cut_model.py Outdated
self.fake_B = self.netG_A(self.real_with_z, G_prompt)
else:
fake_B = self.netG_A(real_A_with_z)
self.fake_B = self.netG_A(self.real_with_z)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove self


# match batch size
captions_enc = caption_enc.repeat(x.shape[0], 1, 1)
batch_size = caption_enc.shape[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unneeded ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inside cut_model, x is created by cat real_A and real_B, which double the batch size of the x tensor. prompt tensor has normal batch_size, so, to match the two tensor, I did this modification. Detail toy example is here: https://colab.research.google.com/drive/1RMvHt2PuQufH4zEc2Lrds561L9NEYzYf?usp=sharing

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fixed by modifying the prompt tensor outside turbo, in cut when A & B are concatenated for inference, not here.

"D_lr": 0.0001,
"G_ema": false,
"G_ema_beta": 0.999,
"G_lr": 0.0002,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set to 0.0001

@beniz beniz changed the title G_prompt for cut_turbo for dataset with single prompts G_prompt for cut_turbo for dataset with single prompt Jun 26, 2024
@@ -201,17 +201,14 @@ def forward(self, x, prompt):
).input_ids.cuda()
caption_enc = self.text_encoder(caption_tokens)[0]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure about the [0] ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with refs:
1.https://huggingface.co/transformers/v4.8.0/model_doc/clip.html#flaxcliptextmodel
2.https://github.com/huggingface/transformers/blob/f91c16d270e5e3ff32fdb32ccf286d05c03dfa66/src/transformers/models/clip/modeling_clip.py#L759
"outputs= self.text_encoder(caption_tokens)"
type(outputs)= text_encoder <class 'transformers.modeling_outputs.BaseModelOutputWithPooling'>
len(outputs) = 2
outputs[0].shape = torch.Size([4, 77, 1024]) this is last_hidden_state
outputs[1].shape = torch.Size([4, 1024]) this is the pooler_output
According to the explication of refs 1, should be outputs[0].

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants