@@ -67,16 +67,10 @@ def test_add(store: ModelRecordServiceBase):
6767def test_dup (store : ModelRecordServiceBase ):
6868 config = example_config ()
6969 store .add_model ("key1" , example_config ())
70- try :
70+ with pytest . raises ( DuplicateModelException ) :
7171 store .add_model ("key1" , config )
72- assert False , "Duplicate model key should have been caught"
73- except DuplicateModelException :
74- assert True
75- try :
72+ with pytest .raises (DuplicateModelException ):
7673 store .add_model ("key2" , config )
77- assert False , "Duplicate model path should have been caught"
78- except DuplicateModelException :
79- assert True
8074
8175
8276def test_update (store : ModelRecordServiceBase ):
@@ -90,32 +84,21 @@ def test_update(store: ModelRecordServiceBase):
9084 new_config = store .get_model ("key1" )
9185 assert new_config .name == "new name"
9286
93- try :
87+
88+ def test_unknown_key (store : ModelRecordServiceBase ):
89+ config = example_config ()
90+ store .add_model ("key1" , config )
91+ with pytest .raises (UnknownModelException ):
9492 store .update_model ("unknown_key" , config )
95- assert False , "expected UnknownModelException"
96- except UnknownModelException :
97- assert True
9893
9994
10095def test_delete (store : ModelRecordServiceBase ):
10196 config = example_config ()
10297 store .add_model ("key1" , config )
10398 config = store .get_model ("key1" )
10499 store .del_model ("key1" )
105- try :
100+ with pytest . raises ( UnknownModelException ) :
106101 config = store .get_model ("key1" )
107- assert False , "expected fetch of deleted model to raise exception"
108- except UnknownModelException :
109- assert True
110-
111- # a bug in sqlite3 in python 3.9 prevents DEL from returning number of
112- # deleted rows!
113- if sys .version_info .major == 3 and sys .version_info .minor > 9 :
114- try :
115- store .del_model ("unknown" )
116- assert False , "expected delete of unknown model to raise exception"
117- except UnknownModelException :
118- assert True
119102
120103
121104def test_exists (store : ModelRecordServiceBase ):
@@ -167,3 +150,62 @@ def test_filter(store: ModelRecordServiceBase):
167150
168151 matches = store .all_models ()
169152 assert len (matches ) == 3
153+
154+
155+ def test_filter_2 (store : ModelRecordServiceBase ):
156+ config1 = DiffusersConfig (
157+ path = "/tmp/config1" ,
158+ name = "config1" ,
159+ base = BaseModelType ("sd-1" ),
160+ type = ModelType ("main" ),
161+ original_hash = "CONFIG1HASH" ,
162+ )
163+ config2 = DiffusersConfig (
164+ path = "/tmp/config2" ,
165+ name = "config2" ,
166+ base = BaseModelType ("sd-1" ),
167+ type = ModelType ("main" ),
168+ original_hash = "CONFIG2HASH" ,
169+ )
170+ config3 = DiffusersConfig (
171+ path = "/tmp/config3" ,
172+ name = "dup_name1" ,
173+ base = BaseModelType ("sd-2" ),
174+ type = ModelType ("main" ),
175+ original_hash = "CONFIG3HASH" ,
176+ )
177+ config4 = DiffusersConfig (
178+ path = "/tmp/config4" ,
179+ name = "dup_name1" ,
180+ base = BaseModelType ("sd-2" ),
181+ type = ModelType ("main" ),
182+ original_hash = "CONFIG3HASH" ,
183+ )
184+ config5 = VaeDiffusersConfig (
185+ path = "/tmp/config5" ,
186+ name = "dup_name1" ,
187+ base = BaseModelType ("sd-1" ),
188+ type = ModelType ("vae" ),
189+ original_hash = "CONFIG3HASH" ,
190+ )
191+ for c in config1 , config2 , config3 , config4 , config5 :
192+ store .add_model (sha256 (c .path .encode ("utf-8" )).hexdigest (), c )
193+
194+ matches = store .search_by_name (
195+ model_type = ModelType ("main" ),
196+ model_name = "dup_name1" ,
197+ )
198+ assert len (matches ) == 2
199+
200+ matches = store .search_by_name (
201+ base_model = BaseModelType ("sd-1" ),
202+ model_type = ModelType ("main" ),
203+ )
204+ assert len (matches ) == 2
205+
206+ matches = store .search_by_name (
207+ base_model = BaseModelType ("sd-1" ),
208+ model_type = ModelType ("vae" ),
209+ model_name = "dup_name1" ,
210+ )
211+ assert len (matches ) == 1
0 commit comments