33# SPDX-License-Identifier: BSD-3-Clause
44# ---------------------------------------------------------------------
55
6+ import sys
67import os
8+ sys .path .append ("." )
9+ sys .path .append (".." )
10+ import utils .install as install
711import numpy as np
812import torch
913import torchvision .transforms as transforms
1014
1115from PIL import Image
1216from PIL .Image import fromarray as ImageFromArray
13- from torch .nn .functional import interpolate , pad
14- from torchvision import transforms
15- from typing import Callable , Dict , List , Tuple
16-
17+ from utils .image_processing import (
18+ preprocess_inputs
19+ )
1720from qai_appbuilder import (QNNContext , Runtime , LogLevel , ProfilingLevel , PerfProfile , QNNConfig )
1821
19- image_size = 512
20- aotgan = None
21- image_buffer = None
22+ ####################################################################
23+
24+ MODEL_ID = "mn1w65o8m"
25+ MODEL_NAME = "aotgan"
26+ MODEL_HELP_URL = "https://github.com/quic/ai-engine-direct-helper/tree/main/samples/python/" + MODEL_NAME + "#" + MODEL_NAME + "-qnn-models"
27+ IMAGE_SIZE = 512
28+
29+ ####################################################################
30+
31+ execution_ws = os .getcwd ()
32+ qnn_dir = execution_ws + "\\ qai_libs"
2233
34+ if not MODEL_NAME in execution_ws :
35+ execution_ws = execution_ws + "\\ " + MODEL_NAME
36+
37+ model_dir = execution_ws + "\\ models"
38+ madel_path = model_dir + "\\ " + MODEL_NAME + ".bin"
39+
40+ ####################################################################
41+
42+ image_buffer = None
43+ aotgan = None
2344
2445def preprocess_PIL_image (image : Image ) -> torch .Tensor :
2546 """Convert a PIL image into a pyTorch tensor with range [0, 1] and shape NCHW."""
26- transform = transforms .Compose ([transforms .Resize (image_size ), # bgr image
27- transforms .CenterCrop (image_size ),
47+ transform = transforms .Compose ([transforms .Resize (IMAGE_SIZE ), # bgr image
48+ transforms .CenterCrop (IMAGE_SIZE ),
2849 transforms .PILToTensor ()])
2950 img : torch .Tensor = transform (image ) # type: ignore
3051 img = img .float ().unsqueeze (0 ) / 255.0 # int 0 - 255 to float 0.0 - 1.0
@@ -37,49 +58,39 @@ def torch_tensor_to_PIL_image(data: torch.Tensor) -> Image:
3758 out = torch .clip (data , min = 0.0 , max = 1.0 )
3859 np_out = (out .detach ().numpy () * 255 ).astype (np .uint8 )
3960 return ImageFromArray (np_out )
40-
41- def preprocess_inputs (
42- pixel_values_or_image : Image ,
43- mask_pixel_values_or_image : Image ,
44- ) -> Dict [str , torch .Tensor ]:
45-
46- NCHW_fp32_torch_frames = preprocess_PIL_image (pixel_values_or_image )
47- NCHW_fp32_torch_masks = preprocess_PIL_image (mask_pixel_values_or_image )
48-
49- # The number of input images should equal the number of input masks.
50- if NCHW_fp32_torch_masks .shape [0 ] != 1 :
51- NCHW_fp32_torch_masks = NCHW_fp32_torch_masks .tile (
52- (NCHW_fp32_torch_frames .shape [0 ], 1 , 1 , 1 )
53- )
54-
55- # Mask input image
56- image_masked = (
57- NCHW_fp32_torch_frames * (1 - NCHW_fp32_torch_masks ) + NCHW_fp32_torch_masks
58- )
59-
60- return {"image" : image_masked , "mask" : NCHW_fp32_torch_masks }
61-
62- # AotGan class which inherited from the class QNNContext.
61+
62+ # LamaDilated class which inherited from the class QNNContext.
6363class AotGan (QNNContext ):
6464 def Inference (self , input_data , input_mask ):
6565 input_datas = [input_data , input_mask ]
6666 output_data = super ().Inference (input_datas )[0 ]
6767 return output_data
68-
68+
69+ def model_download ():
70+ ret = True
71+
72+ desc = f"Downloading { MODEL_NAME } model... "
73+ fail = f"\n Failed to download { MODEL_NAME } model. Please prepare the model according to the steps in below link:\n { MODEL_HELP_URL } "
74+ ret = install .download_qai_hubmodel (MODEL_ID , madel_path , desc = desc , fail = fail )
75+
76+ if not ret :
77+ exit ()
78+
6979def Init ():
7080 global aotgan
7181
82+ model_download ()
83+
7284 # Config AppBuilder environment.
7385 QNNConfig .Config (os .getcwd () + "\\ qai_libs" , Runtime .HTP , LogLevel .WARN , ProfilingLevel .BASIC )
7486
7587 # Instance for AotGan objects.
76- aotgan_model = "models\\ aotgan.bin"
77- aotgan = AotGan ("aotgan" , aotgan_model )
88+ aotgan = AotGan ("aotgan" , madel_path )
7889
7990def Inference (input_image_path , input_mask_path , output_image_path ):
8091 global image_buffer
8192
82- # Read and preprocess the image& mask.
93+ # Read and preprocess the image & mask.
8394 image = Image .open (input_image_path )
8495 mask = Image .open (input_mask_path )
8596 inputs = preprocess_inputs (image , mask )
@@ -89,26 +100,26 @@ def Inference(input_image_path, input_mask_path, output_image_path):
89100
90101 image_masked = np .transpose (image_masked , (0 , 2 , 3 , 1 ))
91102 mask_torch = np .transpose (mask_torch , (0 , 2 , 3 , 1 ))
92-
103+
93104 # Burst the HTP.
94105 PerfProfile .SetPerfProfileGlobal (PerfProfile .BURST )
95106
96107 # Run the inference.
97108 output_image = aotgan .Inference ([image_masked ], [mask_torch ])
98-
109+
99110 # Reset the HTP.
100111 PerfProfile .RelPerfProfileGlobal ()
101-
102- # show% save the result
103- output_image = torch .from_numpy (output_image )
104- output_image = output_image .reshape (image_size , image_size , 3 )
105- output_image = torch .unsqueeze (output_image , 0 )
112+
113+ # show & save the result
114+ output_image = torch .from_numpy (output_image )
115+ output_image = output_image .reshape (IMAGE_SIZE , IMAGE_SIZE , 3 )
116+ output_image = torch .unsqueeze (output_image , 0 )
106117 output_image = [torch_tensor_to_PIL_image (img ) for img in output_image ]
107118 image_buffer = output_image [0 ]
108119 image_buffer .save (output_image_path )
109- image_buffer .show ()
120+ image_buffer .show ()
121+ image .show ()
110122
111-
112123def Release ():
113124 global aotgan
114125
@@ -118,6 +129,7 @@ def Release():
118129
119130Init ()
120131
121- Inference (" input.png" , " mask.png" , " output.png" )
132+ Inference (execution_ws + " \\ input.png" , execution_ws + " \\ mask.png" , execution_ws + " \\ output.png" )
122133
123134Release ()
135+
0 commit comments