55"""
66import cv2
77import numpy as np
8- from scipy .special import softmax
98from skimage .morphology import label
109from skimage .measure import regionprops
1110from typing import List , Tuple
1211from skimage .measure ._regionprops import RegionProperties
1312
1413
1514class PixelLinkDetector ():
16- """ Wrapper class for Intel's version of PixelLink text-detection-0001
15+ """ Wrapper class for Intel's version of PixelLink text detector
16+
17+ See https://github.com/openvinotoolkit/open_model_zoo/blob/master/models/intel/ \
18+ text-detection-0004/description/text-detection-0004.md
19+
1720 :param xml_model_path: path to XML file
1821
1922 **Example:**
2023
2124 .. code-block:: python
22- detector = PixelLinkDetector('text-detection-0002 .xml')
25+ detector = PixelLinkDetector('text-detection-0004 .xml')
2326 img = cv2.imread('tmp.jpg')
2427 # ~250ms on i7-6700K
2528 detector.detect(img)
@@ -35,7 +38,15 @@ def __init__(self, xml_model_path: str, txt_threshold=0.5):
3538 self ._txt_threshold = txt_threshold
3639
3740 def detect (self , img : np .ndarray ) -> None :
38- """ GetPixelLink's outputs
41+ """ GetPixelLink's outputs (BxCxHxW):
42+ + [1x16x192x320] - logits related to linkage between pixels and their neighbors
43+ + [1x2x192x320] - logits related to text/no-text classification for each pixel
44+
45+ B - batch size
46+ C - number of channels
47+ H - image height
48+ W - image width
49+
3950 :param img: image as ``numpy.ndarray``
4051 """
4152 self ._img_shape = img .shape
@@ -47,29 +58,48 @@ def detect(self, img: np.ndarray) -> None:
4758 # for text-detection-00[34]
4859 self .links , self .pixels = self ._net .forward (out_layer_names )
4960
50- def get_mask (self ) -> np .array :
61+ def get_mask (self ) -> np .ndarray :
5162 """ Get binary mask of detected text pixels
5263 """
5364 pixel_mask = self ._get_pixel_scores () >= self ._txt_threshold
5465 return pixel_mask .astype (np .uint8 )
5566
56- def _get_pixel_scores (self ) -> np .array :
57- "get softmaxed properly shaped pixel scores"
67+ def _logsumexp (self , a : np .ndarray , axis = - 1 ) -> np .ndarray :
68+ """ Castrated function from scipy
69+ https://github.com/scipy/scipy/blob/v1.6.2/scipy/special/_logsumexp.py
70+
71+ Compute the log of the sum of exponentials of input elements.
72+ """
73+ a_max = np .amax (a , axis = axis , keepdims = True )
74+ tmp = np .exp (a - a_max )
75+ s = np .sum (tmp , axis = axis , keepdims = True )
76+ out = np .log (s )
77+ out += a_max
78+ return out
79+
80+ def _get_pixel_scores (self ) -> np .ndarray :
81+ """ get softmaxed properly shaped pixel scores """
82+ # move channels to the end
5883 tmp = np .transpose (self .pixels , (0 , 2 , 3 , 1 ))
59- return softmax (tmp , axis = - 1 )[0 , :, :, 1 ]
84+ # softmax from scipy
85+ tmp = np .exp (tmp - self ._logsumexp (tmp , axis = - 1 ))
86+ # select single batch, single chanel values
87+ return tmp [0 , :, :, 1 ]
6088
61- def _get_txt_regions (self , pixel_mask : np .array ) -> List [RegionProperties ]:
62- "kernels are class dependent"
89+ def _get_txt_regions (self , pixel_mask : np .ndarray ) -> List [RegionProperties ]:
90+ """ kernels are class dependent "" "
6391 img_h , img_w = self ._img_shape [:2 ]
6492 _ , mask = cv2 .threshold (pixel_mask , 0 , 1 , cv2 .THRESH_BINARY )
6593 # transmutatioins
6694 # kernel size should be image size dependant (default (21,21))
6795 # on small image it will connect separate words
6896 txt_kernel = cv2 .getStructuringElement (cv2 .MORPH_RECT , (2 , 2 ))
6997 mask = cv2 .morphologyEx (mask , cv2 .MORPH_CLOSE , txt_kernel )
70- # label regions on mask of original img size
98+ # connect regions on mask of original img size
7199 mask = cv2 .resize (mask , (img_w , img_h ), interpolation = cv2 .INTER_NEAREST )
100+ # Label connected regions of an integer array
72101 mask = label (mask , background = 0 , connectivity = 2 )
102+ # Measure properties of labeled image regions.
73103 txt_regions = regionprops (mask )
74104 return txt_regions
75105
@@ -99,4 +129,6 @@ def decode(self) -> List[Tuple[int, int, int, int]]:
99129 """
100130 mask = self .get_mask ()
101131 bboxes = self ._get_txt_bboxes (self ._get_txt_regions (mask ))
132+ # sort by xmin, ymin
133+ bboxes = sorted (bboxes , key = lambda x : (x [1 ], x [0 ]))
102134 return bboxes
0 commit comments