Spaces:
Running
Running
import hashlib | |
from io import BytesIO | |
from typing import Optional | |
import safetensors.torch | |
import torch | |
def model_hash(filename): | |
"""Old model hash used by stable-diffusion-webui""" | |
try: | |
with open(filename, "rb") as file: | |
m = hashlib.sha256() | |
file.seek(0x100000) | |
m.update(file.read(0x10000)) | |
return m.hexdigest()[0:8] | |
except FileNotFoundError: | |
return "NOFILE" | |
except IsADirectoryError: # Linux? | |
return "IsADirectory" | |
except PermissionError: # Windows | |
return "IsADirectory" | |
def calculate_sha256(filename): | |
"""New model hash used by stable-diffusion-webui""" | |
try: | |
hash_sha256 = hashlib.sha256() | |
blksize = 1024 * 1024 | |
with open(filename, "rb") as f: | |
for chunk in iter(lambda: f.read(blksize), b""): | |
hash_sha256.update(chunk) | |
return hash_sha256.hexdigest() | |
except FileNotFoundError: | |
return "NOFILE" | |
except IsADirectoryError: # Linux? | |
return "IsADirectory" | |
except PermissionError: # Windows | |
return "IsADirectory" | |
def addnet_hash_legacy(b): | |
"""Old model hash used by sd-webui-additional-networks for .safetensors format files""" | |
m = hashlib.sha256() | |
b.seek(0x100000) | |
m.update(b.read(0x10000)) | |
return m.hexdigest()[0:8] | |
def addnet_hash_safetensors(b): | |
"""New model hash used by sd-webui-additional-networks for .safetensors format files""" | |
hash_sha256 = hashlib.sha256() | |
blksize = 1024 * 1024 | |
b.seek(0) | |
header = b.read(8) | |
n = int.from_bytes(header, "little") | |
offset = n + 8 | |
b.seek(offset) | |
for chunk in iter(lambda: b.read(blksize), b""): | |
hash_sha256.update(chunk) | |
return hash_sha256.hexdigest() | |
def precalculate_safetensors_hashes(tensors, metadata): | |
"""Precalculate the model hashes needed by sd-webui-additional-networks to | |
save time on indexing the model later.""" | |
# Because writing user metadata to the file can change the result of | |
# sd_models.model_hash(), only retain the training metadata for purposes of | |
# calculating the hash, as they are meant to be immutable | |
metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} | |
bytes = safetensors.torch.save(tensors, metadata) | |
b = BytesIO(bytes) | |
model_hash = addnet_hash_safetensors(b) | |
legacy_hash = addnet_hash_legacy(b) | |
return model_hash, legacy_hash | |
def dtype_to_str(dtype: torch.dtype) -> str: | |
# get name of the dtype | |
dtype_name = str(dtype).split(".")[-1] | |
return dtype_name | |
def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: | |
""" | |
Convert a string to a torch.dtype | |
Args: | |
s: string representation of the dtype | |
default_dtype: default dtype to return if s is None | |
Returns: | |
torch.dtype: the corresponding torch.dtype | |
Raises: | |
ValueError: if the dtype is not supported | |
Examples: | |
>>> str_to_dtype("float32") | |
torch.float32 | |
>>> str_to_dtype("fp32") | |
torch.float32 | |
>>> str_to_dtype("float16") | |
torch.float16 | |
>>> str_to_dtype("fp16") | |
torch.float16 | |
>>> str_to_dtype("bfloat16") | |
torch.bfloat16 | |
>>> str_to_dtype("bf16") | |
torch.bfloat16 | |
>>> str_to_dtype("fp8") | |
torch.float8_e4m3fn | |
>>> str_to_dtype("fp8_e4m3fn") | |
torch.float8_e4m3fn | |
>>> str_to_dtype("fp8_e4m3fnuz") | |
torch.float8_e4m3fnuz | |
>>> str_to_dtype("fp8_e5m2") | |
torch.float8_e5m2 | |
>>> str_to_dtype("fp8_e5m2fnuz") | |
torch.float8_e5m2fnuz | |
""" | |
if s is None: | |
return default_dtype | |
if s in ["bf16", "bfloat16"]: | |
return torch.bfloat16 | |
elif s in ["fp16", "float16"]: | |
return torch.float16 | |
elif s in ["fp32", "float32", "float"]: | |
return torch.float32 | |
elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: | |
return torch.float8_e4m3fn | |
elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: | |
return torch.float8_e4m3fnuz | |
elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: | |
return torch.float8_e5m2 | |
elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: | |
return torch.float8_e5m2fnuz | |
elif s in ["fp8", "float8"]: | |
return torch.float8_e4m3fn # default fp8 | |
else: | |
raise ValueError(f"Unsupported dtype: {s}") | |