|
1 | 1 | import collections
|
2 | 2 | import datetime
|
| 3 | +import functools |
3 | 4 | import inspect
|
4 | 5 | import json
|
| 6 | +import math |
5 | 7 | import os
|
6 | 8 | import re
|
7 | 9 |
|
|
48 | 50 | # Weight file names.
|
49 | 51 | MODEL_WEIGHTS_FILE = "model.weights.h5"
|
50 | 52 | TASK_WEIGHTS_FILE = "task.weights.h5"
|
| 53 | +SHARDED_MODEL_WEIGHTS_CONFIG_FILE = "model.weights.json" |
| 54 | +MAX_SHARD_SIZE = 5.0 # This means 5GB. |
51 | 55 |
|
52 | 56 | # HuggingFace filenames.
|
53 | 57 | README_FILE = "README.md"
|
@@ -647,7 +651,7 @@ def load_backbone(self, cls, load_weights, **kwargs):
|
647 | 651 | backbone = self._load_serialized_object(self.config, **kwargs)
|
648 | 652 | if load_weights:
|
649 | 653 | jax_memory_cleanup(backbone)
|
650 |
| - backbone.load_weights(get_file(self.preset, MODEL_WEIGHTS_FILE)) |
| 654 | + self._load_backbone_weights(backbone) |
651 | 655 | return backbone
|
652 | 656 |
|
653 | 657 | def load_tokenizer(self, cls, config_file=TOKENIZER_CONFIG_FILE, **kwargs):
|
@@ -697,8 +701,7 @@ def load_task(self, cls, load_weights, load_task_weights, **kwargs):
|
697 | 701 | task.load_task_weights(task_weights)
|
698 | 702 | else:
|
699 | 703 | jax_memory_cleanup(task.backbone)
|
700 |
| - backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) |
701 |
| - task.backbone.load_weights(backbone_weights) |
| 704 | + self._load_backbone_weights(task.backbone) |
702 | 705 | return task
|
703 | 706 |
|
704 | 707 | def load_preprocessor(
|
@@ -726,18 +729,59 @@ def _load_serialized_object(self, config, **kwargs):
|
726 | 729 | config["config"] = {**config["config"], **kwargs}
|
727 | 730 | return keras.saving.deserialize_keras_object(config)
|
728 | 731 |
|
| 732 | + def _get_sharded_filenames(self, config_path): |
| 733 | + with open(config_path, encoding="utf-8") as config_file: |
| 734 | + config = json.load(config_file) |
| 735 | + weight_map = config["weight_map"] |
| 736 | + return sorted(set(weight_map.values())) |
| 737 | + |
| 738 | + def _load_backbone_weights(self, backbone): |
| 739 | + # Detect if the backbone is sharded or not. |
| 740 | + has_single_file_weights = check_file_exists( |
| 741 | + self.preset, MODEL_WEIGHTS_FILE |
| 742 | + ) |
| 743 | + if has_single_file_weights: |
| 744 | + filepath = get_file(self.preset, MODEL_WEIGHTS_FILE) |
| 745 | + else: |
| 746 | + filepath = get_file(self.preset, SHARDED_MODEL_WEIGHTS_CONFIG_FILE) |
| 747 | + sharded_filenames = self._get_sharded_filenames(filepath) |
| 748 | + for sharded_filename in sharded_filenames: |
| 749 | + # Download the sharded weights. |
| 750 | + _ = get_file(self.preset, sharded_filename) |
| 751 | + backbone.load_weights(filepath) |
| 752 | + |
729 | 753 |
|
730 | 754 | class KerasPresetSaver:
|
731 | 755 | def __init__(self, preset_dir):
|
732 | 756 | os.makedirs(preset_dir, exist_ok=True)
|
733 | 757 | self.preset_dir = preset_dir
|
734 | 758 |
|
735 |
| - def save_backbone(self, backbone): |
| 759 | + def save_backbone(self, backbone, max_shard_size=None): |
736 | 760 | self._save_serialized_object(backbone, config_file=CONFIG_FILE)
|
737 |
| - backbone_weight_path = os.path.join(self.preset_dir, MODEL_WEIGHTS_FILE) |
738 |
| - backbone.save_weights(backbone_weight_path) |
739 | 761 | self._save_metadata(backbone)
|
740 | 762 |
|
| 763 | + # Save the weights. |
| 764 | + backbone_size_in_bytes = self._get_variables_size_in_bytes( |
| 765 | + backbone.variables |
| 766 | + ) |
| 767 | + backbone_size_in_gb = backbone_size_in_bytes / (1024**3) |
| 768 | + # If the size of the backbone is larger than `MAX_SHARD_SIZE`, save |
| 769 | + # sharded weights. |
| 770 | + max_shard_size = max_shard_size or MAX_SHARD_SIZE |
| 771 | + if backbone_size_in_gb > max_shard_size: |
| 772 | + backbone_sharded_weights_config_path = os.path.join( |
| 773 | + self.preset_dir, SHARDED_MODEL_WEIGHTS_CONFIG_FILE |
| 774 | + ) |
| 775 | + backbone.save_weights( |
| 776 | + backbone_sharded_weights_config_path, |
| 777 | + max_shard_size=max_shard_size, |
| 778 | + ) |
| 779 | + else: |
| 780 | + backbone_weight_path = os.path.join( |
| 781 | + self.preset_dir, MODEL_WEIGHTS_FILE |
| 782 | + ) |
| 783 | + backbone.save_weights(backbone_weight_path) |
| 784 | + |
741 | 785 | def save_tokenizer(self, tokenizer):
|
742 | 786 | config_file = TOKENIZER_CONFIG_FILE
|
743 | 787 | if hasattr(tokenizer, "config_file"):
|
@@ -823,3 +867,28 @@ def _save_metadata(self, layer):
|
823 | 867 | metadata_path = os.path.join(self.preset_dir, METADATA_FILE)
|
824 | 868 | with open(metadata_path, "w") as metadata_file:
|
825 | 869 | metadata_file.write(json.dumps(metadata, indent=4))
|
| 870 | + |
| 871 | + def _get_variables_size_in_bytes(self, variables): |
| 872 | + @functools.lru_cache(512) |
| 873 | + def _compute_memory_size(shape, dtype): |
| 874 | + weight_counts = math.prod(shape) |
| 875 | + dtype = keras.backend.standardize_dtype(dtype) |
| 876 | + dtype_size = int( |
| 877 | + ( |
| 878 | + dtype.replace("bfloat", "") |
| 879 | + .replace("float", "") |
| 880 | + .replace("uint", "") |
| 881 | + .replace("int", "") |
| 882 | + .replace("bool", "1") |
| 883 | + ) |
| 884 | + ) |
| 885 | + return weight_counts * dtype_size |
| 886 | + |
| 887 | + unique_variables = {} |
| 888 | + for v in variables: |
| 889 | + if id(v) not in unique_variables: |
| 890 | + unique_variables[id(v)] = (v.shape, v.dtype) |
| 891 | + total_memory_size = 0 |
| 892 | + for shape, dtype in unique_variables.values(): |
| 893 | + total_memory_size += _compute_memory_size(shape, dtype) |
| 894 | + return total_memory_size / 8 |
0 commit comments