Skip to content

Commit 151d42a

Browse files
committed
Update sample code
1 parent 159b7db commit 151d42a

File tree

8 files changed

+293
-248
lines changed

8 files changed

+293
-248
lines changed

samples/python/README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Download and install [git](https://github.com/dennisameling/git/releases/downloa
1818
### Step 2: Install basic Python dependencies:
1919
Run below commands in Windows terminal:
2020
```
21-
pip install requests wget tqdm importlib-metadata qai-hub qai_hub_models huggingface_hub Pillow numpy opencv-python torch torchvision torchaudio transformers diffusers
21+
pip install requests wget tqdm importlib-metadata qai-hub qai_hub_models huggingface_hub Pillow numpy opencv-python torch torchvision torchaudio transformers diffusers ultralytics==8.0.193
2222
```
2323

2424
### Step 3: Download QAI AppBuilder repository:
@@ -60,5 +60,8 @@ python stable_diffusion_v2_1\stable_diffusion_v2_1.py --prompt "spectacular view
6060
| inception_v3 | 2.28 | python inception_v3\inception_v3.py |
6161
| yolov8_det | 2.28 | python yolov8_det\yolov8_det.py |
6262
| unet_segmentation | 2.28 | python unet_segmentation\unet_segmentation.py |
63+
| openpose | 2.28 | python openpose\openpose.py |
64+
| lama_dilated | 2.28 | python lama_dilated\lama_dilated.py |
65+
| aotgan | 2.28 | python aotgan\aotgan.py |
6366

6467
*More models will be supported soon!*

samples/python/aotgan/aotgan.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,49 @@
33
# SPDX-License-Identifier: BSD-3-Clause
44
# ---------------------------------------------------------------------
55

6+
import sys
67
import os
8+
sys.path.append(".")
9+
sys.path.append("..")
10+
import utils.install as install
711
import numpy as np
812
import torch
913
import torchvision.transforms as transforms
1014

1115
from PIL import Image
1216
from PIL.Image import fromarray as ImageFromArray
13-
from torch.nn.functional import interpolate, pad
14-
from torchvision import transforms
15-
from typing import Callable, Dict, List, Tuple
16-
17+
from utils.image_processing import (
18+
preprocess_inputs
19+
)
1720
from qai_appbuilder import (QNNContext, Runtime, LogLevel, ProfilingLevel, PerfProfile, QNNConfig)
1821

19-
image_size = 512
20-
aotgan = None
21-
image_buffer = None
22+
####################################################################
23+
24+
MODEL_ID = "mn1w65o8m"
25+
MODEL_NAME = "aotgan"
26+
MODEL_HELP_URL = "https://github.com/quic/ai-engine-direct-helper/tree/main/samples/python/" + MODEL_NAME + "#" + MODEL_NAME + "-qnn-models"
27+
IMAGE_SIZE = 512
28+
29+
####################################################################
30+
31+
execution_ws = os.getcwd()
32+
qnn_dir = execution_ws + "\\qai_libs"
2233

34+
if not MODEL_NAME in execution_ws:
35+
execution_ws = execution_ws + "\\" + MODEL_NAME
36+
37+
model_dir = execution_ws + "\\models"
38+
madel_path = model_dir + "\\" + MODEL_NAME + ".bin"
39+
40+
####################################################################
41+
42+
image_buffer = None
43+
aotgan = None
2344

2445
def preprocess_PIL_image(image: Image) -> torch.Tensor:
2546
"""Convert a PIL image into a pyTorch tensor with range [0, 1] and shape NCHW."""
26-
transform = transforms.Compose([transforms.Resize(image_size), # bgr image
27-
transforms.CenterCrop(image_size),
47+
transform = transforms.Compose([transforms.Resize(IMAGE_SIZE), # bgr image
48+
transforms.CenterCrop(IMAGE_SIZE),
2849
transforms.PILToTensor()])
2950
img: torch.Tensor = transform(image) # type: ignore
3051
img = img.float().unsqueeze(0) / 255.0 # int 0 - 255 to float 0.0 - 1.0
@@ -37,49 +58,39 @@ def torch_tensor_to_PIL_image(data: torch.Tensor) -> Image:
3758
out = torch.clip(data, min=0.0, max=1.0)
3859
np_out = (out.detach().numpy() * 255).astype(np.uint8)
3960
return ImageFromArray(np_out)
40-
41-
def preprocess_inputs(
42-
pixel_values_or_image: Image,
43-
mask_pixel_values_or_image: Image,
44-
) -> Dict[str, torch.Tensor]:
45-
46-
NCHW_fp32_torch_frames = preprocess_PIL_image(pixel_values_or_image)
47-
NCHW_fp32_torch_masks = preprocess_PIL_image(mask_pixel_values_or_image)
48-
49-
# The number of input images should equal the number of input masks.
50-
if NCHW_fp32_torch_masks.shape[0] != 1:
51-
NCHW_fp32_torch_masks = NCHW_fp32_torch_masks.tile(
52-
(NCHW_fp32_torch_frames.shape[0], 1, 1, 1)
53-
)
54-
55-
# Mask input image
56-
image_masked = (
57-
NCHW_fp32_torch_frames * (1 - NCHW_fp32_torch_masks) + NCHW_fp32_torch_masks
58-
)
59-
60-
return {"image": image_masked, "mask": NCHW_fp32_torch_masks}
61-
62-
# AotGan class which inherited from the class QNNContext.
61+
62+
# LamaDilated class which inherited from the class QNNContext.
6363
class AotGan(QNNContext):
6464
def Inference(self, input_data, input_mask):
6565
input_datas=[input_data, input_mask]
6666
output_data = super().Inference(input_datas)[0]
6767
return output_data
68-
68+
69+
def model_download():
70+
ret = True
71+
72+
desc = f"Downloading {MODEL_NAME} model... "
73+
fail = f"\nFailed to download {MODEL_NAME} model. Please prepare the model according to the steps in below link:\n{MODEL_HELP_URL}"
74+
ret = install.download_qai_hubmodel(MODEL_ID, madel_path, desc=desc, fail=fail)
75+
76+
if not ret:
77+
exit()
78+
6979
def Init():
7080
global aotgan
7181

82+
model_download()
83+
7284
# Config AppBuilder environment.
7385
QNNConfig.Config(os.getcwd() + "\\qai_libs", Runtime.HTP, LogLevel.WARN, ProfilingLevel.BASIC)
7486

7587
# Instance for AotGan objects.
76-
aotgan_model = "models\\aotgan.bin"
77-
aotgan = AotGan("aotgan", aotgan_model)
88+
aotgan = AotGan("aotgan", madel_path)
7889

7990
def Inference(input_image_path, input_mask_path, output_image_path):
8091
global image_buffer
8192

82-
# Read and preprocess the image&mask.
93+
# Read and preprocess the image & mask.
8394
image = Image.open(input_image_path)
8495
mask = Image.open(input_mask_path)
8596
inputs = preprocess_inputs(image, mask)
@@ -89,26 +100,26 @@ def Inference(input_image_path, input_mask_path, output_image_path):
89100

90101
image_masked = np.transpose(image_masked, (0, 2, 3, 1))
91102
mask_torch = np.transpose(mask_torch, (0, 2, 3, 1))
92-
103+
93104
# Burst the HTP.
94105
PerfProfile.SetPerfProfileGlobal(PerfProfile.BURST)
95106

96107
# Run the inference.
97108
output_image = aotgan.Inference([image_masked], [mask_torch])
98-
109+
99110
# Reset the HTP.
100111
PerfProfile.RelPerfProfileGlobal()
101-
102-
# show%save the result
103-
output_image = torch.from_numpy(output_image)
104-
output_image = output_image.reshape(image_size, image_size, 3)
105-
output_image = torch.unsqueeze(output_image, 0)
112+
113+
# show & save the result
114+
output_image = torch.from_numpy(output_image)
115+
output_image = output_image.reshape(IMAGE_SIZE, IMAGE_SIZE, 3)
116+
output_image = torch.unsqueeze(output_image, 0)
106117
output_image = [torch_tensor_to_PIL_image(img) for img in output_image]
107118
image_buffer = output_image[0]
108119
image_buffer.save(output_image_path)
109-
image_buffer.show()
120+
image_buffer.show()
121+
image.show()
110122

111-
112123
def Release():
113124
global aotgan
114125

@@ -118,6 +129,7 @@ def Release():
118129

119130
Init()
120131

121-
Inference("input.png", "mask.png", "output.png")
132+
Inference(execution_ws + "\\input.png", execution_ws + "\\mask.png", execution_ws + "\\output.png")
122133

123134
Release()
135+

samples/python/fastsam_x/fastsam_x.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
# ---------------------------------------------------------------------
55

66
from __future__ import annotations
7-
7+
import sys
88
import os
9+
sys.path.append(".")
10+
sys.path.append("..")
11+
import utils.install as install
912
import numpy as np
1013
import math
1114
import torch
1215
import torchvision.transforms as transforms
13-
1416
from typing import Callable, Dict, List, Tuple
1517
from PIL import Image
1618
from PIL.Image import fromarray as ImageFromArray
@@ -23,7 +25,27 @@
2325

2426
from qai_appbuilder import (QNNContext, Runtime, LogLevel, ProfilingLevel, PerfProfile, QNNConfig)
2527

28+
####################################################################
29+
30+
MODEL_ID = "mn7x79pvq"
31+
MODEL_NAME = "fastsam_x"
32+
MODEL_HELP_URL = "https://github.com/quic/ai-engine-direct-helper/tree/main/samples/python/" + MODEL_NAME + "#" + MODEL_NAME + "-qnn-models"
33+
34+
####################################################################
35+
36+
execution_ws = os.getcwd()
37+
qnn_dir = execution_ws + "\\qai_libs"
38+
39+
if not MODEL_NAME in execution_ws:
40+
execution_ws = execution_ws + "\\" + MODEL_NAME
41+
42+
model_dir = execution_ws + "\\models"
43+
madel_path = model_dir + "\\" + MODEL_NAME + ".bin"
44+
45+
####################################################################
46+
2647
fastsam = None
48+
2749
confidence: float = 0.4,
2850
iou_threshold: float = 0.9,
2951
retina_masks: bool = True,
@@ -146,16 +168,27 @@ def Inference(self, input_data):
146168
input_datas=[input_data]
147169
output_data = super().Inference(input_datas)
148170
return output_data
149-
171+
172+
def model_download():
173+
ret = True
174+
175+
desc = f"Downloading {MODEL_NAME} model... "
176+
fail = f"\nFailed to download {MODEL_NAME} model. Please prepare the model according to the steps in below link:\n{MODEL_HELP_URL}"
177+
ret = install.download_qai_hubmodel(MODEL_ID, madel_path, desc=desc, fail=fail)
178+
179+
if not ret:
180+
exit()
181+
150182
def Init():
151183
global fastsam
152184

185+
model_download()
186+
153187
# Config AppBuilder environment.
154188
QNNConfig.Config(os.getcwd() + "\\qai_libs", Runtime.HTP, LogLevel.WARN, ProfilingLevel.BASIC)
155189

156190
# Instance for FastSam_x objects.
157-
fastsam_model = "models\\fastsam_x.bin"
158-
fastsam = FastSam("fastsam", fastsam_model)
191+
fastsam = FastSam("fastsam", madel_path)
159192

160193
def Inference(input_image_path, output_image_path):
161194
global confidence, iou_threshold, retina_masks, model_image_input_shape
@@ -188,11 +221,11 @@ def Inference(input_image_path, output_image_path):
188221
torch.tensor(preds[4]).reshape(1, 105, 20, 20),
189222
torch.tensor(preds[5]).reshape(1, 37, 8400)
190223
]
191-
224+
192225
preds = tuple(
193226
(preds[5], tuple(([preds[2], preds[3], preds[4]], preds[1], preds[0])))
194227
)
195-
228+
196229
p = ops.non_max_suppression(
197230
preds[0],
198231
0.4,
@@ -202,7 +235,7 @@ def Inference(input_image_path, output_image_path):
202235
nc=1, # set to 1 class since SAM has no class predictions
203236
classes=None,
204237
)
205-
238+
206239
full_box = torch.zeros(p[0].shape[1], device=p[0].device)
207240
full_box[2], full_box[3], full_box[4], full_box[6:] = (
208241
Img.shape[3],
@@ -266,9 +299,12 @@ def Inference(input_image_path, output_image_path):
266299
binary_mask = segmented_result[0].masks.data.squeeze().cpu().numpy().astype(np.uint8)
267300
binary_mask = binary_mask * 255
268301
mask_image = Image.fromarray(binary_mask)
269-
mask_image.show()
302+
303+
#save and display the output_image
270304
mask_image.save(output_image_path)
271-
305+
mask_image.show()
306+
307+
272308
def Release():
273309
global fastsam
274310

@@ -278,6 +314,7 @@ def Release():
278314

279315
Init()
280316

281-
Inference("input.jpg", "output.jpg")
317+
Inference(execution_ws + "\\input.jpg", execution_ws + "\\output.jpg")
282318

283319
Release()
320+

0 commit comments

Comments
 (0)