6060from contextlib import contextmanager , nullcontext
6161import time
6262import math
63+ import re
6364
6465from ldm .util import instantiate_from_config
6566from ldm .models .diffusion .ddim import DDIMSampler
@@ -171,7 +172,6 @@ def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
171172
172173 # make directories and establish names for the output files
173174 os .makedirs (outdir , exist_ok = True )
174- base_count = len (os .listdir (outdir ))- 1
175175
176176 start_code = None
177177 if self .fixed_code :
@@ -185,7 +185,7 @@ def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
185185 sampler = self .sampler
186186 images = list ()
187187 seeds = list ()
188-
188+ filename = None
189189 tic = time .time ()
190190
191191 with torch .no_grad ():
@@ -218,10 +218,11 @@ def txt2img(self,prompt,outdir=None,batch_size=None,iterations=None,
218218 if not grid :
219219 for x_sample in x_samples_ddim :
220220 x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
221- filename = os .path .join (outdir , f"{ base_count :05} .png" )
221+ filename = self ._unique_filename (outdir ,previousname = filename ,
222+ seed = seed ,isbatch = (batch_size > 1 ))
223+ assert not os .path .exists (filename )
222224 Image .fromarray (x_sample .astype (np .uint8 )).save (filename )
223225 images .append ([filename ,seed ])
224- base_count += 1
225226 else :
226227 all_samples .append (x_samples_ddim )
227228 seeds .append (seed )
@@ -283,7 +284,6 @@ def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=Non
283284
284285 # make directories and establish names for the output files
285286 os .makedirs (outdir , exist_ok = True )
286- base_count = len (os .listdir (outdir ))- 1
287287
288288 assert os .path .isfile (init_img )
289289 init_image = self ._load_img (init_img ).to (self .device )
@@ -304,7 +304,8 @@ def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=Non
304304
305305 images = list ()
306306 seeds = list ()
307-
307+ filename = None
308+
308309 tic = time .time ()
309310
310311 with torch .no_grad ():
@@ -333,10 +334,10 @@ def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=Non
333334 if not grid :
334335 for x_sample in x_samples :
335336 x_sample = 255. * rearrange (x_sample .cpu ().numpy (), 'c h w -> h w c' )
336- filename = os .path .join (outdir , f"{ base_count :05} .png" )
337+ filename = self ._unique_filename (outdir ,filename ,seed = seed ,isbatch = (batch_size > 1 ))
338+ assert not os .path .exists (filename )
337339 Image .fromarray (x_sample .astype (np .uint8 )).save (filename )
338340 images .append ([filename ,seed ])
339- base_count += 1
340341 else :
341342 all_samples .append (x_samples )
342343 seeds .append (seed )
@@ -357,7 +358,6 @@ def img2img(self,prompt,outdir=None,init_img=None,batch_size=None,iterations=Non
357358
358359 def _make_grid (self ,samples ,seeds ,batch_size ,iterations ,outdir ):
359360 images = list ()
360- base_count = len (os .listdir (outdir ))- 1
361361 n_rows = batch_size if batch_size > 1 else int (math .sqrt (batch_size * iterations ))
362362 # save as grid
363363 grid = torch .stack (samples , 0 )
@@ -366,7 +366,7 @@ def _make_grid(self,samples,seeds,batch_size,iterations,outdir):
366366
367367 # to image
368368 grid = 255. * rearrange (grid , 'c h w -> h w c' ).cpu ().numpy ()
369- filename = os . path . join (outdir , f" { base_count :05 } .png" )
369+ filename = self . _unique_filename (outdir ,seed = seeds [ 0 ], grid_count = batch_size * iterations )
370370 Image .fromarray (grid .astype (np .uint8 )).save (filename )
371371 for s in seeds :
372372 images .append ([filename ,s ])
@@ -430,3 +430,40 @@ def _load_img(self,path):
430430 image = image [None ].transpose (0 , 3 , 1 , 2 )
431431 image = torch .from_numpy (image )
432432 return 2. * image - 1.
433+
434+ def _unique_filename (self ,outdir ,previousname = None ,seed = 0 ,isbatch = False ,grid_count = None ):
435+ revision = 1
436+
437+ if previousname is None :
438+ # count up until we find an unfilled slot
439+ dir_list = [a .split ('.' ,1 )[0 ] for a in os .listdir (outdir )]
440+ uniques = dict .fromkeys (dir_list ,True )
441+ basecount = 1
442+ while f'{ basecount :06} ' in uniques :
443+ basecount += 1
444+ if grid_count is not None :
445+ grid_label = f'grid#1-{ grid_count } '
446+ filename = f'{ basecount :06} .{ seed } .{ grid_label } .png'
447+ elif isbatch :
448+ filename = f'{ basecount :06} .{ seed } .01.png'
449+ else :
450+ filename = f'{ basecount :06} .{ seed } .png'
451+
452+ return os .path .join (outdir ,filename )
453+
454+ else :
455+ previousname = os .path .basename (previousname )
456+ x = re .match ('^(\d+)\..*\.png' ,previousname )
457+ if not x :
458+ return self ._unique_filename (outdir ,previousname ,seed )
459+
460+ basecount = int (x .groups ()[0 ])
461+ series = 0
462+ finished = False
463+ while not finished :
464+ series += 1
465+ filename = f'{ basecount :06} .{ seed } .png'
466+ if isbatch or os .path .exists (os .path .join (outdir ,filename )):
467+ filename = f'{ basecount :06} .{ seed } .{ series :02} .png'
468+ finished = not os .path .exists (os .path .join (outdir ,filename ))
469+ return os .path .join (outdir ,filename )
0 commit comments