|
| 1 | +""" Wrapper class for Intel's PixelLink realisation (text segmentation NN) |
| 2 | + text-detection-00[34] |
| 3 | +
|
| 4 | + For text-detection-002 you'll need to uncomment string in detect() |
| 5 | +""" |
| 6 | +import cv2 |
| 7 | +import numpy as np |
| 8 | +from scipy.special import softmax |
| 9 | +from skimage.morphology import label |
| 10 | +from skimage.measure import regionprops |
| 11 | +from typing import List, Tuple |
| 12 | +from skimage.measure._regionprops import RegionProperties |
| 13 | + |
| 14 | + |
| 15 | +class PixelLinkDetector(): |
| 16 | + """ Wrapper class for Intel's version of PixelLink text-detection-0001 |
| 17 | + :param xml_model_path: path to XML file |
| 18 | +
|
| 19 | + **Example:** |
| 20 | +
|
| 21 | + .. code-block:: python |
| 22 | + detector = PixelLinkDetector('text-detection-0002.xml') |
| 23 | + img = cv2.imread('tmp.jpg') |
| 24 | + # ~250ms on i7-6700K |
| 25 | + detector.detect(img) |
| 26 | + # ~2ms |
| 27 | + bboxes = detector.decode() |
| 28 | + """ |
| 29 | + def __init__(self, xml_model_path: str, txt_threshold=0.5): |
| 30 | + """ |
| 31 | + :param xml_model_path: path to model's XML file |
| 32 | + :param txt_threshold: confidence, defaults to ``0.5`` |
| 33 | + """ |
| 34 | + self._net = cv2.dnn.readNet(xml_model_path, xml_model_path[:-3] + 'bin') |
| 35 | + self._txt_threshold = txt_threshold |
| 36 | + |
| 37 | + def detect(self, img: np.ndarray) -> None: |
| 38 | + """ GetPixelLink's outputs |
| 39 | + :param img: image as ``numpy.ndarray`` |
| 40 | + """ |
| 41 | + self._img_shape = img.shape |
| 42 | + blob = cv2.dnn.blobFromImage(img, 1, (1280, 768)) |
| 43 | + self._net.setInput(blob) |
| 44 | + out_layer_names = self._net.getUnconnectedOutLayersNames() |
| 45 | + # for text-detection-002 |
| 46 | + # self.pixels, self.links = self._net.forward(out_layer_names) |
| 47 | + # for text-detection-00[34] |
| 48 | + self.links, self.pixels = self._net.forward(out_layer_names) |
| 49 | + |
| 50 | + def get_mask(self) -> np.array: |
| 51 | + """ Get binary mask of detected text pixels |
| 52 | + """ |
| 53 | + pixel_mask = self._get_pixel_scores() >= self._txt_threshold |
| 54 | + return pixel_mask.astype(np.uint8) |
| 55 | + |
| 56 | + def _get_pixel_scores(self) -> np.array: |
| 57 | + "get softmaxed properly shaped pixel scores" |
| 58 | + tmp = np.transpose(self.pixels, (0, 2, 3, 1)) |
| 59 | + return softmax(tmp, axis=-1)[0, :, :, 1] |
| 60 | + |
| 61 | + def _get_txt_regions(self, pixel_mask: np.array) -> List[RegionProperties]: |
| 62 | + "kernels are class dependent" |
| 63 | + img_h, img_w = self._img_shape[:2] |
| 64 | + _, mask = cv2.threshold(pixel_mask, 0, 1, cv2.THRESH_BINARY) |
| 65 | + # transmutatioins |
| 66 | + # kernel size should be image size dependant (default (21,21)) |
| 67 | + # on small image it will connect separate words |
| 68 | + txt_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2)) |
| 69 | + mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, txt_kernel) |
| 70 | + # label regions on mask of original img size |
| 71 | + mask = cv2.resize(mask, (img_w, img_h), interpolation=cv2.INTER_NEAREST) |
| 72 | + mask = label(mask, background=0, connectivity=2) |
| 73 | + txt_regions = regionprops(mask) |
| 74 | + return txt_regions |
| 75 | + |
| 76 | + def _get_txt_bboxes(self, txt_regions: List[RegionProperties]) -> List[Tuple[int, int, int, int]]: |
| 77 | + """ Filter text area by area and height |
| 78 | +
|
| 79 | + :return: ``[(ymin, xmin, ymax, xmax)]`` |
| 80 | + """ |
| 81 | + min_area = 0 |
| 82 | + min_height = 4 |
| 83 | + boxes = [] |
| 84 | + for p in txt_regions: |
| 85 | + if p.area > min_area: |
| 86 | + bbox = p.bbox |
| 87 | + if (bbox[2] - bbox[0]) > min_height: |
| 88 | + boxes.append(bbox) |
| 89 | + return boxes |
| 90 | + |
| 91 | + def decode(self) -> List[Tuple[int, int, int, int]]: |
| 92 | + """ Decode PixelLink's output |
| 93 | +
|
| 94 | + :return: bounding_boxes |
| 95 | +
|
| 96 | + .. note:: |
| 97 | + bounding_boxes format: [ymin ,xmin ,ymax, xmax] |
| 98 | +
|
| 99 | + """ |
| 100 | + mask = self.get_mask() |
| 101 | + bboxes = self._get_txt_bboxes(self._get_txt_regions(mask)) |
| 102 | + return bboxes |
0 commit comments