1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
| import functools import typing as t
from tensorboard.compat import tf from tensorboard.compat.tensorflow_stub import compat from torch.utils.tensorboard import SummaryWriter
class PatchedLocalFileSystem(tf.io.gfile.LocalFileSystem): stream_map: t.Dict[str, t.TextIO]
def __init__(self): super().__init__() self.stream_map = {}
def _stream(self, filename, binary_mode=False, create=True) -> t.Optional[t.TextIO]: mode = "ab" if binary_mode else "a" key = f"{mode}@{filename}" if key not in self.stream_map or self.stream_map[key] is None: if not create: return None encoding = None if "b" in mode else "utf8" self.stream_map[key] = open(filename, mode, encoding=encoding) return self.stream_map[key]
def append(self, filename, file_content, binary_mode=False): compatify = compat.as_bytes if binary_mode else compat.as_text self._stream(filename, binary_mode).write(compatify(file_content))
def flush(self, filename, binary_mode=False): stream = self._stream(filename, binary_mode, create=False) if stream: stream.flush()
def close(self, filename, binary_mode=False): stream = self._stream(filename, binary_mode, create=False) if stream: stream.close()
_origin_gfile = tf.io.gfile.GFile
class PatchedGFile(_origin_gfile): def flush(self): super().flush() assert hasattr(self.fs, "flush") self.fs.flush(self.filename, self.binary_mode)
def close(self): super().close() assert hasattr(self.fs, "close") self.fs.close(self.filename, self.binary_mode)
@functools.cache def patch_tb_gfile_and_fs(): fs = PatchedLocalFileSystem() tf.io.gfile.register_filesystem("", fs) tf.io.gfile.GFile = PatchedGFile
def assert_patched(writer: SummaryWriter): gfile = writer.file_writer.event_writer._general_file_writer assert isinstance(gfile, PatchedGFile) assert isinstance(gfile.fs, PatchedLocalFileSystem)
|