Skip to content

Commit 7146a75

Browse files
committed
: Merge branch 'master' of https://github.com/zsdonghao/SRGAN
2 parents 2169587 + fa402df commit 7146a75

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

download_imagenet.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,55 @@
1+
import argparse
12
import os
23
import urllib
34
import numpy as np
5+
from PIL import Image
46

57
from joblib import Parallel, delayed
68

79

810
def download_image(download_str, save_dir):
911
img_name, img_url = download_str.strip().split('\t')
1012
save_img = os.path.join(save_dir, "{}.jpg".format(img_name))
13+
downloaded = False
1114
try:
1215
if not os.path.isfile(save_img):
1316
print("Downloading {} to {}.jpg".format(img_url, img_name))
1417
urllib.urlretrieve(img_url, save_img)
18+
19+
# Check size of the images
20+
downloaded = True
21+
with Image.open(save_img) as img:
22+
width, height = img.size
23+
if width < 500 or height < 500:
24+
os.remove(save_img)
25+
print("Remove downloaded images (w:{}, h:{})".format(width, height))
1526
else:
1627
print("Already downloaded {}".format(save_img))
1728
except Exception:
18-
print("File not exists.")
29+
if not downloaded:
30+
print("Cannot download.")
31+
else:
32+
os.remove(save_img)
33+
print("Remove failed, downloaded images.")
1934

2035

2136
def main():
37+
parser = argparse.ArgumentParser()
38+
parser.add_argument("--img_url_file", type=str, required=True,
39+
help="File that contains list of image IDs and urls.")
40+
parser.add_argument("--output_dir", type=str, required=True,
41+
help="Directory where to save outputs.")
42+
parser.add_argument("--n_download_urls", type=int, default=20000,
43+
help="Directory where to save outputs.")
44+
args = parser.parse_args()
45+
2246
np.random.seed(123456)
23-
url_file = "/data/imagenet/fall11_urls.txt"
24-
save_dir = "/data/imagenet/"
25-
n_download_imgs = 20000
2647

27-
with open(url_file) as f:
48+
with open(args.img_url_file) as f:
2849
lines = f.readlines()
29-
lines = np.random.choice(lines, size=n_download_imgs, replace=False)
50+
lines = np.random.choice(lines, size=args.n_download_urls, replace=False)
3051

31-
Parallel(n_jobs=12)(delayed(download_image)(line, save_dir) for line in lines)
52+
Parallel(n_jobs=12)(delayed(download_image)(line, args.output_dir) for line in lines)
3253

3354

3455
if __name__ == "__main__":

0 commit comments

Comments
 (0)