Skip to content

Commit e3285a8

Browse files
authored
Abstracting model code (#36)
* Initial commit for refactoring skip model code using class factories and abstract classes. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Minor fixes Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add missing lines Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent afb8311 commit e3285a8

7 files changed

Lines changed: 644 additions & 958 deletions

run_benchmark.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def main():
397397
checkpoint = config._name_or_path
398398
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
399399
tokenizer.pad_token = tokenizer.eos_token
400-
model_name = checkpoint.split("-")[0].capitalize()
400+
model_name = checkpoint.split("/")[1].split("-")[0].capitalize()
401401

402402
# Get test prompts
403403
test_prompts = get_diverse_test_prompts()

src/configuration_skip.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
from transformers import PretrainedConfig
2+
import os
3+
from typing import Union, Any, Type
4+
5+
6+
7+
def build_skip_config(base_config_class: type[PretrainedConfig], model_type_name: str) -> type[PretrainedConfig]:
8+
class SkipConnectionConfig(base_config_class):
9+
model_type: str = model_type_name
10+
has_no_defaults_at_init: bool = True
11+
12+
def __init__(self,
13+
sparsity: float,
14+
predictor_loss_type: str = "bce",
15+
predictor_temperature: float = 1.0,
16+
predictor_loss_alpha: float = 1.0,
17+
predictor_loss_weight: float = 0.1,
18+
use_optimized_weight_cache: bool = True,
19+
**kwargs):
20+
self._sparsity = sparsity
21+
self.predictor_loss_type = predictor_loss_type
22+
self.predictor_temperature = predictor_temperature
23+
self.predictor_loss_alpha = predictor_loss_alpha
24+
self.predictor_loss_weight = predictor_loss_weight
25+
self.use_optimized_weight_cache = use_optimized_weight_cache
26+
super().__init__(**kwargs)
27+
28+
@property
29+
def sparsity(self):
30+
return self._sparsity
31+
32+
@sparsity.setter
33+
def sparsity(self, value):
34+
self._sparsity = value
35+
36+
@classmethod
37+
def from_json_file(cls, json_file: Union[str, os.PathLike]):
38+
"""
39+
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
40+
41+
Args:
42+
json_file (`str` or `os.PathLike`):
43+
Path to the JSON file containing the parameters.
44+
45+
Returns:
46+
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
47+
48+
"""
49+
config_dict = cls._dict_from_json_file(json_file)
50+
return cls(**config_dict)
51+
52+
@classmethod
53+
def from_dict(cls, config_dict: dict[str, Any], **kwargs):
54+
if "name_or_path" in kwargs and ("name_or_path" in config_dict or "_name_or_path" in config_dict):
55+
del kwargs["name_or_path"]
56+
return super().from_dict(config_dict, **kwargs)
57+
return SkipConnectionConfig

0 commit comments

Comments
 (0)