-
Notifications
You must be signed in to change notification settings - Fork 2
/
alt_i2v_V2.py
89 lines (81 loc) · 2.83 KB
/
alt_i2v_V2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import keras
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential, Model, load_model
from keras.layers import Input, Activation, Dropout, Flatten, Dense, Reshape, merge
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.layers.normalization import BatchNormalization as BN
from keras.layers.core import Dropout
from keras.applications.vgg16 import VGG16
import numpy as np
import os
from PIL import Image
import glob
import pickle
import sys
import random
import msgpack
import numpy as np
import json
input_tensor = Input(shape=(224, 224, 3))
vgg16_model = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
for layer in vgg16_model.layers[:12]: # default 15
layer.trainable = False
x = vgg16_model.layers[-1].output
x = Flatten()(x)
x = BN()(x)
x = Dense(5000, activation='relu')(x)
x = Dropout(0.3)(x)
x = Dense(5000, activation='sigmoid')(x)
model = Model(input=vgg16_model.input, output=x)
model.compile(loss='binary_crossentropy', optimizer='adam')
def train():
for i in range(500):
print('now iter {} load pickled dataset...'.format(i))
Xs = []
ys = []
names = [name for idx, name in enumerate( glob.glob('../dataset/*.pkl') )]
random.shuffle( names )
for idx, name in enumerate(names):
try:
X,y = pickle.loads(open(name,'rb').read() )
except EOFError as e:
continue
if idx%100 == 0:
print('now scan iter', idx)
if idx >= 15000:
break
Xs.append( X )
ys.append( y )
Xs = np.array( Xs )
ys = np.array( ys )
model.fit(Xs, ys, epochs=1 )
print('now iter {} '.format(i))
model.save_weights('models/{:09d}.h5'.format(i))
def pred():
"""
tag_index = pickle.loads(open('tag_index.pkl', 'rb').read())
index_tag = { index:tag for tag, index in tag_index.items() }
name_img150 = []
for name in filter(lambda x: '.jpg' in x, sys.argv):
img = Image.open('{name}'.format(name=name))
img = img.convert('RGB')
img150 = np.array(img.resize((150, 150)))
name_img150.append( (name, img150) )
"""
model.load_weights(sorted(glob.glob('models/*.h5'))[-1])
tag_index = pickle.loads( open('make_datapair/tag_index.pkl', 'rb').read() )
index_tag = { index:tag for tag,index in tag_index.items() }
for name in glob.glob('./make_datapair/dataset/*'):
X, y = pickle.loads( open(name,'rb').read() )
result = model.predict(np.array([X]) )
result = result.tolist()[0]
result = { i:w for i,w in enumerate(result)}
for i,w in sorted(result.items(), key=lambda x:x[1]*-1)[:30]:
print("{name} tag={tag} prob={prob}".format(name=name, tag=index_tag[i], prob=w) )
if __name__ == '__main__':
if '--train' in sys.argv:
train()
if '--pred' in sys.argv:
pred()