Skip to content

Commit e92b849

Browse files
RyanJDickhipsterusername
authored andcommitted
Add minimal unit tests for ModelPatcher.apply_lora(...)
1 parent 61b17c4 commit e92b849

1 file changed

Lines changed: 102 additions & 0 deletions

File tree

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# test that if the model's device changes while the lora is applied, the weights can still be restored
2+
3+
# test that LoRA patching works on both CPU and CUDA
4+
5+
import pytest
6+
import torch
7+
8+
from invokeai.backend.model_management.lora import ModelPatcher
9+
from invokeai.backend.model_management.models.lora import LoRALayer, LoRAModelRaw
10+
11+
12+
@pytest.mark.parametrize(
13+
"device",
14+
[
15+
"cpu",
16+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")),
17+
],
18+
)
19+
@torch.no_grad()
20+
def test_apply_lora(device):
21+
"""Test the basic behavior of ModelPatcher.apply_lora(...). Check that patching and unpatching produce the correct
22+
result, and that model/LoRA tensors are moved between devices as expected.
23+
"""
24+
25+
linear_in_features = 4
26+
linear_out_features = 8
27+
lora_dim = 2
28+
model = torch.nn.ModuleDict(
29+
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device=device, dtype=torch.float16)}
30+
)
31+
32+
lora_layers = {
33+
"linear_layer_1": LoRALayer(
34+
layer_key="linear_layer_1",
35+
values={
36+
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
37+
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
38+
},
39+
)
40+
}
41+
lora = LoRAModelRaw("lora_name", lora_layers)
42+
43+
lora_weight = 0.5
44+
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
45+
expected_patched_linear_weight = orig_linear_weight + (lora_dim * lora_weight)
46+
47+
with ModelPatcher.apply_lora(model, [(lora, lora_weight)], prefix=""):
48+
# After patching, all LoRA layer weights should have been moved back to the cpu.
49+
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
50+
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
51+
52+
# After patching, the patched model should still be on its original device.
53+
assert model["linear_layer_1"].weight.data.device.type == device
54+
55+
torch.testing.assert_close(model["linear_layer_1"].weight.data, expected_patched_linear_weight)
56+
57+
# After unpatching, the original model weights should have been restored on the original device.
58+
assert model["linear_layer_1"].weight.data.device.type == device
59+
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight)
60+
61+
62+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA device")
63+
@torch.no_grad()
64+
def test_apply_lora_change_device():
65+
"""Test that if LoRA patching is applied on the CPU, and then the patched model is moved to the GPU, unpatching
66+
still behaves correctly.
67+
"""
68+
linear_in_features = 4
69+
linear_out_features = 8
70+
lora_dim = 2
71+
# Initialize the model on the CPU.
72+
model = torch.nn.ModuleDict(
73+
{"linear_layer_1": torch.nn.Linear(linear_in_features, linear_out_features, device="cpu", dtype=torch.float16)}
74+
)
75+
76+
lora_layers = {
77+
"linear_layer_1": LoRALayer(
78+
layer_key="linear_layer_1",
79+
values={
80+
"lora_down.weight": torch.ones((lora_dim, linear_in_features), device="cpu", dtype=torch.float16),
81+
"lora_up.weight": torch.ones((linear_out_features, lora_dim), device="cpu", dtype=torch.float16),
82+
},
83+
)
84+
}
85+
lora = LoRAModelRaw("lora_name", lora_layers)
86+
87+
orig_linear_weight = model["linear_layer_1"].weight.data.detach().clone()
88+
89+
with ModelPatcher.apply_lora(model, [(lora, 0.5)], prefix=""):
90+
# After patching, all LoRA layer weights should have been moved back to the cpu.
91+
assert lora_layers["linear_layer_1"].up.device.type == "cpu"
92+
assert lora_layers["linear_layer_1"].down.device.type == "cpu"
93+
94+
# After patching, the patched model should still be on the CPU.
95+
assert model["linear_layer_1"].weight.data.device.type == "cpu"
96+
97+
# Move the model to the GPU.
98+
assert model.to("cuda")
99+
100+
# After unpatching, the original model weights should have been restored on the GPU.
101+
assert model["linear_layer_1"].weight.data.device.type == "cuda"
102+
torch.testing.assert_close(model["linear_layer_1"].weight.data, orig_linear_weight, check_device=False)

0 commit comments

Comments
 (0)