source: internals/2016/aptoideimagesdetector/trunk/explicit_content_detector/i2v/chainer_i2v.py @ 16448

Last change on this file since 16448 was 16448, checked in by dferreira, 3 years ago

File organization totally changed

  • Property svn:executable set to *
File size: 3.2 KB
Line 
1from i2v.base import Illustration2VecBase
2import json
3import warnings
4import numpy as np
5from scipy.ndimage import zoom
6from skimage.transform import resize
7from chainer import Variable
8from chainer.functions import average_pooling_2d, sigmoid
9from chainer.functions.caffe import CaffeFunction
10
11
12class ChainerI2V(Illustration2VecBase):
13
14    def __init__(self, *args, **kwargs):
15        super(ChainerI2V, self).__init__(*args, **kwargs)
16        mean = np.array([ 164.76139251,  167.47864617,  181.13838569])
17        self.mean = mean
18
19    def resize_image(self, im, new_dims, interp_order=1):
20        # NOTE: we import the following codes from caffe.io.resize_image()
21        if im.shape[-1] == 1 or im.shape[-1] == 3:
22            im_min, im_max = im.min(), im.max()
23            if im_max > im_min:
24                # skimage is fast but only understands {1,3} channel images
25                # in [0, 1].
26                im_std = (im - im_min) / (im_max - im_min)
27                resized_std = resize(im_std, new_dims, order=interp_order)
28                resized_im = resized_std * (im_max - im_min) + im_min
29            else:
30                # the image is a constant -- avoid divide by 0
31                ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]),
32                               dtype=np.float32)
33                ret.fill(im_min)
34                return ret
35        else:
36            # ndimage interpolates anything but more slowly.
37            scale = tuple(np.array(new_dims) / np.array(im.shape[:2]))
38            resized_im = zoom(im, scale + (1,), order=interp_order)
39        return resized_im.astype(np.float32)
40
41    def _forward(self, inputs, layername):
42        shape = [len(inputs), 224, 224, 3]
43        input_ = np.zeros(shape, dtype=np.float32)
44        for ix, in_ in enumerate(inputs):
45            input_[ix] = self.resize_image(in_, shape[1:])
46        input_ = input_[:, :, :, ::-1]  # RGB to BGR
47        input_ -= self.mean  # subtract mean
48        input_ = input_.transpose((0, 3, 1, 2))  # (N, H, W, C) -> (N, C, H, W)
49        x = Variable(input_)
50        y, = self.net(inputs={'data': x}, outputs=[layername], train=False)
51        return y
52
53    def _extract(self, inputs, layername):
54        if layername == 'prob':
55            h = self._forward(inputs, layername='conv6_4')
56            h = average_pooling_2d(h, ksize=7)
57            y = sigmoid(h)
58            return y.data
59        elif layername == 'encode1neuron':
60            h = self._forward(inputs, layername='encode1')
61            y = sigmoid(h)
62            return y.data
63        else:
64            y = self._forward(inputs, layername)
65            return y.data
66
67def make_i2v_with_chainer(param_path, tag_path=None, threshold_path=None):
68    # ignore UserWarnings from chainer
69    with warnings.catch_warnings():
70        warnings.simplefilter('ignore')
71        net = CaffeFunction(param_path)
72
73    kwargs = {}
74    if tag_path is not None:
75        tags = json.loads(open(tag_path, 'r').read())
76        assert(len(tags) == 1539)
77        kwargs['tags'] = tags
78
79    if threshold_path is not None:
80        fscore_threshold = np.load(threshold_path)['threshold']
81        kwargs['threshold'] = fscore_threshold
82
83    return ChainerI2V(net, **kwargs)
Note: See TracBrowser for help on using the repository browser.