Skip to content

Commit acc0a29

Browse files
author
Lincoln Stein
committed
fixed ruff formatting issues
1 parent 38c1436 commit acc0a29

3 files changed

Lines changed: 11 additions & 35 deletions

File tree

invokeai/app/api/routers/model_records.py

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,23 +41,15 @@ class ModelsList(BaseModel):
4141
operation_id="list_model_records",
4242
)
4343
async def list_model_records(
44-
base_models: Optional[List[BaseModelType]] = Query(
45-
default=None, description="Base models to include"
46-
),
47-
model_type: Optional[ModelType] = Query(
48-
default=None, description="The type of model to get"
49-
),
44+
base_models: Optional[List[BaseModelType]] = Query(default=None, description="Base models to include"),
45+
model_type: Optional[ModelType] = Query(default=None, description="The type of model to get"),
5046
) -> ModelsList:
5147
"""Get a list of models."""
5248
record_store = ApiDependencies.invoker.services.model_records
5349
found_models: list[AnyModelConfig] = []
5450
if base_models:
5551
for base_model in base_models:
56-
found_models.extend(
57-
record_store.search_by_attr(
58-
base_model=base_model, model_type=model_type
59-
)
60-
)
52+
found_models.extend(record_store.search_by_attr(base_model=base_model, model_type=model_type))
6153
else:
6254
found_models.extend(record_store.search_by_attr(model_type=model_type))
6355
return ModelsList(models=found_models)
@@ -97,9 +89,7 @@ async def get_model_record(
9789
)
9890
async def update_model_record(
9991
key: Annotated[str, Path(description="Unique key of model")],
100-
info: Annotated[
101-
AnyModelConfig, Body(description="Model config", discriminator="type")
102-
],
92+
info: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")],
10393
) -> AnyModelConfig:
10494
"""Update model contents with a new config. If the model name or base fields are changed, then the model is renamed."""
10595
logger = ApiDependencies.invoker.services.logger
@@ -145,17 +135,13 @@ async def del_model_record(
145135
operation_id="add_model_record",
146136
responses={
147137
201: {"description": "The model added successfully"},
148-
409: {
149-
"description": "There is already a model corresponding to this path or repo_id"
150-
},
138+
409: {"description": "There is already a model corresponding to this path or repo_id"},
151139
415: {"description": "Unrecognized file/folder format"},
152140
},
153141
status_code=201,
154142
)
155143
async def add_model_record(
156-
config: Annotated[
157-
AnyModelConfig, Body(description="Model config", discriminator="type")
158-
]
144+
config: Annotated[AnyModelConfig, Body(description="Model config", discriminator="type")]
159145
) -> AnyModelConfig:
160146
"""
161147
Add a model using the configuration information appropriate for its type.

invokeai/backend/model_manager/config.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,7 @@ class ModelConfigBase(BaseModel):
115115
description="current fasthash of model contents", default=None
116116
) # if model is converted or otherwise modified, this will hold updated hash
117117
description: Optional[str] = Field(default=None)
118-
source: Optional[str] = Field(
119-
description="Model download source (URL or repo_id)", default=None
120-
)
118+
source: Optional[str] = Field(description="Model download source (URL or repo_id)", default=None)
121119

122120
model_config = ConfigDict(
123121
use_enum_values=False,
@@ -251,19 +249,13 @@ class T2IConfig(ModelConfigBase):
251249
format: Literal[ModelFormat.Diffusers]
252250

253251

254-
_ONNXConfig = Annotated[
255-
Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")
256-
]
252+
_ONNXConfig = Annotated[Union[ONNXSD1Config, ONNXSD2Config], Field(discriminator="base")]
257253
_ControlNetConfig = Annotated[
258254
Union[ControlNetDiffusersConfig, ControlNetCheckpointConfig],
259255
Field(discriminator="format"),
260256
]
261-
_VaeConfig = Annotated[
262-
Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")
263-
]
264-
_MainModelConfig = Annotated[
265-
Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")
266-
]
257+
_VaeConfig = Annotated[Union[VaeDiffusersConfig, VaeCheckpointConfig], Field(discriminator="format")]
258+
_MainModelConfig = Annotated[Union[MainDiffusersConfig, MainCheckpointConfig], Field(discriminator="format")]
267259

268260
AnyModelConfig = Union[
269261
_MainModelConfig,

tests/app/services/model_records/test_model_records_sql.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,7 @@ def test_filter(store: ModelRecordServiceBase):
159159
assert len(matches) == 1
160160
assert matches[0].name == "config3"
161161
assert matches[0].key == sha256("config3".encode("utf-8")).hexdigest()
162-
assert isinstance(
163-
matches[0].type, ModelType
164-
) # This tests that we get proper enums back
162+
assert isinstance(matches[0].type, ModelType) # This tests that we get proper enums back
165163

166164
matches = store.search_by_hash("CONFIG1HASH")
167165
assert len(matches) == 1

0 commit comments

Comments
 (0)