|
1 | | -import argparse |
2 | | -from PIL import Image |
3 | 1 | import torch |
4 | | -from diffusers.utils import check_min_version |
| 2 | +from diffusers.utils import load_image, check_min_version |
5 | 3 | from controlnet_flux import FluxControlNetModel |
| 4 | +from transformer_flux import FluxTransformer2DModel |
6 | 5 | from pipeline_flux_controlnet_inpaint import FluxControlNetInpaintingPipeline |
7 | 6 |
|
8 | 7 | check_min_version("0.30.2") |
9 | 8 |
|
10 | | -def main(image_path, mask_path, prompt): |
11 | | - controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) |
12 | | - pipe = FluxControlNetInpaintingPipeline.from_pretrained( |
13 | | - "black-forest-labs/FLUX.1-dev", |
14 | | - controlnet=controlnet, |
15 | | - torch_dtype=torch.bfloat16 |
16 | | - ).to("cuda") |
17 | | - |
18 | | - size = (768, 768) |
19 | | - image = Image.open(image_path).convert("RGB").resize(size) |
20 | | - mask = Image.open(mask_path).convert("RGB").resize(size) |
| 9 | +# Set image path , mask path and prompt |
| 10 | +image_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket.png', |
| 11 | +mask_path='https://huggingface.co/alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha/resolve/main/images/bucket_mask.jpeg', |
| 12 | +prompt='a person wearing a white shoe, carrying a white bucket with text "FLUX" on it' |
21 | 13 |
|
22 | | - generator = torch.Generator(device="cuda").manual_seed(48) |
23 | | - result = pipe( |
24 | | - prompt=prompt, |
25 | | - height=size[1], |
26 | | - width=size[0], |
27 | | - control_image=image, |
28 | | - control_mask=mask, |
29 | | - num_inference_steps=28, |
30 | | - generator=generator, |
31 | | - controlnet_conditioning_scale=0.95, |
32 | | - guidance_scale=3.5, |
33 | | - negative_prompt="", |
34 | | - true_guidance_scale=3.5 |
35 | | - ).images[0] |
36 | | - |
37 | | - result.save('flux_inpaint.png') |
38 | | - print("Successfully inpainted image") |
| 14 | +# Build pipeline |
| 15 | +controlnet = FluxControlNetModel.from_pretrained("alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Alpha", torch_dtype=torch.bfloat16) |
| 16 | +transformer = FluxTransformer2DModel.from_pretrained( |
| 17 | + "black-forest-labs/FLUX.1-dev", subfolder='transformer', torch_dytpe=torch.bfloat16 |
| 18 | + ) |
| 19 | +pipe = FluxControlNetInpaintingPipeline.from_pretrained( |
| 20 | + "black-forest-labs/FLUX.1-dev", |
| 21 | + controlnet=controlnet, |
| 22 | + transformer=transformer, |
| 23 | + torch_dtype=torch.bfloat16 |
| 24 | +).to("cuda") |
| 25 | +pipe.transformer.to(torch.bfloat16) |
| 26 | +pipe.controlnet.to(torch.bfloat16) |
39 | 27 |
|
40 | | -if __name__ == "__main__": |
41 | | - parser = argparse.ArgumentParser(description="Image Inpainting Program") |
42 | | - parser.add_argument("-i", "--image", required=True, help="Path to the input image") |
43 | | - parser.add_argument("-m", "--mask", required=True, help="Path to the input mask") |
44 | | - parser.add_argument("-p", "--prompt", required=True, help="Prompt for inpainting") |
45 | | - args = parser.parse_args() |
46 | | - main(args.image, args.mask, args.prompt) |
| 28 | +# Load image and mask |
| 29 | +size = (768, 768) |
| 30 | +image = load_image(image_path).convert("RGB").resize(size) |
| 31 | +mask = load_image(mask_path).convert("RGB").resize(size) |
| 32 | +generator = torch.Generator(device="cuda").manual_seed(24) |
| 33 | + |
| 34 | +# Inpaint |
| 35 | +result = pipe( |
| 36 | + prompt=prompt, |
| 37 | + height=size[1], |
| 38 | + width=size[0], |
| 39 | + control_image=image, |
| 40 | + control_mask=mask, |
| 41 | + num_inference_steps=28, |
| 42 | + generator=generator, |
| 43 | + controlnet_conditioning_scale=0.9, |
| 44 | + guidance_scale=3.5, |
| 45 | + negative_prompt="", |
| 46 | + true_guidance_scale=3.5 |
| 47 | +).images[0] |
| 48 | + |
| 49 | +result.save('flux_inpaint.png') |
| 50 | +print("Successfully inpaint image") |
0 commit comments