Skip to content

Commit ce22c0f

Browse files
committed
sync pydantic and sql field names; merge routes
1 parent 55f8865 commit ce22c0f

6 files changed

Lines changed: 82 additions & 45 deletions

File tree

invokeai/app/api/routers/model_records.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from ..dependencies import ApiDependencies
1818

19-
model_records_router = APIRouter(prefix="/v1/model_records", tags=["model_records"])
19+
model_records_router = APIRouter(prefix="/v1/model/record", tags=["models"])
2020

2121
ModelConfigValidator = TypeAdapter(AnyModelConfig)
2222

@@ -34,7 +34,7 @@ class ModelsList(BaseModel):
3434

3535
@model_records_router.get(
3636
"/",
37-
operation_id="list_model_configs",
37+
operation_id="list_model_recordss",
3838
responses={200: {"model": ModelsList}},
3939
)
4040
async def list_model_records(

invokeai/app/services/model_records/model_records_sql.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
2828
# fetching config
2929
new_config = store.get_model('key1')
30-
print(new_config.name, new_config.base_model)
30+
print(new_config.name, new_config.base)
3131
assert new_config.key == 'key1'
3232
3333
# deleting
@@ -100,11 +100,11 @@ def _create_tables(self) -> None:
100100
"""--sql
101101
CREATE TABLE IF NOT EXISTS model_config (
102102
id TEXT NOT NULL PRIMARY KEY,
103-
-- These 4 fields are enums in python, unrestricted string here
104-
base_model TEXT NOT NULL,
105-
model_type TEXT NOT NULL,
106-
model_name TEXT NOT NULL,
107-
model_path TEXT NOT NULL,
103+
-- The next 3 fields are enums in python, unrestricted string here
104+
base TEXT NOT NULL,
105+
type TEXT NOT NULL,
106+
name TEXT NOT NULL,
107+
path TEXT NOT NULL,
108108
original_hash TEXT, -- could be null
109109
-- Serialized JSON representation of the whole config object,
110110
-- which will contain additional fields from subclasses
@@ -139,6 +139,15 @@ def _create_tables(self) -> None:
139139
"""
140140
)
141141

142+
# Add indexes for searchable fields
143+
for stmt in [
144+
"CREATE INDEX IF NOT EXISTS base_index ON model_config(base);",
145+
"CREATE INDEX IF NOT EXISTS type_index ON model_config(type);",
146+
"CREATE INDEX IF NOT EXISTS name_index ON model_config(name);",
147+
"CREATE UNIQUE INDEX IF NOT EXISTS path_index ON model_config(path);",
148+
]:
149+
self._cursor.execute(stmt)
150+
142151
# Add our version to the metadata table
143152
self._cursor.execute(
144153
"""--sql
@@ -169,18 +178,18 @@ def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConf
169178
"""--sql
170179
INSERT INTO model_config (
171180
id,
172-
base_model,
173-
model_type,
174-
model_name,
175-
model_path,
181+
base,
182+
type,
183+
name,
184+
path,
176185
original_hash,
177186
config
178187
)
179188
VALUES (?,?,?,?,?,?,?);
180189
""",
181190
(
182191
key,
183-
record.base_model,
192+
record.base,
184193
record.type,
185194
record.name,
186195
record.path,
@@ -193,7 +202,11 @@ def add_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelConf
193202
except sqlite3.IntegrityError as e:
194203
self._conn.rollback()
195204
if "UNIQUE constraint failed" in str(e):
196-
raise DuplicateModelException(f"A model with key '{key}' is already installed") from e
205+
if "model_config.path" in str(e):
206+
msg = f"A model with path '{record.path}' is already installed"
207+
else:
208+
msg = f"A model with key '{key}' is already installed"
209+
raise DuplicateModelException(msg) from e
197210
else:
198211
raise e
199212
except sqlite3.Error as e:
@@ -257,14 +270,14 @@ def update_model(self, key: str, config: Union[dict, ModelConfigBase]) -> ModelC
257270
self._cursor.execute(
258271
"""--sql
259272
UPDATE model_config
260-
SET base_model=?,
261-
model_type=?,
262-
model_name=?,
263-
model_path=?,
273+
SET base=?,
274+
type=?,
275+
name=?,
276+
path=?,
264277
config=?
265278
WHERE id=?;
266279
""",
267-
(record.base_model, record.type, record.name, record.path, json_serialized, key),
280+
(record.base, record.type, record.name, record.path, json_serialized, key),
268281
)
269282
if self._cursor.rowcount == 0:
270283
raise UnknownModelException("model not found")
@@ -338,13 +351,13 @@ def search_by_name(
338351
where_clause = []
339352
bindings = []
340353
if model_name:
341-
where_clause.append("model_name=?")
354+
where_clause.append("name=?")
342355
bindings.append(model_name)
343356
if base_model:
344-
where_clause.append("base_model=?")
357+
where_clause.append("base=?")
345358
bindings.append(base_model)
346359
if model_type:
347-
where_clause.append("model_type=?")
360+
where_clause.append("type=?")
348361
bindings.append(model_type)
349362
where = f"WHERE {' AND '.join(where_clause)}" if where_clause else ""
350363
with self._lock:

invokeai/backend/model_manager/config.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from invokeai.backend.model_manager import ModelConfigFactory
88
raw = dict(path='models/sd-1/main/foo.ckpt',
99
name='foo',
10-
base_model='sd-1',
11-
model_type='main',
10+
base='sd-1',
11+
type='main',
1212
config='configs/stable-diffusion/v1-inference.yaml',
1313
variant='normal',
1414
format='checkpoint'
@@ -103,7 +103,7 @@ class ModelConfigBase(BaseModel):
103103

104104
path: str
105105
name: str
106-
base_model: BaseModelType
106+
base: BaseModelType
107107
type: ModelType
108108
format: ModelFormat
109109
key: str = Field(description="unique key for model", default="<NOKEY>")
@@ -181,29 +181,38 @@ class MainConfig(ModelConfigBase):
181181

182182
vae: Optional[str] = Field(None)
183183
variant: ModelVariantType = ModelVariantType.Normal
184+
ztsnr_training: bool = False
184185

185186

186187
class MainCheckpointConfig(CheckpointConfig, MainConfig):
187188
"""Model config for main checkpoint models."""
188189

190+
# Note that we do not need prediction_type or upcast_attention here
191+
# because they are provided in the checkpoint's own config file.
192+
189193

190194
class MainDiffusersConfig(DiffusersConfig, MainConfig):
191195
"""Model config for main diffusers models."""
192196

197+
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
198+
upcast_attention: bool = False
199+
193200

194201
class ONNXSD1Config(MainConfig):
195202
"""Model config for ONNX format models based on sd-1."""
196203

197204
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
205+
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
206+
upcast_attention: bool = False
198207

199208

200209
class ONNXSD2Config(MainConfig):
201210
"""Model config for ONNX format models based on sd-2."""
202211

203212
format: Literal[ModelFormat.Onnx, ModelFormat.Olive]
204213
# No yaml config file for ONNX, so these are part of config
205-
prediction_type: SchedulerPredictionType
206-
upcast_attention: bool
214+
prediction_type: SchedulerPredictionType = SchedulerPredictionType.VPrediction
215+
upcast_attention: bool = True
207216

208217

209218
class IPAdapterConfig(ModelConfigBase):
@@ -305,7 +314,7 @@ def make_config(
305314
try:
306315
format = model_data.get("format")
307316
type = model_data.get("type")
308-
model_base = model_data.get("base_model")
317+
model_base = model_data.get("base")
309318
class_to_return = dest_class or cls._class_map[format][type]
310319
if isinstance(class_to_return, dict): # additional level allowed
311320
class_to_return = class_to_return[model_base]

invokeai/backend/model_manager/migrate_to_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def migrate(self):
5050
hash = FastModelHash.hash(self.config.models_path / stanza.path)
5151
new_key = sha1(model_key.encode("utf-8")).hexdigest()
5252

53-
stanza["base_model"] = BaseModelType(base_type)
53+
stanza["base"] = BaseModelType(base_type)
5454
stanza["type"] = ModelType(model_type)
5555
stanza["name"] = model_name
5656
stanza["original_hash"] = hash

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ dependencies = [
5151
"fastapi~=0.103.2",
5252
"fastapi-events~=0.9.1",
5353
"huggingface-hub~=0.16.4",
54+
"imohash",
5455
"invisible-watermark~=0.2.0", # needed to install SDXL base and refiner using their repo_ids
5556
"matplotlib", # needed for plotting of Penner easing functions
5657
"mediapipe", # needed for "mediapipeface" controlnet model

tests/backend/model_manager_2/test_model_storage_sql.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,12 @@
88
import pytest
99

1010
from invokeai.app.services.config import InvokeAIAppConfig
11-
from invokeai.app.services.model_records import ModelRecordServiceBase, ModelRecordServiceSQL, UnknownModelException
11+
from invokeai.app.services.model_records import (
12+
DuplicateModelException,
13+
ModelRecordServiceBase,
14+
ModelRecordServiceSQL,
15+
UnknownModelException,
16+
)
1217
from invokeai.app.services.shared.sqlite import SqliteDatabase
1318
from invokeai.backend.model_manager.config import (
1419
BaseModelType,
@@ -32,8 +37,8 @@ def example_config() -> TextualInversionConfig:
3237
return TextualInversionConfig(
3338
path="/tmp/pokemon.bin",
3439
name="old name",
35-
base_model="sd-1",
36-
type="embedding",
40+
base=BaseModelType("sd-1"),
41+
type=ModelType("embedding"),
3742
format="embedding_file",
3843
original_hash="ABC123",
3944
)
@@ -43,7 +48,7 @@ def test_add(store: ModelRecordServiceBase):
4348
raw = dict(
4449
path="/tmp/foo.ckpt",
4550
name="model1",
46-
base_model="sd-1",
51+
base=BaseModelType("sd-1"),
4752
type="main",
4853
config="/tmp/foo.yaml",
4954
variant="normal",
@@ -53,16 +58,25 @@ def test_add(store: ModelRecordServiceBase):
5358
store.add_model("key1", raw)
5459
config1 = store.get_model("key1")
5560
assert config1 is not None
56-
raw["name"] = "model2"
57-
raw["base_model"] = "sd-2"
58-
raw["format"] = "diffusers"
59-
raw.pop("config")
60-
store.add_model("key2", raw)
61-
config2 = store.get_model("key2")
61+
assert config1.base == BaseModelType("sd-1")
6262
assert config1.name == "model1"
63-
assert config2.name == "model2"
64-
assert config1.base_model == "sd-1"
65-
assert config2.base_model == "sd-2"
63+
assert config1.original_hash == "111222333444"
64+
assert config1.current_hash is None
65+
66+
67+
def test_dup(store: ModelRecordServiceBase):
68+
config = example_config()
69+
store.add_model("key1", example_config())
70+
try:
71+
store.add_model("key1", config)
72+
assert False, "Duplicate model key should have been caught"
73+
except DuplicateModelException:
74+
assert True
75+
try:
76+
store.add_model("key2", config)
77+
assert False, "Duplicate model path should have been caught"
78+
except DuplicateModelException:
79+
assert True
6680

6781

6882
def test_update(store: ModelRecordServiceBase):
@@ -115,21 +129,21 @@ def test_filter(store: ModelRecordServiceBase):
115129
config1 = DiffusersConfig(
116130
path="/tmp/config1",
117131
name="config1",
118-
base_model=BaseModelType("sd-1"),
132+
base=BaseModelType("sd-1"),
119133
type=ModelType("main"),
120134
original_hash="CONFIG1HASH",
121135
)
122136
config2 = DiffusersConfig(
123137
path="/tmp/config2",
124138
name="config2",
125-
base_model=BaseModelType("sd-1"),
139+
base=BaseModelType("sd-1"),
126140
type=ModelType("main"),
127141
original_hash="CONFIG2HASH",
128142
)
129143
config3 = VaeDiffusersConfig(
130144
path="/tmp/config3",
131145
name="config3",
132-
base_model=BaseModelType("sd-2"),
146+
base=BaseModelType("sd-2"),
133147
type=ModelType("vae"),
134148
original_hash="CONFIG3HASH",
135149
)

0 commit comments

Comments
 (0)