name_dataset = './name.csv'
train_x = []
train_y = []
with open(name_dataset, 'r',encoding='UTF-8') as f:
first_line = True
for line in f:
if first_line is True:
first_line = False
continue
sample = line.strip().split(',')
if len(sample) == 2:
train_x.append(sample[0])
if sample[1] == '男':
train_y.append([0, 1]) # 男
else:
train_y.append([1, 0]) # 女
max_name_length = max([len(name) for name in train_x])
# print("最长名字的字符数: ", max_name_length)
max_name_length = 8
counter = 0
vocabulary = {}
for name in train_x:
counter += 1
tokens = [word for word in name]
for word in tokens:
if word in vocabulary:
vocabulary[word] += 1
else:
vocabulary[word] = 1
# 字符串转为向量形式
vocab = dict([(x, y) for (y, x) in enumerate(vocabulary_list)])
train_x_vec = []
for name in train_x:
name_vec = []
for word in name:
name_vec.append(vocab.get(word))
while len(name_vec) < max_name_length:
name_vec.append(0)
train_x_vec.append(name_vec)
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(1): # 201
for i in range(num_batch):
batch_x = train_x_vec[i * batch_size: (i + 1) * batch_size]
batch_y = train_y[i * batch_size: (i + 1) * batch_size]
_, loss_ = sess.run([train_op, loss], feed_dict={X: batch_x, Y: batch_y, dropout_keep_prob: 0.5})
print(e, i, loss_)
# 保存模型
if e % 50 == 0:
saver.save(sess, './name2sex.model', global_step=e)
# 测试
def detect_sex(name_list):
x = []
for name in name_list:
name_vec = []
for word in name:
name_vec.append(vocab.get(word))
while len(name_vec) < max_name_length:
name_vec.append(0)
x.append(name_vec)
output = neural_network(len(vocabulary_list))
saver = tf.train.Saver(tf.global_variables())
with tf.Session() as sess:
# 恢复前一次训练
ckpt = tf.train.get_checkpoint_state('./')
if ckpt != None:
print(ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
print('已加载模型...')
else:
print("没找到模型")