Skip to content

Commit 7ebd7f2

Browse files
committed
Add back support for old Falcon model type/config (#243)
Two days ago Falcon changed their model type and configuration keys. Yesterday night they reverted the changes. Update Falcon loading to support both old-style and new-style models. Fixes #242.
1 parent f368237 commit 7ebd7f2

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

curated_transformers/models/auto_model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ class AutoDecoder(AutoModel[DecoderModule]):
116116
"gpt_neox": GPTNeoXDecoder,
117117
"llama": LLaMADecoder,
118118
"falcon": FalconDecoder,
119+
"RefinedWeb": FalconDecoder,
120+
"RefinedWebModel": FalconDecoder,
119121
}
120122

121123
@classmethod
@@ -143,6 +145,8 @@ class AutoCausalLM(AutoModel[CausalLMModule[KeyValueCache]]):
143145
"gpt_neox": GPTNeoXCausalLM,
144146
"llama": LLaMACausalLM,
145147
"falcon": FalconCausalLM,
148+
"RefinedWeb": FalconCausalLM,
149+
"RefinedWebModel": FalconCausalLM,
146150
}
147151

148152
@classmethod

curated_transformers/models/falcon/_hf.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,26 @@
1212
EXTRA_KWARG_KEYS = [ATTENTION_DROPOUT, HIDDEN_DROPOUT]
1313

1414

15+
# There are multiple versions of Falcon with different names
16+
# for the same options.
17+
HF_CONFIG_KEY_MAPPING_WITH_COMPAT = {
18+
frozenset({"num_attention_heads", "n_head"}): "num_attention_heads",
19+
frozenset({"num_hidden_layers", "n_layer"}): "num_hidden_layers",
20+
}
21+
1522
HF_CONFIG_KEY_MAPPING = {
1623
"hidden_size": "hidden_width",
1724
"layer_norm_epsilon": "layer_norm_eps",
1825
"multi_query": "multi_query",
19-
"num_attention_heads": "num_attention_heads",
20-
"num_hidden_layers": "num_hidden_layers",
2126
"bias": "use_bias",
2227
"vocab_size": "vocab_size",
2328
}
2429

2530

2631
def convert_hf_config(hf_config: Any) -> FalconConfig:
32+
hf_config_keys = set(hf_config.keys())
2733
missing_keys = tuple(
28-
sorted(set(HF_CONFIG_KEY_MAPPING.keys()).difference(set(hf_config.keys())))
34+
sorted(set(HF_CONFIG_KEY_MAPPING.keys()).difference(hf_config_keys))
2935
)
3036
if len(missing_keys) != 0:
3137
raise ValueError(f"Missing keys in Hugging Face Falcon config: {missing_keys}")
@@ -34,6 +40,16 @@ def convert_hf_config(hf_config: Any) -> FalconConfig:
3440
# Handle config options that are not set in all models.
3541
kwargs.update({k: hf_config[k] for k in EXTRA_KWARG_KEYS if k in hf_config})
3642

43+
for hf_keys, curated in HF_CONFIG_KEY_MAPPING_WITH_COMPAT.items():
44+
key_overlap = list(hf_keys.intersection(hf_config_keys))
45+
if not key_overlap:
46+
raise ValueError(
47+
f"Hugging Face Falcon config must contain one of: {', '.join(sorted(hf_keys))}"
48+
)
49+
# Ideally, we'd check that we only have one overlapping key, but
50+
# I bet that someone will then add both keys 'just to be sure'.
51+
kwargs[curated] = hf_config[key_overlap[0]]
52+
3753
parallel_attention = hf_config.get("parallel_attn", True)
3854

3955
# When new_decoder_architecture is set, the multi_query and parallel_attn

curated_transformers/tests/models/falcon/test_decoder.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,16 @@
2525
# against output without caching.
2626

2727

28+
FALCON_TEST_MODELS = [
29+
"explosion-testing/falcon-test",
30+
"explosion-testing/refined-web-model-test",
31+
]
32+
33+
2834
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
2935
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
30-
def test_decoder(torch_device):
36+
@pytest.mark.parametrize("model", FALCON_TEST_MODELS)
37+
def test_decoder(torch_device, model):
3138
hf_model = transformers.AutoModel.from_pretrained(
3239
"explosion-testing/falcon-test",
3340
# Safe because it is under our control.
@@ -38,9 +45,7 @@ def test_decoder(torch_device):
3845
hf_model.to(torch_device)
3946
hf_model.eval()
4047

41-
model = FalconDecoder.from_hf_hub(
42-
name="explosion-testing/falcon-test", device=torch_device
43-
)
48+
model = FalconDecoder.from_hf_hub(name=model, device=torch_device)
4449
model.eval()
4550

4651
torch.manual_seed(0)
@@ -55,10 +60,9 @@ def test_decoder(torch_device):
5560

5661
@pytest.mark.skipif(not has_hf_transformers, reason="requires huggingface transformers")
5762
@pytest.mark.parametrize("torch_device", TORCH_DEVICES)
58-
def test_decoder_with_cache(torch_device):
59-
model = FalconDecoder.from_hf_hub(
60-
name="explosion-testing/falcon-test", device=torch_device
61-
)
63+
@pytest.mark.parametrize("model", FALCON_TEST_MODELS)
64+
def test_decoder_with_cache(torch_device, model):
65+
model = FalconDecoder.from_hf_hub(name=model, device=torch_device)
6266
model.eval()
6367

6468
torch.manual_seed(0)

0 commit comments

Comments
 (0)