0%

Monkey patch tensorboard writer 提高性能

背景

工作中会经常使用 tensorboard 进行实验管理和记录。但是我们发现,在某些场景下,tensorboard 的写性能会发现一些降速。查了下 tb 代码,发现写 tb 的时候代码写成了每一次 append write 会 open 一个 file write 之后再 close,导致每次调用上层的 writer.add_scalar 函数的时候会多一次 openclose 操作,这个行为其实比较蠢,一般来说,只需要初始化一次文件指针,然后针对这个文件指针做操作即可,但是这里做了多次 openclose 真是多此一举。

另外,tb 外层的 flush 会清空本地队列的缓存,会等待 _byte_queue.join() 里元素都清空,这块逻辑没问题,但是底层的 fs 如果被我们换成一个统一的 fd 来 write 的话,需要支持能 flush 的 fs 中的 fd。但我们查看了 GFile 的 flush 代码 ,发现在我们的场景里面 self.fs_supports_append 是 True,所以看代码基本上 flush 啥都没干,close 代码也是类似的。因此我们也需要 monkey patch 一下 GFile 使其能支持内部 fs 的 flush 和 close

代码更改

思路是希望更改 tb 代码

因此我们可以写下面的 append_tb_writer.py 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
# append_tb_writer.py
import functools
import typing as t

from tensorboard.compat import tf
from tensorboard.compat.tensorflow_stub import compat
from torch.utils.tensorboard import SummaryWriter


class PatchedLocalFileSystem(tf.io.gfile.LocalFileSystem):
stream_map: t.Dict[str, t.TextIO]

def __init__(self):
super().__init__()
self.stream_map = {}

def _stream(self, filename, binary_mode=False, create=True) -> t.Optional[t.TextIO]:
# use append to open file
mode = "ab" if binary_mode else "a"
# generate key due to mode and filename
key = f"{mode}@{filename}"
if key not in self.stream_map or self.stream_map[key] is None:
if not create:
return None
encoding = None if "b" in mode else "utf8"
self.stream_map[key] = open(filename, mode, encoding=encoding)
return self.stream_map[key]

def append(self, filename, file_content, binary_mode=False):
compatify = compat.as_bytes if binary_mode else compat.as_text
# reuse fd to write instead of duplicated open and close
self._stream(filename, binary_mode).write(compatify(file_content))

def flush(self, filename, binary_mode=False):
stream = self._stream(filename, binary_mode, create=False)
if stream:
stream.flush()

def close(self, filename, binary_mode=False):
stream = self._stream(filename, binary_mode, create=False)
if stream:
stream.close()


_origin_gfile = tf.io.gfile.GFile


class PatchedGFile(_origin_gfile):
def flush(self):
super().flush()
assert hasattr(self.fs, "flush")
self.fs.flush(self.filename, self.binary_mode)

def close(self):
super().close()
assert hasattr(self.fs, "close")
self.fs.close(self.filename, self.binary_mode)


@functools.cache
def patch_tb_gfile_and_fs():
fs = PatchedLocalFileSystem()
# monkey patch for registering a default fs instead of tf.io.gfile.LocalFileSystem
tf.io.gfile.register_filesystem("", fs)
tf.io.gfile.GFile = PatchedGFile


def assert_patched(writer: SummaryWriter):
gfile = writer.file_writer.event_writer._general_file_writer
assert isinstance(gfile, PatchedGFile)
assert isinstance(gfile.fs, PatchedLocalFileSystem)

在主函数中,可以这么打上 patch

1
2
3
4
5
from append_tb_writer import patch_tb_gfile_and_fs, assert_patched
patch_tb_gfile_and_fs()
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(tensorboard_dir)
assert_patched(writer)

上面的 assert_patched 主要是需要确保给 writer 打上 patch