Skip to content

Commit 6b173cc

Browse files
committed
multiple small stylistic changes requested by reviewers
1 parent ce22c0f commit 6b173cc

5 files changed

Lines changed: 110 additions & 59 deletions

File tree

invokeai/app/services/model_records/model_records_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelConfigBase, ModelType
1111

1212
# should match the InvokeAI version when this is first released.
13-
CONFIG_FILE_VERSION = "3.2"
13+
CONFIG_FILE_VERSION = "3.2.0"
1414

1515

1616
class DuplicateModelException(Exception):

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 23 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@
6666
class ModelRecordServiceSQL(ModelRecordServiceBase):
6767
"""Implementation of the ModelConfigStore ABC using a SQL database."""
6868

69-
_conn: sqlite3.Connection
69+
_db: SqliteDatabase
7070
_cursor: sqlite3.Cursor
71-
_lock: threading.Lock
7271

7372
def __init__(self, db: SqliteDatabase):
7473
"""
@@ -78,16 +77,15 @@ def __init__(self, db: SqliteDatabase):
7877
:param lock: threading Lock object
7978
"""
8079
super().__init__()
81-
self._conn = db.conn
82-
self._lock = db.lock
83-
self._conn.row_factory = sqlite3.Row
84-
self._cursor = self._conn.cursor()
80+
self._db = db
81+
self._db.conn.row_factory = sqlite3.Row
82+
self._cursor = self._db.conn.cursor()
8583

86-
with self._lock:
84+
with self._db.lock:
8785
# Enable foreign keys
88-
self._conn.execute("PRAGMA foreign_keys = ON;")
86+
self._db.conn.execute("PRAGMA foreign_keys = ON;")
8987
self._create_tables()
90-
self._conn.commit()
88+
self._db.conn.commit()
9189
assert (
9290
str(self.version) == CONFIG_FILE_VERSION
9391
), f"Model config version {self.version} does not match expected version {CONFIG_FILE_VERSION}"
@@ -172,7 +170,7 @@ def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConf
172170
"""
173171
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect.
174172
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
175-
with self._lock:
173+
with self._db.lock:
176174
try:
177175
self._cursor.execute(
178176
"""--sql
@@ -197,10 +195,10 @@ def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConf
197195
json_serialized,
198196
),
199197
)
200-
self._conn.commit()
198+
self._db.conn.commit()
201199

202200
except sqlite3.IntegrityError as e:
203-
self._conn.rollback()
201+
self._db.conn.rollback()
204202
if "UNIQUE constraint failed" in str(e):
205203
if "model_config.path" in str(e):
206204
msg = f"A model with path '{record.path}' is already installed"
@@ -210,15 +208,15 @@ def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConf
210208
else:
211209
raise e
212210
except sqlite3.Error as e:
213-
self._conn.rollback()
211+
self._db.conn.rollback()
214212
raise e
215213

216214
return self.get_model(key)
217215

218216
@property
219217
def version(self) -> str:
220218
"""Return the version of the database schema."""
221-
with self._lock:
219+
with self._db.lock:
222220
self._cursor.execute(
223221
"""--sql
224222
SELECT metadata_value FROM model_manager_metadata
@@ -239,7 +237,7 @@ def del_model(self, key: str) -> None:
239237
240238
Can raise an UnknownModelException
241239
"""
242-
with self._lock:
240+
with self._db.lock:
243241
try:
244242
self._cursor.execute(
245243
"""--sql
@@ -250,9 +248,9 @@ def del_model(self, key: str) -> None:
250248
)
251249
if self._cursor.rowcount == 0:
252250
raise UnknownModelException("model not found")
253-
self._conn.commit()
251+
self._db.conn.commit()
254252
except sqlite3.Error as e:
255-
self._conn.rollback()
253+
self._db.conn.rollback()
256254
raise e
257255

258256
def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConfigBase:
@@ -265,7 +263,7 @@ def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelC
265263
"""
266264
record = ModelConfigFactory.make_config(config, key=key) # ensure it is a valid config obect
267265
json_serialized = json.dumps(record.model_dump()) # and turn it into a json string.
268-
with self._lock:
266+
with self._db.lock:
269267
try:
270268
self._cursor.execute(
271269
"""--sql
@@ -281,9 +279,9 @@ def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelC
281279
)
282280
if self._cursor.rowcount == 0:
283281
raise UnknownModelException("model not found")
284-
self._conn.commit()
282+
self._db.conn.commit()
285283
except sqlite3.Error as e:
286-
self._conn.rollback()
284+
self._db.conn.rollback()
287285
raise e
288286

289287
return self.get_model(key)
@@ -296,7 +294,7 @@ def get_model(self, key: str) -> AnyModelConfig:
296294
297295
Exceptions: UnknownModelException
298296
"""
299-
with self._lock:
297+
with self._db.lock:
300298
self._cursor.execute(
301299
"""--sql
302300
SELECT config FROM model_config
@@ -317,7 +315,7 @@ def exists(self, key: str) -> bool:
317315
:param key: Unique key for the model to be deleted
318316
"""
319317
count = 0
320-
with self._lock:
318+
with self._db.lock:
321319
try:
322320
self._cursor.execute(
323321
"""--sql
@@ -360,7 +358,7 @@ def search_by_name(
360358
where_clause.append("type=?")
361359
bindings.append(model_type)
362360
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
363-
with self._lock:
361+
with self._db.lock:
364362
try:
365363
self._cursor.execute(
366364
f"""--sql
@@ -377,7 +375,7 @@ def search_by_name(
377375
def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
378376
"""Return models with the indicated path."""
379377
results = []
380-
with self._lock:
378+
with self._db.lock:
381379
try:
382380
self._cursor.execute(
383381
"""--sql
@@ -394,7 +392,7 @@ def search_by_path(self, path: Union[str, Path]) -> List[ModelConfigBase]:
394392
def search_by_hash(self, hash: str) -> List[ModelConfigBase]:
395393
"""Return models with the indicated original_hash."""
396394
results = []
397-
with self._lock:
395+
with self._db.lock:
398396
try:
399397
self._cursor.execute(
400398
"""--sql

invokeai/backend/model_manager/migrate_to_db.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,19 @@
1616
ModelsValidator = TypeAdapter(AnyModelConfig)
1717

1818

19-
class Migrate:
20-
"""Migration class."""
19+
class MigrateModelYamlToDb:
20+
"""
21+
Migrate the InvokeAI models.yaml format (VERSION 3.0.0) to SQL3 database format (VERSION 3.2.0)
22+
23+
The class has one externally useful method, migrate(), which scans the
24+
currently models.yaml file and imports all its entries into invokeai.db.
25+
26+
Use this way:
27+
28+
from invokeai.backend.model_manager/migrate_to_db import MigrateModelYamlToDb
29+
MigrateModelYamlToDb().migrate()
30+
31+
"""
2132

2233
config: InvokeAIAppConfig
2334
logger: InvokeAILogger
@@ -28,14 +39,17 @@ def __init__(self):
2839
self.logger = InvokeAILogger.get_logger()
2940

3041
def get_db(self) -> ModelRecordServiceSQL:
42+
"""Fetch the sqlite3 database for this installation."""
3143
db = SqliteDatabase(self.config, self.logger)
3244
return ModelRecordServiceSQL(db)
3345

3446
def get_yaml(self) -> DictConfig:
47+
"""Fetch the models.yaml DictConfig for this installation."""
3548
yaml_path = self.config.model_conf_path
3649
return OmegaConf.load(yaml_path)
3750

3851
def migrate(self):
52+
"""Do the migration from models.yaml to invokeai.db."""
3953
db = self.get_db()
4054
yaml = self.get_yaml()
4155

@@ -65,7 +79,7 @@ def migrate(self):
6579

6680

6781
def main():
68-
Migrate().migrate()
82+
MigrateModelYamlToDb().migrate()
6983

7084

7185
if __name__ == "__main__":

invokeai/backend/util/util.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -386,13 +386,10 @@ def __exit__(self, *args):
386386
class SilenceWarnings(object):
387387
"""Context manager to temporarily lower verbosity of diffusers & transformers warning messages."""
388388

389-
def __init__(self):
390-
"""Set up context, save current transformers and diffusers verbosity settings."""
391-
self.transformers_verbosity = transformers_logging.get_verbosity()
392-
self.diffusers_verbosity = diffusers_logging.get_verbosity()
393-
394389
def __enter__(self):
395390
"""Set verbosity to error."""
391+
self.transformers_verbosity = transformers_logging.get_verbosity()
392+
self.diffusers_verbosity = diffusers_logging.get_verbosity()
396393
transformers_logging.set_verbosity_error()
397394
diffusers_logging.set_verbosity_error()
398395
warnings.simplefilter("ignore")

tests/backend/model_manager_2/test_model_storage_sql.py renamed to tests/app/services/model_records/test_model_records_sql.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase):
6767
def test_dup(store: ModelRecordServiceBase):
6868
config = example_config()
6969
store.add_model("key1", example_config())
70-
try:
70+
with pytest.raises(DuplicateModelException):
7171
store.add_model("key1", config)
72-
assert False, "Duplicate model key should have been caught"
73-
except DuplicateModelException:
74-
assert True
75-
try:
72+
with pytest.raises(DuplicateModelException):
7673
store.add_model("key2", config)
77-
assert False, "Duplicate model path should have been caught"
78-
except DuplicateModelException:
79-
assert True
8074

8175

8276
def test_update(store: ModelRecordServiceBase):
@@ -90,32 +84,21 @@ def test_update(store: ModelRecordServiceBase):
9084
new_config = store.get_model("key1")
9185
assert new_config.name == "new name"
9286

93-
try:
87+
88+
def test_unknown_key(store: ModelRecordServiceBase):
89+
config = example_config()
90+
store.add_model("key1", config)
91+
with pytest.raises(UnknownModelException):
9492
store.update_model("unknown_key", config)
95-
assert False, "expected UnknownModelException"
96-
except UnknownModelException:
97-
assert True
9893

9994

10095
def test_delete(store: ModelRecordServiceBase):
10196
config = example_config()
10297
store.add_model("key1", config)
10398
config = store.get_model("key1")
10499
store.del_model("key1")
105-
try:
100+
with pytest.raises(UnknownModelException):
106101
config = store.get_model("key1")
107-
assert False, "expected fetch of deleted model to raise exception"
108-
except UnknownModelException:
109-
assert True
110-
111-
# a bug in sqlite3 in python 3.9 prevents DEL from returning number of
112-
# deleted rows!
113-
if sys.version_info.major == 3 and sys.version_info.minor > 9:
114-
try:
115-
store.del_model("unknown")
116-
assert False, "expected delete of unknown model to raise exception"
117-
except UnknownModelException:
118-
assert True
119102

120103

121104
def test_exists(store: ModelRecordServiceBase):
@@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase):
167150

168151
matches = store.all_models()
169152
assert len(matches) == 3
153+
154+
155+
def test_filter_2(store: ModelRecordServiceBase):
156+
config1 = DiffusersConfig(
157+
path="/tmp/config1",
158+
name="config1",
159+
base=BaseModelType("sd-1"),
160+
type=ModelType("main"),
161+
original_hash="CONFIG1HASH",
162+
)
163+
config2 = DiffusersConfig(
164+
path="/tmp/config2",
165+
name="config2",
166+
base=BaseModelType("sd-1"),
167+
type=ModelType("main"),
168+
original_hash="CONFIG2HASH",
169+
)
170+
config3 = DiffusersConfig(
171+
path="/tmp/config3",
172+
name="dup_name1",
173+
base=BaseModelType("sd-2"),
174+
type=ModelType("main"),
175+
original_hash="CONFIG3HASH",
176+
)
177+
config4 = DiffusersConfig(
178+
path="/tmp/config4",
179+
name="dup_name1",
180+
base=BaseModelType("sd-2"),
181+
type=ModelType("main"),
182+
original_hash="CONFIG3HASH",
183+
)
184+
config5 = VaeDiffusersConfig(
185+
path="/tmp/config5",
186+
name="dup_name1",
187+
base=BaseModelType("sd-1"),
188+
type=ModelType("vae"),
189+
original_hash="CONFIG3HASH",
190+
)
191+
for c in config1, config2, config3, config4, config5:
192+
store.add_model(sha256(c.path.encode("utf-8")).hexdigest(), c)
193+
194+
matches = store.search_by_name(
195+
model_type=ModelType("main"),
196+
model_name="dup_name1",
197+
)
198+
assert len(matches) == 2
199+
200+
matches = store.search_by_name(
201+
base_model=BaseModelType("sd-1"),
202+
model_type=ModelType("main"),
203+
)
204+
assert len(matches) == 2
205+
206+
matches = store.search_by_name(
207+
base_model=BaseModelType("sd-1"),
208+
model_type=ModelType("vae"),
209+
model_name="dup_name1",
210+
)
211+
assert len(matches) == 1

0 commit comments

Comments
 (0)