-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
123 lines (100 loc) · 3.57 KB
/
utils.py
File metadata and controls
123 lines (100 loc) · 3.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import queue
import threading
import zipfile
import cv2
import numpy as np
import pickle
import tqdm_utils
def image_center_crop(img):
h, w = img.shape[0], img.shape[1]
pad_left = 0
pad_right = 0
pad_top = 0
pad_bottom = 0
if h > w:
diff = h - w
pad_top = diff - diff // 2
pad_bottom = diff // 2
else:
diff = w - h
pad_left = diff - diff // 2
pad_right = diff // 2
return img[pad_top:h-pad_bottom, pad_left:w-pad_right, :]
def decode_image_from_buf(buf):
img = cv2.imdecode(np.asarray(bytearray(buf), dtype=np.uint8), 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
def crop_and_preprocess(img, input_shape, preprocess_for_model):
img = image_center_crop(img) # take center crop
img = cv2.resize(img, input_shape) # resize for our model
img = img.astype("float32") # prepare for normalization
img = preprocess_for_model(img) # preprocess for model
return img
def apply_model(zip_fn, model, preprocess_for_model, extensions=(".jpg",), input_shape=(224, 224), batch_size=32):
# queue for cropped images
q = queue.Queue(maxsize=batch_size * 10)
# when read thread put all images in queue
read_thread_completed = threading.Event()
# time for read thread to die
kill_read_thread = threading.Event()
def reading_thread(zip_fn):
zf = zipfile.ZipFile(zip_fn)
for fn in tqdm_utils.tqdm_notebook_failsafe(zf.namelist()):
if kill_read_thread.is_set():
break
if os.path.splitext(fn)[-1] in extensions:
buf = zf.read(fn) # read raw bytes from zip for fn
img = decode_image_from_buf(buf) # decode raw bytes
img = crop_and_preprocess(img, input_shape, preprocess_for_model)
while True:
try:
q.put((os.path.split(fn)[-1], img), timeout=1) # put in queue
except queue.Full:
if kill_read_thread.is_set():
break
continue
break
read_thread_completed.set() # read all images
# start reading thread
t = threading.Thread(target=reading_thread, args=(zip_fn,))
t.daemon = True
t.start()
img_fns = []
img_embeddings = []
batch_imgs = []
def process_batch(batch_imgs):
batch_imgs = np.stack(batch_imgs, axis=0)
batch_embeddings = model.predict(batch_imgs)
img_embeddings.append(batch_embeddings)
try:
while True:
try:
fn, img = q.get(timeout=1)
except queue.Empty:
if read_thread_completed.is_set():
break
continue
img_fns.append(fn)
batch_imgs.append(img)
if len(batch_imgs) == batch_size:
process_batch(batch_imgs)
batch_imgs = []
q.task_done()
# process last batch
if len(batch_imgs):
process_batch(batch_imgs)
finally:
kill_read_thread.set()
t.join()
q.join()
img_embeddings = np.vstack(img_embeddings)
return img_embeddings, img_fns
def save_pickle(obj, fn):
with open(fn, "wb") as f:
pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
def read_pickle(fn):
with open(fn, "rb") as f:
return pickle.load(f)