Skip to content

Commit bac2a75

Browse files
RyanJDickhipsterusername
authored andcommitted
Replace deepcopy with a pickle roundtrip in apply_ti(...) to improve speed.
1 parent a4a7b60 commit bac2a75

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

  • invokeai/backend/model_management

invokeai/backend/model_management/lora.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import copy
3+
import pickle
44
from contextlib import contextmanager
55
from pathlib import Path
66
from typing import Any, Dict, List, Optional, Tuple, Union
@@ -165,7 +165,13 @@ def apply_ti(
165165
new_tokens_added = None
166166

167167
try:
168-
ti_tokenizer = copy.deepcopy(tokenizer)
168+
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
169+
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
170+
# exiting this `apply_ti(...)` context manager.
171+
#
172+
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
173+
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
174+
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
169175
ti_manager = TextualInversionManager(ti_tokenizer)
170176
init_tokens_count = text_encoder.resize_token_embeddings(None).num_embeddings
171177

@@ -439,7 +445,13 @@ def apply_ti(
439445
orig_embeddings = None
440446

441447
try:
442-
ti_tokenizer = copy.deepcopy(tokenizer)
448+
# HACK: The CLIPTokenizer API does not include a way to remove tokens after calling add_tokens(...). As a
449+
# workaround, we create a full copy of `tokenizer` so that its original behavior can be restored after
450+
# exiting this `apply_ti(...)` context manager.
451+
#
452+
# In a previous implementation, the deep copy was obtained with `ti_tokenizer = copy.deepcopy(tokenizer)`,
453+
# but a pickle roundtrip was found to be much faster (1 sec vs. 0.05 secs).
454+
ti_tokenizer = pickle.loads(pickle.dumps(tokenizer))
443455
ti_manager = TextualInversionManager(ti_tokenizer)
444456

445457
def _get_trigger(ti_name, index):

0 commit comments

Comments
 (0)