source: internals/2016/aptoideimagesdetector/trunk/explicit_content_detector/i2v/caffe_i2v.py @ 16448

Last change on this file since 16448 was 16448, checked in by dferreira, 3 years ago

File organization totally changed

  • Property svn:executable set to *
File size: 1.8 KB
Line 
1from i2v.base import Illustration2VecBase
2import json
3import numpy as np
4from caffe import Classifier
5from caffe.io import resize_image
6
7
8class CaffeI2V(Illustration2VecBase):
9
10    def _extract(self, inputs, layername):
11        # NOTE: we import the following codes from caffe.Classifier
12        shape = (
13            len(inputs), self.net.image_dims[0],
14            self.net.image_dims[1], inputs[0].shape[2])
15        input_ = np.zeros(shape, dtype=np.float32)
16        for ix, in_ in enumerate(inputs):
17            input_[ix] = resize_image(in_, self.net.image_dims)
18        # Take center crop.
19        center = np.array(self.net.image_dims) / 2.0
20        crop = np.tile(center, (1, 2))[0] + np.concatenate([
21            -self.net.crop_dims / 2.0,
22            self.net.crop_dims / 2.0
23        ])
24        input_ = input_[:, crop[0]:crop[2], crop[1]:crop[3], :]
25        # Classify
26        caffe_in = np.zeros(
27            np.array(input_.shape)[[0, 3, 1, 2]], dtype=np.float32)
28        for ix, in_ in enumerate(input_):
29            caffe_in[ix] = \
30                self.net.transformer.preprocess(self.net.inputs[0], in_)
31        out = self.net.forward_all(
32            blobs=[layername], **{self.net.inputs[0]: caffe_in})[layername]
33        return out
34
35
36def make_i2v_with_caffe(net_path, param_path, tag_path=None, threshold_path=None):
37    mean = np.array([ 164.76139251,  167.47864617,  181.13838569])
38    net = Classifier(
39        net_path, param_path, mean=mean, channel_swap=(2, 1, 0))
40
41    kwargs = {}
42    if tag_path is not None:
43        tags = json.loads(open(tag_path, 'r').read())
44        assert(len(tags) == 1539)
45        kwargs['tags'] = tags
46
47    if threshold_path is not None:
48        fscore_threshold = np.load(threshold_path)['threshold']
49        kwargs['threshold'] = fscore_threshold
50
51    return CaffeI2V(net, **kwargs)
Note: See TracBrowser for help on using the repository browser.