source: internals/2016/aptoideimagesdetector/trunk/Source Code/Illustration2Vector/illustration2vec-master/i2v/base.py @ 16289

Last change on this file since 16289 was 16289, checked in by dferreira, 3 years ago

Initial content. All the tests done to three open-source platrofms.

  • Property svn:executable set to *
File size: 5.1 KB
Line 
1from abc import ABCMeta, abstractmethod
2import numpy as np
3
4
5class Illustration2VecBase(object):
6
7    __metaclass__ = ABCMeta
8
9    def __init__(self, net, tags=None, threshold=None):
10        self.net = net
11        if tags is not None:
12            self.tags = np.array(tags)
13            self.index = {t: i for i, t in enumerate(tags)}
14        else:
15            self.tags = None
16
17        if threshold is not None:
18            self.threshold = threshold
19        else:
20            self.threshold = None
21
22    @abstractmethod
23    def _extract(self, inputs, layername):
24        pass
25
26    def _convert_image(self, image):
27        arr = np.asarray(image, dtype=np.float32)
28        if arr.ndim == 2:
29            # convert a monochrome image to a color one
30            ret = np.empty((arr.shape[0], arr.shape[1], 3), dtype=np.float32)
31            ret[:] = arr.reshape(arr.shape[0], arr.shape[1], 1)
32            return ret
33        elif arr.ndim == 3:
34            # if arr contains alpha channel, remove it
35            return arr[:,:,:3]
36        else:
37            raise TypeError('unsupported image specified')
38
39    def _estimate(self, images):
40        assert(self.tags is not None)
41        imgs = [self._convert_image(img) for img in images]
42        prob = self._extract(imgs, layername='prob')
43        prob = prob.reshape(prob.shape[0], -1)
44        return prob
45
46    def estimate_specific_tags(self, images, tags):
47        prob = self._estimate(images)
48        return [{t: float(prob[i, self.index[t]]) for t in tags}
49                for i in range(prob.shape[0])]
50
51    def estimate_top_tags(self, images, n_tag=10):
52        prob = self._estimate(images)
53        general_prob = prob[:, :512]
54        character_prob = prob[:, 512:1024]
55        copyright_prob = prob[:, 1024:1536]
56        rating_prob = prob[:, 1536:]
57        general_arg = np.argsort(-general_prob, axis=1)[:, :n_tag]
58        character_arg = np.argsort(-character_prob, axis=1)[:, :n_tag]
59        copyright_arg = np.argsort(-copyright_prob, axis=1)[:, :n_tag]
60        rating_arg = np.argsort(-rating_prob, axis=1)
61        result = []
62        for i in range(prob.shape[0]):
63            result.append({
64                'general': zip(
65                    self.tags[general_arg[i]],
66                    general_prob[i, general_arg[i]].tolist()),
67                'character': zip(
68                    self.tags[512 + character_arg[i]],
69                    character_prob[i, character_arg[i]].tolist()),
70                'copyright': zip(
71                    self.tags[1024 + copyright_arg[i]],
72                    copyright_prob[i, copyright_arg[i]].tolist()),
73                'rating': zip(
74                    self.tags[1536 + rating_arg[i]],
75                    rating_prob[i, rating_arg[i]].tolist()),
76            })
77        return result
78
79    def __extract_plausible_tags(self, preds, f):
80        result = []
81        for pred in preds:
82            general = [(t, p) for t, p in pred['general'] if f(t, p)]
83            character = [(t, p) for t, p in pred['character'] if f(t, p)]
84            copyright = [(t, p) for t, p in pred['copyright'] if f(t, p)]
85            result.append({
86                'general': general,
87                'character': character,
88                'copyright': copyright,
89                'rating': pred['rating'],
90            })
91        return result
92
93    def estimate_plausible_tags(
94            self, images, threshold=0.25, threshold_rule='constant'):
95        preds = self.estimate_top_tags(images, n_tag=512)
96        result = []
97        if threshold_rule == 'constant':
98            return self.__extract_plausible_tags(
99                preds, lambda t, p: p > threshold)
100        elif threshold_rule == 'f0.5':
101            if self.threshold is None:
102                raise TypeError(
103                    'please specify threshold option during init.')
104            return self.__extract_plausible_tags(
105                preds, lambda t, p: p > self.threshold[self.index[t], 0])
106        elif threshold_rule == 'f1':
107            if self.threshold is None:
108                raise TypeError(
109                    'please specify threshold option during init.')
110            return self.__extract_plausible_tags(
111                preds, lambda t, p: p > self.threshold[self.index[t], 1])
112        elif threshold_rule == 'f2':
113            if self.threshold is None:
114                raise TypeError(
115                    'please specify threshold option during init.')
116            return self.__extract_plausible_tags(
117                preds, lambda t, p: p > self.threshold[self.index[t], 2])
118        else:
119            raise TypeError('unknown rule specified')
120        return result
121
122    def extract_feature(self, images):
123        imgs = [self._convert_image(img) for img in images]
124        feature = self._extract(imgs, layername='encode1')
125        feature = feature.reshape(feature.shape[0], -1)
126        return feature
127
128    def extract_binary_feature(self, images):
129        imgs = [self._convert_image(img) for img in images]
130        feature = self._extract(imgs, layername='encode1neuron')
131        feature = feature.reshape(feature.shape[0], -1)
132        binary_feature = np.zeros_like(feature, dtype=np.uint8)
133        binary_feature[feature > 0.5] = 1
134        return np.packbits(binary_feature, axis=1)
Note: See TracBrowser for help on using the repository browser.