-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_transform_utils.py
78 lines (66 loc) · 2.64 KB
/
image_transform_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import sys
import numpy as np
from multiprocessing import Pool
from functools import partial
from matplotlib import pyplot as plt
def mask_random_pixels(imgs, mask_ratio=.2):
mask = np.random.binomial(1, 1 - mask_ratio, imgs.shape[:-1])
mask = mask.reshape([*mask.shape, 1])
mask = np.concatenate([mask] * 3, axis=-1)
return mask
def generate_masks(imgs, n_masks, mask_size):
if len(imgs.shape) == 3:
imgs = imgs[np.newaxis, :]
mask = np.ones_like(imgs)
masked_patches = np.empty((len(imgs), mask_size[0], mask_size[1], 3*n_masks))
indices = np.random.randint([0, 0],
[mask.shape[1] - mask_size[0] - 1, mask.shape[2] - mask_size[1] - 1],
[len(imgs), n_masks, 2])
for idx, start_points in enumerate(indices):
for idx1, start in enumerate(start_points):
mask[idx, start[0]:start[0] + mask_size[0], start[1]:start[1] + mask_size[1]] = 0
masked_patches[idx, :, :, idx1:idx1+3] = imgs[idx, start[0]:start[0] + mask_size[0], start[1]:start[1] + mask_size[1], :]
return mask, masked_patches, indices
def mask_random_areas(imgs, n_masks=10, mask_size=(3, 3), parallelise=False, n_pools=4):
if not parallelise:
return generate_masks(imgs, n_masks=n_masks, mask_size=mask_size)
func = partial(generate_masks, n_masks=n_masks, mask_size=mask_size)
processing_pools = Pool(n_pools)
offset = len(imgs)//n_pools
img_subsets = [imgs[idx:idx+offset] for idx in range(0, len(imgs), offset)]
return np.concatenate(processing_pools.map(func, img_subsets), axis=0)
def visualise_masks(path):
if not os.path.exists(path):
print(f'Are you sure {path} is a valid file path?')
print('Type a valid file path below or hit RETURN to exit')
path = input()
if path == '':
exit()
visualise_masks(path)
else:
img = plt.imread(path)
img = np.array([img, img, img, img])
mask1 = mask_random_pixels(img)
mask2 = mask_random_areas(img, mask_size=(100, 100))
img1 = mask1 * img
img2 = mask2[0] * img
plt.subplot(311)
plt.title('Original')
plt.imshow(img[0])
plt.subplot(312)
plt.title('Pixels')
plt.imshow(img1[0])
plt.subplot(313)
plt.title('Patches')
plt.imshow(img2[0])
plt.show()
if __name__ == '__main__':
if len(sys.argv) > 1:
img_path = sys.argv[1]
else:
print('Type a valid file path below or hit RETURN to exit')
img_path = input()
if img_path == '':
exit()
visualise_masks(img_path)