博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Keras(十二)tf_data基础API使用
阅读量:4202 次
发布时间:2019-05-26

本文共 5665 字,大约阅读时间需要 18 分钟。

本文将介绍如下内容:

  • numpy转化为tf.data.Dataset
  • tf.data.Dataset方法-repeat,batch
  • tf.data.Dataset方法-interleave
  • from_tensor_slices中传入set元组参数
  • from_tensor_slices中传入dict字典参数

一,numpy转化为tf.data.Dataset

from_tensor_slices中传入的参数可以是list,dict,numpy

1,实现numpy转化为tf.data.Dataset代码

将python或numpy的数据类型转化为tf.data.Dataset数据类型

# 1,将python或numpy的数据类型转化为tf.data.Dataset数据类型# from_tensor_slices中传入的参数可以是list,dict,numpynp_data = np.arange(1,13,dtype=np.float32)np_data = np_data.reshape(3,4)dataset = tf.data.Dataset.from_tensor_slices(np_data)print(dataset)# 返回数据类型
# 2,tf.data.Dataset方法-遍历for item in dataset: print(item)
2,总结代码
import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # 从内存中构建数据集# 1,将python或numpy的数据类型转化为tf.data.Dataset数据类型# from_tensor_slices中传入的参数可以是list,dict,numpynp_data = np.arange(1,13,dtype=np.float32)np_data = np_data.reshape(3,4)dataset = tf.data.Dataset.from_tensor_slices(np_data)print(dataset)# 返回数据类型
#---output---------2.1.0sys.version_info(major=3, minor=6, micro=8, releaselevel='final', serial=0)matplotlib 3.2.1numpy 1.18.5pandas 1.0.3sklearn 0.21.3tensorflow 2.1.0tensorflow_core.keras 2.2.4-tf

二,tf.data.Dataset方法-repeat,batch

1,实现tf.data.Dataset方法-repeat,batch代码
# 3,tf.data.Dataset方法-repeat,batch# 3.1. repeat epoch# 3.2. get batchdataset = dataset.repeat(3).batch(3)for item in dataset:    print(item)#---output---------tf.Tensor([[ 1.  2.  3.  4.] [ 5.  6.  7.  8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)tf.Tensor([[ 1.  2.  3.  4.] [ 5.  6.  7.  8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)tf.Tensor([[ 1.  2.  3.  4.] [ 5.  6.  7.  8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)
2,总结代码
#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # 从内存中构建数据集# 1,将python或numpy的数据类型转化为tf.data.Dataset数据类型# from_tensor_slices中传入的参数可以是list,dict,numpynp_data = np.arange(1,13,dtype=np.float32)np_data = np_data.reshape(3,4)dataset = tf.data.Dataset.from_tensor_slices(np_data)print(dataset)# 返回数据类型
# 2,tf.data.Dataset方法-遍历for item in dataset: print(item) # 3,tf.data.Dataset方法-repeat,batch# 3.1. repeat epoch# 3.2. get batchdataset = dataset.repeat(3).batch(3)for item in dataset: print(item)#---output--------tf.Tensor([[ 1. 2. 3. 4.] [ 5. 6. 7. 8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)tf.Tensor([[ 1. 2. 3. 4.] [ 5. 6. 7. 8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)tf.Tensor([[ 1. 2. 3. 4.] [ 5. 6. 7. 8.] [ 9. 10. 11. 12.]], shape=(3, 4), dtype=float32)

三,tf.data.Dataset方法-interleave

实现代码如下:
#!/usr/bin/env python3# -*- coding: utf-8 -*-import matplotlib as mplimport matplotlib.pyplot as pltimport numpy as npimport sklearnimport pandas as pdimport osimport sysimport timeimport tensorflow as tffrom tensorflow import keras# 打印使用的python库的版本信息print(tf.__version__)print(sys.version_info)for module in mpl, np, pd, sklearn, tf, keras:    print(module.__name__, module.__version__)    # 从内存中构建数据集# 1,将python或numpy的数据类型转化为tf.data.Dataset数据类型# from_tensor_slices中传入的参数可以是list,dict,numpynp_data = np.arange(1,13,dtype=np.float32)np_data = np_data.reshape(3,4)dataset = tf.data.Dataset.from_tensor_slices(np_data)print(dataset)# 返回数据类型
# 2,tf.data.Dataset方法-遍历for item in dataset: print(item) # 4,tf.data.Dataset方法-interleave# case: 文件dataset -> 具体数据集dataset2 = dataset.interleave( lambda v: tf.data.Dataset.from_tensor_slices(v), # map_fn 对每个元素做如何的处理 cycle_length = 2, # cycle_length 并行数量 block_length = 3, # block_length 每次取多少数据出来)for item in dataset2: print("------------------") print(item)#----output--------tf.Tensor(1.0, shape=(), dtype=float32)------------------tf.Tensor(2.0, shape=(), dtype=float32)------------------tf.Tensor(3.0, shape=(), dtype=float32)------------------tf.Tensor(5.0, shape=(), dtype=float32)------------------tf.Tensor(6.0, shape=(), dtype=float32)------------------tf.Tensor(7.0, shape=(), dtype=float32)------------------tf.Tensor(4.0, shape=(), dtype=float32)------------------tf.Tensor(8.0, shape=(), dtype=float32)------------------tf.Tensor(9.0, shape=(), dtype=float32)------------------tf.Tensor(10.0, shape=(), dtype=float32)------------------tf.Tensor(11.0, shape=(), dtype=float32)------------------tf.Tensor(12.0, shape=(), dtype=float32)

四,from_tensor_slices中传入set元组参数

# from_tensor_slices中传入的参数可以是set元组x = np.array([[1, 2], [3, 4], [5, 6]])y = np.array(['cat', 'dog', 'fox'])dataset3 = tf.data.Dataset.from_tensor_slices((x, y))print(dataset3)for item_x, item_y in dataset3:    print(item_x.numpy(), item_y.numpy())# ---output----
[1 2] b'cat'[3 4] b'dog'[5 6] b'fox'

五,from_tensor_slices中传入dict字典参数

# 5,from_tensor_slices中传入的参数可以是set元组x = np.array([[1, 2], [3, 4], [5, 6]])y = np.array(['cat', 'dog', 'fox'])dataset3 = tf.data.Dataset.from_tensor_slices((x, y))print(dataset3)for item_x, item_y in dataset3:    print(item_x.numpy(), item_y.numpy())    # 6,from_tensor_slices中传入的参数可以是dict字段dataset4 = tf.data.Dataset.from_tensor_slices({
"feature": x,"label": y})for item in dataset4: print(item["feature"].numpy(), item["label"].numpy())

转载地址:http://bvili.baihongyu.com/

你可能感兴趣的文章
用SpringCloud Alibaba搭建属于自己的微服务(五)~基础搭建~cloud、cloud alibaba和boot的版本选择
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(七)~基础搭建~springboot整合druid和mybatisPlus
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(八)~基础搭建~springboot整合swagger接口文档
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(九)~基础搭建~参数校验框架的使用
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十)~基础搭建~自定义异常、统一结果集和全局异常处理器
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十一)~基础搭建~alibaba nacos的安装
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十二)~基础搭建~alibaba nacos的服务注册和发现
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十五)~基础搭建~使用openfeign集成ribbon负载均衡.
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十六)~基础搭建~openfegin+ribbon的rpc调用高可用之重试机制
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十七)~基础搭建~alibaba sentinel服务端的安装
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十八)~基础搭建~alibaba sentinel限流
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(十九)~基础搭建~alibaba sentinel熔断
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(二十五)~基础搭建~gateway整合swagger接口文档
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(二十六)~业务开发~用户注册
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(二十七)~业务开发~jwt实现用户登录
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(二十八)~业务开发~gateway实现鉴权
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(三十一)~业务开发~查看商品信息接口开发
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(三十二)~业务开发~扣款接口开发
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(三十三)~业务开发~支付接口开发
查看>>
用SpringCloud Alibaba搭建属于自己的微服务(三十四)~业务开发~下订单核心接口开发
查看>>