Skip to content
This repository was archived by the owner on Jan 3, 2024. It is now read-only.

Commit d86b096

Browse files
committed
Preprocessing PASCAL based datasets is now supported
1 parent 832681e commit d86b096

File tree

3 files changed

+59
-47
lines changed

3 files changed

+59
-47
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from os.path import dirname, basename, isfile, join
2+
import glob
3+
modules = glob.glob(join(dirname(__file__), "*.py"))
4+
__all__ = [basename(f)[:-3] for f in modules if isfile(f)
5+
and not f.endswith('__init__.py')]

pytorch-superpixels/list_loader/list_loader.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,26 @@
22

33

44
class image_list:
5-
def __init__(self, dataset, path, split=None):
6-
datasets = {
7-
'pascal': "ImageSets/Segmentation/",
8-
}
5+
def __init__(self, dataset, path, split='trainval'):
6+
datasets = ['pascal']
97
splits = [None, 'train', 'val', 'trainval']
8+
9+
datasets = {'pascal': {'listPath': 'ImageSets/Segmentation/',
10+
'imagePath': 'JPEGImages',
11+
'targetPath': 'SegmentationClass'}
12+
}
13+
1014
if dataset in datasets and split in splits:
1115
self.split = split
1216
self.dataset = dataset
13-
self.path = join(path, datasets[dataset])
17+
self.path = path
18+
self.listPath = join(path, datasets[dataset]['listPath'])
19+
self.imagePath = join(path, datasets[dataset]['imagePath'])
20+
self.targetPath = join(path, datasets[dataset]['targetPath'])
1421
self.list = []
1522
else:
1623
raise ValueError("Invalid dataset and/or split")
1724

18-
if split is None:
19-
list_path = join(self.path, "trainval.txt")
20-
else:
21-
list_path = join(self.path, split + ".txt")
25+
list_path = join(self.listPath, self.split + ".txt")
2226
self.list = tuple(open(list_path, "r"))
2327
self.list = [id_.rstrip() for id_ in self.list]

pytorch-superpixels/preprocess.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,47 @@
1-
import list_loader
1+
from skimage.io import imread
2+
from skimage.segmentation import slic
3+
from skimage.util import img_as_float
4+
from os.path import exists
5+
from os.path import join
6+
from tqdm import tqdm
7+
from os import mkdir
8+
import torch
29

3-
def create_masks(numSegments=100, limOverseg=None):
4-
# Generate image list
5-
image_list = list_loader.image_list('pascal',)
6-
image_list = image_list.list
7-
for image_number in tqdm(image_list):
10+
11+
def create_masks(imageList, numSegments=100, limOverseg=None):
12+
# Iterate through all images
13+
for image_number in tqdm(imageList.list):
814
# Load image/target pair
9-
image_name = image_number + ".jpg"
10-
target_name = image_number + ".png"
11-
image_path = join(root, "JPEGImages", image_name)
12-
target_path = join(root, "SegmentationClass/pre_encoded", target_name)
13-
image = img_as_float(io.imread(image_path))
14-
target = io.imread(target_path)
15+
image_path = join(imageList.imagePath, image_number + ".jpg")
16+
target_path = join(imageList.targetPath, image_number + ".png")
17+
image = img_as_float(imread(image_path))
18+
target = imread(target_path)
1519
target = torch.from_numpy(target)
16-
# Create mask for image/target pair
17-
mask, target_s = create_mask(
18-
image=image,
19-
target=target,
20-
numSegments=numSegments,
21-
limOverseg=limOverseg
22-
)
23-
24-
# Save for later
25-
image_save_dir = join(
26-
root,
27-
"SegmentationClass/{}_sp".format(numSegments)
28-
)
29-
target_s_save_dir = join(
30-
root,
31-
"SegmentationClass/pre_encoded_{}_sp".format(numSegments)
32-
)
33-
if not exists(image_save_dir):
34-
mkdir(image_save_dir)
35-
if not exists(target_s_save_dir):
36-
mkdir(target_s_save_dir)
37-
save_name = image_number + ".pt"
38-
image_save_path = join(image_save_dir, save_name)
39-
target_s_save_path = join(target_s_save_dir, save_name)
40-
torch.save(mask, image_save_path)
41-
torch.save(target_s, target_s_save_path)
20+
# Save paths
21+
saveDir = join(imageList.path, 'SuperPixels')
22+
maskDir = join(saveDir, '{}_sp_mask'.format(numSegments))
23+
targetDir = join(saveDir, '{}_sp_target'.format(numSegments))
24+
# Check that directories exist
25+
if not exists(saveDir):
26+
mkdir(saveDir)
27+
if not exists(maskDir):
28+
mkdir(maskDir)
29+
if not exists(targetDir):
30+
mkdir(targetDir)
31+
# Define save paths
32+
mask_save_path = join(maskDir, image_number + ".pt")
33+
target_save_path = join(targetDir, image_number + ".pt")
34+
# If they haven't already been made, make them
35+
if not exists(mask_save_path) and not exists(target_save_path):
36+
# Create mask for image/target pair
37+
mask, target_s = create_mask(
38+
image=image,
39+
target=target,
40+
numSegments=numSegments,
41+
limOverseg=limOverseg
42+
)
43+
torch.save(mask, mask_save_path)
44+
torch.save(target_s, target_save_path)
4245

4346

4447
def create_mask(image, target, numSegments, limOverseg):

0 commit comments

Comments
 (0)