Skip to content

Commit 10a6368

Browse files
committed
Add support for sharded weights.
1 parent 0b63ce0 commit 10a6368

File tree

3 files changed

+131
-6
lines changed

3 files changed

+131
-6
lines changed

keras_hub/src/utils/keras_utils.py

+14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import sys
23

34
import keras
@@ -147,3 +148,16 @@ def get_gpu_names():
147148
]
148149
else:
149150
return [""]
151+
152+
153+
def sharded_weights_available():
154+
"""Whether sharded weights serialization is available.
155+
156+
Returns:
157+
`True` if sharded weights are available, `False` otherwise.
158+
"""
159+
save_weights_signature = inspect.signature(keras.saving.save_weights)
160+
if "max_shard_size" in save_weights_signature.parameters:
161+
return True
162+
else:
163+
return False

keras_hub/src/utils/preset_utils.py

+75-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections
22
import datetime
3+
import functools
34
import inspect
45
import json
6+
import math
57
import os
68
import re
79

@@ -48,6 +50,8 @@
4850
# Weight file names.
4951
MODEL_WEIGHTS_FILE = "model.weights.h5"
5052
TASK_WEIGHTS_FILE = "task.weights.h5"
53+
SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json"
54+
MAX_SHARD_SIZE = 5.0 # This means 5GB.
5155

5256
# HuggingFace filenames.
5357
README_FILE = "README.md"
@@ -647,7 +651,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
647651
backbone = self._load_serialized_object(self.config, **kwargs)
648652
if load_weights:
649653
jax_memory_cleanup(backbone)
650-
backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE))
654+
self._load_backbone_weights(backbone)
651655
return backbone
652656

653657
def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
@@ -697,8 +701,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
697701
task.load_task_weights(task_weights)
698702
else:
699703
jax_memory_cleanup(task.backbone)
700-
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
701-
task.backbone.load_weights(backbone_weights)
704+
self._load_backbone_weights(task.backbone)
702705
return task
703706

704707
def load_preprocessor(
@@ -726,18 +729,59 @@ def _load_serialized_object(self, config, **kwargs):
726729
config["config"] = {**config["config"], **kwargs}
727730
return keras.saving.deserialize_keras_object(config)
728731

732+
def _get_sharded_filenames(self, config_path):
733+
with open(config_path, encoding="utf-8") as config_file:
734+
config = json.load(config_file)
735+
weight_map = config["weight_map"]
736+
return sorted(set(weight_map.values()))
737+
738+
def _load_backbone_weights(self, backbone):
739+
# Detect if the backbone is sharded or not.
740+
has_single_file_weights = check_file_exists(
741+
self.preset, MODEL_WEIGHTS_FILE
742+
)
743+
if has_single_file_weights:
744+
filepath = get_file(self.preset, MODEL_WEIGHTS_FILE)
745+
else:
746+
filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE)
747+
sharded_filenames = self._get_sharded_filenames(filepath)
748+
for sharded_filename in sharded_filenames:
749+
# Download the sharded weights.
750+
_ = get_file(self.preset, sharded_filename)
751+
backbone.load_weights(filepath)
752+
729753

730754
class KerasPresetSaver:
731755
def __init__(self, preset_dir):
732756
os.makedirs(preset_dir, exist_ok=True)
733757
self.preset_dir = preset_dir
734758

735-
def save_backbone(self, backbone):
759+
def save_backbone(self, backbone, max_shard_size=None):
736760
self._save_serialized_object(backbone, config_file=CONFIG_FILE)
737-
backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE)
738-
backbone.save_weights(backbone_weight_path)
739761
self._save_metadata(backbone)
740762

763+
# Save the weights.
764+
backbone_size_in_bytes = self._get_variables_size_in_bytes(
765+
backbone.variables
766+
)
767+
backbone_size_in_gb = backbone_size_in_bytes / (1024**3)
768+
# If the size of the backbone is larger than `MAX_SHARD_SIZE`, save
769+
# sharded weights.
770+
max_shard_size = max_shard_size or MAX_SHARD_SIZE
771+
if backbone_size_in_gb > max_shard_size:
772+
backbone_sharded_weights_config_path = os.path.join(
773+
self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE
774+
)
775+
backbone.save_weights(
776+
backbone_sharded_weights_config_path,
777+
max_shard_size=max_shard_size,
778+
)
779+
else:
780+
backbone_weight_path = os.path.join(
781+
self.preset_dir, MODEL_WEIGHTS_FILE
782+
)
783+
backbone.save_weights(backbone_weight_path)
784+
741785
def save_tokenizer(self, tokenizer):
742786
config_file = TOKENIZER_CONFIG_FILE
743787
if hasattr(tokenizer, "config_file"):
@@ -823,3 +867,28 @@ def _save_metadata(self, layer):
823867
metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
824868
with open(metadata_path, "w") as metadata_file:
825869
metadata_file.write(json.dumps(metadata, indent=4))
870+
871+
def _get_variables_size_in_bytes(self, variables):
872+
@functools.lru_cache(512)
873+
def _compute_memory_size(shape, dtype):
874+
weight_counts = math.prod(shape)
875+
dtype = keras.backend.standardize_dtype(dtype)
876+
dtype_size = int(
877+
(
878+
dtype.replace("bfloat", "")
879+
.replace("float", "")
880+
.replace("uint", "")
881+
.replace("int", "")
882+
.replace("bool", "1")
883+
)
884+
)
885+
return weight_counts * dtype_size
886+
887+
unique_variables = {}
888+
for v in variables:
889+
if id(v) not in unique_variables:
890+
unique_variables[id(v)] = (v.shape, v.dtype)
891+
total_memory_size = 0
892+
for shape, dtype in unique_variables.values():
893+
total_memory_size += _compute_memory_size(shape, dtype)
894+
return total_memory_size / 8

keras_hub/src/utils/preset_utils_test.py

+42
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,54 @@
1010
)
1111
from keras_hub.src.models.bert.bert_backbone import BertBackbone
1212
from keras_hub.src.models.bert.bert_tokenizer import BertTokenizer
13+
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
1314
from keras_hub.src.tests.test_case import TestCase
1415
from keras_hub.src.utils.preset_utils import CONFIG_FILE
16+
from keras_hub.src.utils.preset_utils import get_preset_saver
1517
from keras_hub.src.utils.preset_utils import upload_preset
1618

1719

1820
class PresetUtilsTest(TestCase):
21+
@pytest.mark.large
22+
def test_sharded_weights(self):
23+
# Gemma2 config.
24+
init_kwargs = {
25+
"vocabulary_size": 4096, # 256128
26+
"num_layers": 24, # 46
27+
"num_query_heads": 16, # 32
28+
"num_key_value_heads": 8, # 16
29+
"hidden_dim": 64, # 4608
30+
"intermediate_dim": 128, # 73728
31+
"head_dim": 8, # 128
32+
"sliding_window_size": 5, # 4096
33+
"attention_logit_soft_cap": 50,
34+
"final_logit_soft_cap": 30,
35+
"layer_norm_epsilon": 1e-6,
36+
"query_head_dim_normalize": False,
37+
"use_post_ffw_norm": True,
38+
"use_post_attention_norm": True,
39+
"use_sliding_window_attention": True,
40+
}
41+
backbone = GemmaBackbone(**init_kwargs) # ~4.4MB
42+
43+
# Save the sharded weights.
44+
preset_dir = self.get_temp_dir()
45+
preset_saver = get_preset_saver(preset_dir)
46+
preset_saver.save_backbone(backbone, max_shard_size=0.002)
47+
self.assertTrue(
48+
os.path.exists(os.path.join(preset_dir, "model.weights.json"))
49+
)
50+
self.assertTrue(
51+
os.path.exists(os.path.join(preset_dir, "model_00000.weights.h5"))
52+
)
53+
54+
# Load the sharded weights.
55+
revived_backbone = GemmaBackbone.from_preset(preset_dir)
56+
for v1, v2 in zip(
57+
backbone.trainable_variables, revived_backbone.trainable_variables
58+
):
59+
self.assertAllClose(v1, v2)
60+
1961
@pytest.mark.large
2062
def test_preset_errors(self):
2163
with self.assertRaisesRegex(ValueError, "must be a string"):

0 commit comments

Comments
 (0)