Spaces:
Running
Running
File size: 4,860 Bytes
b72ab63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
"""Helper functions for a standard streaming compression API"""
from zipfile import ZipFile
import fsspec.utils
from fsspec.spec import AbstractBufferedFile
def noop_file(file, mode, **kwargs):
return file
# TODO: files should also be available as contexts
# should be functions of the form func(infile, mode=, **kwargs) -> file-like
compr = {None: noop_file}
def register_compression(name, callback, extensions, force=False):
"""Register an "inferable" file compression type.
Registers transparent file compression type for use with fsspec.open.
Compression can be specified by name in open, or "infer"-ed for any files
ending with the given extensions.
Args:
name: (str) The compression type name. Eg. "gzip".
callback: A callable of form (infile, mode, **kwargs) -> file-like.
Accepts an input file-like object, the target mode and kwargs.
Returns a wrapped file-like object.
extensions: (str, Iterable[str]) A file extension, or list of file
extensions for which to infer this compression scheme. Eg. "gz".
force: (bool) Force re-registration of compression type or extensions.
Raises:
ValueError: If name or extensions already registered, and not force.
"""
if isinstance(extensions, str):
extensions = [extensions]
# Validate registration
if name in compr and not force:
raise ValueError(f"Duplicate compression registration: {name}")
for ext in extensions:
if ext in fsspec.utils.compressions and not force:
raise ValueError(f"Duplicate compression file extension: {ext} ({name})")
compr[name] = callback
for ext in extensions:
fsspec.utils.compressions[ext] = name
def unzip(infile, mode="rb", filename=None, **kwargs):
if "r" not in mode:
filename = filename or "file"
z = ZipFile(infile, mode="w", **kwargs)
fo = z.open(filename, mode="w")
fo.close = lambda closer=fo.close: closer() or z.close()
return fo
z = ZipFile(infile)
if filename is None:
filename = z.namelist()[0]
return z.open(filename, mode="r", **kwargs)
register_compression("zip", unzip, "zip")
try:
from bz2 import BZ2File
except ImportError:
pass
else:
register_compression("bz2", BZ2File, "bz2")
try: # pragma: no cover
from isal import igzip
def isal(infile, mode="rb", **kwargs):
return igzip.IGzipFile(fileobj=infile, mode=mode, **kwargs)
register_compression("gzip", isal, "gz")
except ImportError:
from gzip import GzipFile
register_compression(
"gzip", lambda f, **kwargs: GzipFile(fileobj=f, **kwargs), "gz"
)
try:
from lzma import LZMAFile
register_compression("lzma", LZMAFile, "lzma")
register_compression("xz", LZMAFile, "xz")
except ImportError:
pass
try:
import lzmaffi
register_compression("lzma", lzmaffi.LZMAFile, "lzma", force=True)
register_compression("xz", lzmaffi.LZMAFile, "xz", force=True)
except ImportError:
pass
class SnappyFile(AbstractBufferedFile):
def __init__(self, infile, mode, **kwargs):
import snappy
super().__init__(
fs=None, path="snappy", mode=mode.strip("b") + "b", size=999999999, **kwargs
)
self.infile = infile
if "r" in mode:
self.codec = snappy.StreamDecompressor()
else:
self.codec = snappy.StreamCompressor()
def _upload_chunk(self, final=False):
self.buffer.seek(0)
out = self.codec.add_chunk(self.buffer.read())
self.infile.write(out)
return True
def seek(self, loc, whence=0):
raise NotImplementedError("SnappyFile is not seekable")
def seekable(self):
return False
def _fetch_range(self, start, end):
"""Get the specified set of bytes from remote"""
data = self.infile.read(end - start)
return self.codec.decompress(data)
try:
import snappy
snappy.compress
# Snappy may use the .sz file extension, but this is not part of the
# standard implementation.
register_compression("snappy", SnappyFile, [])
except (ImportError, NameError, AttributeError):
pass
try:
import lz4.frame
register_compression("lz4", lz4.frame.open, "lz4")
except ImportError:
pass
try:
import zstandard as zstd
def zstandard_file(infile, mode="rb"):
if "r" in mode:
cctx = zstd.ZstdDecompressor()
return cctx.stream_reader(infile)
else:
cctx = zstd.ZstdCompressor(level=10)
return cctx.stream_writer(infile)
register_compression("zstd", zstandard_file, "zst")
except ImportError:
pass
def available_compressions():
"""Return a list of the implemented compressions."""
return list(compr)
|