Skip to content

Commit 49279bb

Browse files
committed
Update IP-Adapter unit test for multi-image.
1 parent 8464450 commit 49279bb

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

tests/backend/ip_adapter/test_ip_adapter.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def test_ip_adapter_unet_patch(model_params, model_installer, torch_device):
6565
ip_adapter.to(torch_device, dtype=torch.float32)
6666
unet.to(torch_device, dtype=torch.float32)
6767

68-
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [torch.randn((1, 4, 768)).to(torch_device)]}
68+
# ip_embeds shape: (batch_size, num_ip_images, seq_len, ip_image_embedding_len)
69+
ip_embeds = torch.randn((1, 3, 4, 768)).to(torch_device)
70+
71+
cross_attention_kwargs = {"ip_adapter_image_prompt_embeds": [ip_embeds]}
6972
ip_adapter_unet_patcher = UNetPatcher([ip_adapter])
7073
with ip_adapter_unet_patcher.apply_ip_adapter_attention(unet):
7174
output = unet(**dummy_unet_input, cross_attention_kwargs=cross_attention_kwargs).sample

0 commit comments

Comments
 (0)