当前位置:Gxlcms > 数据库问题 > LMDB数据库加速Pytorch文件读取速度

LMDB数据库加速Pytorch文件读取速度

时间:2021-07-01 10:21:17 帮助过:32人阅读

import lmdb 2 3 env = lmdb.open(D:/desktop/lmdb, map_size=10*1024**2)

 

指定存放生成的lmdb数据库的文件夹路径,如果没有该文件夹则自动创建。

map_size 指定创建的新数据库所需磁盘空间的最小值,1099511627776B=1T。可以在这里进行 存储单位换算。

会在指定路径下创建 data.mdb 和 lock.mdb 两个文件,一是个数据文件,一个是锁文件。

 

修改数据库内容:

 1 # 创建一个事务Transaction对象
 2 txn = env.begin(write=True)
 3 
 4 # insert/modify
 5 # txn.put(key, value)
 6 txn.put(str(1).encode(), "Alice".encode()) # .encode()编码为字节bytes格式
 7 txn.put(str(2).encode(), "Bob".encode())
 8 txn.put(str(3).encode(), "Jack".encode())
 9 
10 # delete
11 # txn.delete(key)
12 txn.delete(str(1).encode())
13 
14 # 提交待处理的事务
15 txn.commit()

先创建一个事务(transaction) 对象 txn,所有的操作都必须经过这个事务对象。因为我们要对数据库进行写入操作,所以将 write 参数置为 True,默认其为 False

使用 .put(key, value) 对数据库进行插入和修改操作,传入的参数为键值对。

值得注意的是,需要在键值字符串后加 .encode() 改变其编码格式,将 str 转换为 bytes 格式,否则会报该错误:TypeError: Won‘t implicitly convert Unicode to bytes; use .encode()。在后面使用 .decode() 对其进行解码得到原数据。

使用 .delete(key) 删除指定键值对。

 对LMDB的读写操作在事务中执行,需要使用 commit 方法提交待处理的事务。

 

查询数据库内容:

1 # 数据库查询
2 txn = env.begin() # 每个commit()之后都需要使用begin()方法更新txn得到最新数据库
3 
4 print(txn.get(str(2).encode()))
5 
6 for key, value in txn.cursor():
7     print(key, value)
8 
9 env.close

每次 commit() 之后都要用 env.begin() 更新 txn(得到最新的lmdb数据库)。

使用 .get(key) 查询数据库中的单条记录。

使用 .cursor() 遍历数据库中的所有记录,其返回一个可迭代对象,相当于关系数据库中的游标,每读取一次,游标下移一位。

 

也可以想文件一样使用 with 语法:

1 # 可以像文件一样使用with语法
2 with env.begin() as txn:
3     print(txn.get(str(2).encode()))
4 
5     for key, value in txn.cursor():
6         print(key, value)
7 env.close

 

完整的demo如下:

技术图片
 1 import lmdb
 2 import os, sys
 3 
 4 def initialize(lmdb_dir, map_size):
 5     # map_size: bytes
 6     env = lmdb.open(lmdb_dir, map_size)
 7     return env
 8 
 9 def insert(env, key, value):
10     txn = env.begin(write=True)
11     txn.put(str(key).encode(), value.encode())
12     txn.commit()
13 
14 def delete(env, key):
15     txn = env.begin(write=True)
16     txn.delete(str(key).encode())
17     txn.commit()
18 
19 def update(env, key, value):
20     txn = env.begin(write=True)
21     txn.put(str(key).encode(), value.encode())
22     txn.commit()
23 
24 def search(env, key):
25     txn = env.begin()
26     value = txn.get(str(key).encode())
27     return value
28 
29 def display(env):
30     txn = env.begin()
31     cursor = txn.cursor()
32     for key, value in cursor:
33         print(key, value)
34 
35 
36 if __name__ == __main__:
37     path = D:/desktop/lmdb
38     env = initialize(path, 10*1024*1024)
39 
40     print("Insert 3 records.")
41     insert(env, 1, "Alice")
42     insert(env, 2, "Bob")
43     insert(env, 3, "Peter")
44     display(env)
45 
46     print("Delete the record where key = 1")
47     delete(env, 1)
48     display(env)
49 
50     print("Update the record where key = 3")
51     update(env, 3, "Mark")
52     display(env)
53 
54     print("Get the value whose key = 3")
55     name = search(env, 3)
56     print(name)
57 
58     # 最后需要关闭lmdb数据库
59     env.close()
View Code

图片数据示例

在图像深度学习训练中我们一般都会把大量原始数据集转化为lmdb格式以方便后续的网络训练。因此我们也需要对该数据集进行lmdb格式转化。

将图片和对应的文本标签存放到lmdb数据库:

技术图片
 1 import lmdb
 2 
 3 image_path = ./cat.jpg
 4 label = cat
 5 
 6 env = lmdb.open(lmdb_dir)
 7 cache = {}  # 存储键值对
 8 
 9 with open(image_path, rb) as f:
10     # 读取图像文件的二进制格式数据
11     image_bin = f.read()
12 
13 # 用两个键值对表示一个数据样本
14 cache[image_000] = image_bin
15 cache[label_000] = label
16 
17 with env.begin(write=True) as txn:
18     for k, v in cache.items():
19         if isinstance(v, bytes):
20             # 图片类型为bytes
21             txn.put(k.encode(), v)
22         else:
23             # 标签类型为str, 转为bytes
24             txn.put(k.encode(), v.encode())  # 编码
25 
26 env.close()
View Code

这里需要获取图像文件的二进制格式数据,然后用两个键值对保存一个数据样本,即分开保存图片和其标签。

然后分别将图像和标签写入到lmdb数据库中,和上面例子一样都需要将键值转换为 bytes 格式。因为此处读取的图片格式本身就为 bytes,所以不需要转换,标签格式为 str,写入数据库之前需要先进行编码将其转换为 bytes

 

从lmdb数据库中读取图片数据:

技术图片
 1 import cv2
 2 import lmdb
 3 import numpy as np
 4 
 5 env = lmdb.open(lmdb_dir)
 6 
 7 with env.begin(write=False) as txn:
 8     # 获取图像数据
 9     image_bin = txn.get(image_000.encode())
10     label = txn.get(label_000.encode()).decode()  # 解码
11 
12     # 将二进制文件转为十进制文件(一维数组)
13     image_buf = np.frombuffer(image_bin, dtype=np.uint8)
14     # 将数据转换(解码)成图像格式
15     # cv2.IMREAD_GRAYSCALE为灰度图,cv2.IMREAD_COLOR为彩色图
16     img = cv2.imdecode(image_buf, cv2.IMREAD_COLOR)
17     cv2.imshow(image, img)
18     cv2.waitKey(0)
View Code

先通过 lmdb.open() 获取之前创建的lmdb数据库。

这里通过键得到图片和其标签,因为写入数据库之前进行了编码,所以这里需要先解码。

  • 标签通过 .decode() 进行解码重新得到字符串格式。
  • 读取到的图片数据为二进制格式,所以先使用 np.frombuffer() 将其转换为十进制格式的文件,这是一维数组。然后可以使用 cv2.imdecode() 将其转换为灰度图(二维数组)或者彩色图(三维数组)。

 

leveldb

leveldb的使用与lmdb差不多,然而LevelDB 是单进程的服务。

https://www.jianshu.com/p/66496c8726a1

https://github.com/liquidconv/py4db

https://github.com/google/leveldb

技术图片
 1 #!/usr/bin/env python
 2 
 3 import leveldb
 4 import os, sys
 5 
 6 def initialize():
 7     db = leveldb.LevelDB("students");
 8     return db;
 9 
10 def insert(db, sid, name):
11     db.Put(str(sid), name);
12 
13 def delete(db, sid):
14     db.Delete(str(sid));
15 
16 def update(db, sid, name):
17     db.Put(str(sid), name);
18 
19 def search(db, sid):
20     name = db.Get(str(sid));
21     return name;
22 
23 def display(db):
24     for key, value in db.RangeIter():
25         print (key, value);
26 
27 db = initialize();
28 
29 print "Insert 3 records."
30 insert(db, 1, "Alice");
31 insert(db, 2, "Bob");
32 insert(db, 3, "Peter");
33 display(db);
34 
35 print "Delete the record where sid = 1."
36 delete(db, 1);
37 display(db);
38 
39 print "Update the record where sid = 3."
40 update(db, 3, "Mark");
41 display(db);
42 
43 print "Get the name of student whose sid = 3."
44 name = search(db, 3);
45 print name;
View Code

 

 pytorch从lmdb中加载数据

这里给出一种pytorch从lmdb中加载数据的参考示例,来自:https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

 1 from __future__ import print_function
 2 import torch.utils.data as data
 3 # import h5py
 4 import numpy as np
 5 import lmdb
 6 
 7 
 8 class onlineHCCR(data.Dataset):
 9     def __init__(self, train=True):
10         # self.root = root
11         self.train = train
12 
13         if self.train:
14             datalmdb_path = traindata_lmdb
15             labellmdb_path = trainlabel_lmdb
16             self.data_env = lmdb.open(datalmdb_path, readonly=True)
17             self.label_env = lmdb.open(labellmdb_path, readonly=True)
18 
19         else:
20             datalmdb_path = testdata_lmdb
21             labellmdb_path = testlabel_lmdb
22             self.data_env = lmdb.open(datalmdb_path, readonly=True)
23             self.label_env = lmdb.open(labellmdb_path, readonly=True)
24 
25 
26     def __getitem__(self, index):
27 
28         Data = []
29         Target = []
30 
31         if self.train:
32             with self.data_env.begin() as f:
33                 key = {:08}.format(index)
34                 data = f.get(key)
35                 flat_data = np.fromstring(data, dtype=float)
36                 data = flat_data.reshape(150, 6).astype(float32)
37                 Data = data
38 
39             with self.label_env.begin() as f:
40                 key = {:08}.format(index)
41                 data = f.get(key)
42                 label = np.fromstring(data, dtype=int)
43                 Target = label[0]
44 
45         else:
46 
47             with self.data_env.begin() as f:
48                 key = {:08}.format(index)
49                 data = f.get(key)
50                 flat_data = np.fromstring(data, dtype=float)
51                 data = flat_data.reshape(150, 6).astype(float32)
52                 Data = data
53 
54             with self.label_env.begin() as f:
55                 key = {:08}.format(index)
56                 data = f.get(key)
57                 label = np.fromstring(data, dtype=int)
58                 Target = label[0]
59 
60         return Data, Target
61         
62 
63     def __len__(self):
64         if self.train:
65             return 2693931
66         else:
67             return 224589

 

 

 

参考:

lmdb 数据库

Python操作SQLite/MySQL/LMDB/LevelDB

https://github.com/liquidconv/py4db

https://discuss.pytorch.org/t/whats-the-best-way-to-load-large-data/2977

 

LMDB数据库加速Pytorch文件读取速度

标签:tla   自动   turn   student   mysq   online   指针   磁盘io   jpg   

人气教程排行