十年网站开发经验 + 多家企业客户 + 靠谱的建站团队
量身定制 + 运营维护+专业推广+无忧售后,网站问题一站解决
这篇文章将为大家详细讲解有关如何实现TensorFlow微信跳一跳的AI,小编觉得挺实用的,因此分享给大家做个参考,希望大家阅读完这篇文章后可以有所收获。
成都创新互联的客户来自各行各业,为了共同目标,我们在工作上密切配合,从创业型小企业到企事业单位,感谢他们对我们的要求,感谢他们从不同领域给我们带来的挑战,让我们激情的团队有机会用头脑与智慧不断的给客户带来惊喜。专业领域包括网站建设、网站制作、电商网站开发、微信营销、系统平台开发。1.需要设备:
Android手机,数据线
ADB环境
Python环境(本例使用3.6.1)
TensorFlow(本例使用1.0.0)
2.大致原理
使用adb模拟点击和截屏,使用两层卷积神经网络作为训练模型,截屏图片作为输入,按压毫秒数直接作为为输出。
3.训练过程
最开始想的用强化学习,然后发现让它自己去玩成功率太!低!了!,加上每次截屏需要大量时间,就放弃了这个方法,于是考虑用自己玩的数据作为样本喂给它,这样就需要知道每次按压的时间。
我是这样做的,找一个手机写个app监听按压屏幕时间,另一个手机玩游戏,然后两个手指同时按两个手机o(╯□╰)o
4.上代码
首先,搭建模型:
第一层卷积:5*5的卷积核,12个featuremap,此时形状为96*96*12
池化层:4*4 max pooling,此时形状为24*24*12
第二层卷积:5*5的卷积核,24个featuremap,此时形状为20*20*24
池化层:4*4 max pooling,此时形状为5*5*24
全连接层:5*5*24连接到32个节点,使用relu激活函数和0.4的dropout率
输出:32个节点连接到1个节点,此节点就代表按压的时间(单位s)
# 输入:100*100的灰度图片,前面的None是batch size,这里都为1 x = tf.placeholder(tf.float32, shape=[None, 100, 100, 1]) # 输出:一个浮点数,就是按压时间,单位s y_ = tf.placeholder(tf.float32, shape=[None, 1]) # 第一层卷积 12个feature map W_conv1 = weight_variable([5, 5, 1, 12], 0.1) b_conv1 = bias_variable([12], 0.1) # 卷积后为96*96*12 h_conv1 = tf.nn.relu(conv2d(x, W_conv1) + b_conv1) h_pool1 = max_pool_4x4(h_conv1) # 池化后为24*24*12 # 第二层卷积 24个feature map W_conv2 = weight_variable([5, 5, 12, 24], 0.1) b_conv2 = bias_variable([24], 0.1) # 卷积后为20*20*24 h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_4x4(h_conv2) # 池化后为5*5*24 # 全连接层5*5*24 --> 32 W_fc1 = weight_variable([5 * 5 * 24, 32], 0.1) b_fc1 = bias_variable([32], 0.1) h_pool2_flat = tf.reshape(h_pool2, [-1, 5 * 5 * 24]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) # drapout,play时为1训练时为0.6 keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) # 学习率 learn_rate = tf.placeholder(tf.float32) # 32 --> 1 W_fc2 = weight_variable([32, 1], 0.1) b_fc2 = bias_variable([1], 0.1) y_fc2 = tf.matmul(h_fc1_drop, W_fc2) + b_fc2 # 因输出直接是时间值,而不是分类概率,所以用平方损失 cross_entropy = tf.reduce_mean(tf.square(y_fc2 - y_)) train_step = tf.train.AdamOptimizer(learn_rate).minimize(cross_entropy)
其次,获取屏幕截图并转换为模型输入:
# 获取屏幕截图并转换为模型的输入 def get_screen_shot(): # 使用adb命令截图并获取图片,这里如果把后缀改成jpg会导致TensorFlow读不出来 os.system('adb shell screencap -p /sdcard/jump_temp.png') os.system('adb pull /sdcard/jump_temp.png .') # 使用PIL处理图片,并转为jpg im = Image.open(r"./jump_temp.png") w, h = im.size # 将图片压缩,并截取中间部分,截取后为100*100 im = im.resize((108, 192), Image.ANTIALIAS) region = (4, 50, 104, 150) im = im.crop(region) # 转换为jpg bg = Image.new("RGB", im.size, (255, 255, 255)) bg.paste(im, im) bg.save(r"./jump_temp.jpg") img_data = tf.image.decode_jpeg(tf.gfile.FastGFile('./jump_temp.jpg', 'rb').read()) # 使用TensorFlow转为只有1通道的灰度图 img_data_gray = tf.image.rgb_to_grayscale(img_data) x_in = np.asarray(img_data_gray.eval(), dtype='float32') # [0,255]转为[0,1]浮点 for i in range(len(x_in)): for j in range(len(x_in[i])): x_in[i][j][0] /= 255 # 因为输入shape有batch维度,所以还要套一层 return [x_in]
以上代码过程大概是这样:
最后,开始训练:
while True: ………… # 每训练100个保存一次 if train_count % 100 == 0: saver_init.save(sess, "./save/mode.mod") ………… sess.run(train_step, feed_dict={x: x_in, y_: y_out, keep_prob: 0.6, learn_rate: 0.00005})
训练所用数据是直接从采集好的文件中读取的,由于样本有限(目前采集了800张图和对应800个按压时间,在github上train_data文件夹里),并且学习率太大又会震荡,只能用较小学习率反复学习这些图片。
关于“如何实现TensorFlow微信跳一跳的AI”这篇文章就分享到这里了,希望以上内容可以对大家有一定的帮助,使各位可以学到更多知识,如果觉得文章不错,请把它分享出去让更多的人看到。
另外有需要云服务器可以了解下创新互联scvps.cn,海内外云服务器15元起步,三天无理由+7*72小时售后在线,公司持有idc许可证,提供“云服务器、裸金属服务器、高防服务器、香港服务器、美国服务器、虚拟主机、免备案服务器”等云主机租用服务以及企业上云的综合解决方案,具有“安全稳定、简单易用、服务可用性高、性价比高”等特点与优势,专为企业上云打造定制,能够满足用户丰富、多元化的应用场景需求。