Text Generation
Transformers
Safetensors
PyTorch
nemotron_h
nvidia
nemotron-3
latent-moe
mtp
conversational
custom_code
Eval Results
dmax123 commited on
Commit
40b9aef
·
verified ·
1 Parent(s): 101e33c

Add files using upload-large-folder tool

Browse files
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. __init__.py +0 -0
  3. chat_template.jinja +209 -0
  4. config.json +69 -0
  5. configuration_nemotron_h.py +410 -0
  6. generation_config.json +13 -0
  7. model-00001-of-00050.safetensors +3 -0
  8. model-00002-of-00050.safetensors +3 -0
  9. model-00003-of-00050.safetensors +3 -0
  10. model-00004-of-00050.safetensors +3 -0
  11. model-00005-of-00050.safetensors +3 -0
  12. model-00006-of-00050.safetensors +3 -0
  13. model-00007-of-00050.safetensors +3 -0
  14. model-00008-of-00050.safetensors +3 -0
  15. model-00009-of-00050.safetensors +3 -0
  16. model-00010-of-00050.safetensors +3 -0
  17. model-00011-of-00050.safetensors +3 -0
  18. model-00012-of-00050.safetensors +3 -0
  19. model-00013-of-00050.safetensors +3 -0
  20. model-00014-of-00050.safetensors +3 -0
  21. model-00015-of-00050.safetensors +3 -0
  22. model-00016-of-00050.safetensors +3 -0
  23. model-00017-of-00050.safetensors +3 -0
  24. model-00019-of-00050.safetensors +3 -0
  25. model-00020-of-00050.safetensors +3 -0
  26. model-00021-of-00050.safetensors +3 -0
  27. model-00022-of-00050.safetensors +3 -0
  28. model-00023-of-00050.safetensors +3 -0
  29. model-00024-of-00050.safetensors +3 -0
  30. model-00026-of-00050.safetensors +3 -0
  31. model-00033-of-00050.safetensors +3 -0
  32. model-00036-of-00050.safetensors +3 -0
  33. model-00038-of-00050.safetensors +3 -0
  34. model-00040-of-00050.safetensors +3 -0
  35. model-00041-of-00050.safetensors +3 -0
  36. model-00042-of-00050.safetensors +3 -0
  37. model-00043-of-00050.safetensors +3 -0
  38. model-00044-of-00050.safetensors +3 -0
  39. model-00045-of-00050.safetensors +3 -0
  40. model-00046-of-00050.safetensors +3 -0
  41. model-00047-of-00050.safetensors +3 -0
  42. model-00048-of-00050.safetensors +3 -0
  43. model-00049-of-00050.safetensors +3 -0
  44. model-00050-of-00050.safetensors +3 -0
  45. model.safetensors.index.json +0 -0
  46. modeling_nemotron_h.py +1754 -0
  47. special_tokens_map.json +30 -0
  48. super_v3_reasoning_parser.py +29 -0
  49. tokenizer.json +3 -0
  50. tokenizer_config.json +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
__init__.py ADDED
File without changes
chat_template.jinja ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {% macro render_extra_keys(json_dict, handled_keys) %}
2
+ {%- if json_dict is mapping %}
3
+ {%- for json_key in json_dict if json_key not in handled_keys %}
4
+ {%- if json_dict[json_key] is mapping or (json_dict[json_key] is sequence and json_dict[json_key] is not string) %}
5
+ {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '</' ~ json_key ~ '>' }}
6
+ {%- else %}
7
+ {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '</' ~ json_key ~ '>' }}
8
+ {%- endif %}
9
+ {%- endfor %}
10
+ {%- endif %}
11
+ {% endmacro %}
12
+ {%- set enable_thinking = enable_thinking if enable_thinking is defined else True %}
13
+ {%- set low_effort = low_effort if low_effort is defined else False %}
14
+ {%- set truncate_history_thinking = truncate_history_thinking if truncate_history_thinking is defined else True %}
15
+
16
+ {%- set ns = namespace(last_user_idx = -1) %}
17
+ {%- set loop_messages = messages %}
18
+ {%- for m in loop_messages %}
19
+ {%- if m["role"] == "user" %}
20
+ {%- set ns.last_user_idx = loop.index0 %}
21
+ {%- endif %}
22
+ {%- endfor %}
23
+
24
+ {%- if messages[0]["role"] == "system" %}
25
+ {%- set system_message = messages[0]["content"] %}
26
+ {%- set loop_messages = messages[1:] %}
27
+ {%- else %}
28
+ {%- set system_message = "" %}
29
+ {%- set loop_messages = messages %}
30
+ {%- endif %}
31
+ {%- if not tools is defined %}
32
+ {%- set tools = [] %}
33
+ {%- endif %}
34
+ {# Recompute last_user_idx relative to loop_messages after handling system #}
35
+ {%- set ns = namespace(last_user_idx = -1) %}
36
+ {%- for m in loop_messages %}
37
+ {%- if m["role"] == "user" %}
38
+ {%- set ns.last_user_idx = loop.index0 %}
39
+ {%- endif %}
40
+ {%- endfor %}
41
+ {%- if system_message is defined %}
42
+ {{- "<|im_start|>system\n" + system_message }}
43
+ {%- else %}
44
+ {%- if tools is iterable and tools | length > 0 %}
45
+ {{- "<|im_start|>system\n" }}
46
+ {%- endif %}
47
+ {%- endif %}
48
+ {%- if tools is iterable and tools | length > 0 %}
49
+ {%- if system_message is defined and system_message | length > 0 %}
50
+ {{- "\n\n" }}
51
+ {%- endif %}
52
+ {{- "# Tools\n\nYou have access to the following functions:\n\n" }}
53
+ {{- "<tools>" }}
54
+ {%- for tool in tools %}
55
+ {%- if tool.function is defined %}
56
+ {%- set tool = tool.function %}
57
+ {%- endif %}
58
+ {{- "\n<function>\n<name>" ~ tool.name ~ "</name>" }}
59
+ {%- if tool.description is defined %}
60
+ {{- '\n<description>' ~ (tool.description | trim) ~ '</description>' }}
61
+ {%- endif %}
62
+ {{- '\n<parameters>' }}
63
+ {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %}
64
+ {%- for param_name, param_fields in tool.parameters.properties|items %}
65
+ {{- '\n<parameter>' }}
66
+ {{- '\n<name>' ~ param_name ~ '</name>' }}
67
+ {%- if param_fields.type is defined %}
68
+ {{- '\n<type>' ~ (param_fields.type | string) ~ '</type>' }}
69
+ {%- endif %}
70
+ {%- if param_fields.description is defined %}
71
+ {{- '\n<description>' ~ (param_fields.description | trim) ~ '</description>' }}
72
+ {%- endif %}
73
+ {%- if param_fields.enum is defined %}
74
+ {{- '\n<enum>' ~ (param_fields.enum | tojson | safe) ~ '</enum>' }}
75
+ {%- endif %}
76
+ {%- set handled_keys = ['name', 'type', 'description', 'enum'] %}
77
+ {{- render_extra_keys(param_fields, handled_keys) }}
78
+ {{- '\n</parameter>' }}
79
+ {%- endfor %}
80
+ {%- endif %}
81
+ {% set handled_keys = ['type', 'properties', 'required'] %}
82
+ {{- render_extra_keys(tool.parameters, handled_keys) }}
83
+ {%- if tool.parameters is defined and tool.parameters.required is defined %}
84
+ {{- '\n<required>' ~ (tool.parameters.required | tojson | safe) ~ '</required>' }}
85
+ {%- endif %}
86
+ {{- '\n</parameters>' }}
87
+ {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %}
88
+ {{- render_extra_keys(tool, handled_keys) }}
89
+ {{- '\n</function>' }}
90
+ {%- endfor %}
91
+ {{- "\n</tools>" }}
92
+
93
+ {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>' }}
94
+ {%- endif %}
95
+
96
+
97
+ {%- if system_message is defined %}
98
+ {{- '<|im_end|>\n' }}
99
+ {%- else %}
100
+ {%- if tools is iterable and tools | length > 0 %}
101
+ {{- '<|im_end|>\n' }}
102
+ {%- endif %}
103
+ {%- endif %}
104
+
105
+ {%- for message in loop_messages %}
106
+ {%- if message.role == "assistant" %}
107
+ {# Add reasoning content in to content field for unified processing below. #}
108
+ {%- if message.reasoning_content is defined and message.reasoning_content is string and message.reasoning_content | trim | length > 0 %}
109
+ {%- set content = "<think>\n" ~ message.reasoning_content ~ "\n</think>\n" ~ (message.content | default('', true)) %}
110
+ {%- else %}
111
+ {%- set content = message.content | default('', true) %}
112
+ {%- if content is string -%}
113
+ {# Allow downstream logic to to take care of broken thought, only handle coherent reasoning here. #}
114
+ {%- if '<think>' not in content and '</think>' not in content -%}
115
+ {%- set content = "<think></think>" ~ content -%}
116
+ {%- endif -%}
117
+ {%- else -%}
118
+ {%- set content = content -%}
119
+ {%- endif -%}
120
+ {%- endif %}
121
+ {%- if message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %}
122
+ {# Assistant message has tool calls. #}
123
+ {{- '<|im_start|>assistant\n' }}
124
+ {%- set include_content = not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
125
+ {%- if content is string and content | trim | length > 0 %}
126
+ {%- if include_content %}
127
+ {{- (content | trim) ~ '\n' -}}
128
+ {%- else %}
129
+ {%- set c = (content | string) %}
130
+ {%- if '</think>' in c %}
131
+ {# Keep only content after the last closing think. Also generation prompt causes this. #}
132
+ {%- set c = c.split('</think>')[-1] %}
133
+ {%- elif '<think>' in c %}
134
+ {# If <think> was opened but never closed, drop the trailing think segment #}
135
+ {%- set c = c.split('<think>')[0] %}
136
+ {%- endif %}
137
+ {%- set c = "<think></think>" ~ c | trim %}
138
+ {%- if c | length > 0 %}
139
+ {{- c ~ '\n' -}}
140
+ {%- endif %}
141
+ {%- endif %}
142
+ {%- else %}
143
+ {{- "<think></think>" -}}
144
+ {%- endif %}
145
+ {%- for tool_call in message.tool_calls %}
146
+ {%- if tool_call.function is defined %}
147
+ {%- set tool_call = tool_call.function %}
148
+ {%- endif %}
149
+ {{- '<tool_call>\n<function=' ~ tool_call.name ~ '>\n' -}}
150
+ {%- if tool_call.arguments is defined %}
151
+ {%- for args_name, args_value in tool_call.arguments|items %}
152
+ {{- '<parameter=' ~ args_name ~ '>\n' -}}
153
+ {%- set args_value = args_value | tojson | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %}
154
+ {{- args_value ~ '\n</parameter>\n' -}}
155
+ {%- endfor %}
156
+ {%- endif %}
157
+ {{- '</function>\n</tool_call>\n' -}}
158
+ {%- endfor %}
159
+ {{- '<|im_end|>\n' }}
160
+ {%- else %}
161
+ {# Assistant message doesn't have tool calls. #}
162
+ {%- if not (truncate_history_thinking and loop.index0 < ns.last_user_idx) %}
163
+ {{- '<|im_start|>assistant\n' ~ (content | default('', true) | string | trim) ~ '<|im_end|>\n' }}
164
+ {%- else %}
165
+ {%- set c = (content | default('', true) | string) %}
166
+ {%- if '<think>' in c and '</think>' in c %}
167
+ {%- set c = "<think></think>" ~ c.split('</think>')[-1] %}
168
+ {%- endif %}
169
+ {%- set c = c | trim %}
170
+ {%- if c | length > 0 %}
171
+ {{- '<|im_start|>assistant\n' ~ c ~ '<|im_end|>\n' }}
172
+ {%- else %}
173
+ {{- '<|im_start|>assistant\n<|im_end|>\n' }}
174
+ {%- endif %}
175
+ {%- endif %}
176
+ {%- endif %}
177
+ {%- elif message.role == "user" or message.role == "system" %}
178
+ {{- '<|im_start|>' + message.role + '\n' }}
179
+ {%- set content = message.content | string %}
180
+ {%- if message.role == "user" and loop.index0 == ns.last_user_idx and low_effort %}
181
+ {{- content + '\n\n{reasoning effort: low}' }}
182
+ {%- else %}
183
+ {{- content }}
184
+ {%- endif %}
185
+ {{- '<|im_end|>\n' }}
186
+ {%- elif message.role == "tool" %}
187
+ {%- if loop.previtem and loop.previtem.role != "tool" %}
188
+ {{- '<|im_start|>user\n' }}
189
+ {%- endif %}
190
+ {{- '<tool_response>\n' }}
191
+ {{- message.content }}
192
+ {{- '\n</tool_response>\n' }}
193
+ {%- if not loop.last and loop.nextitem.role != "tool" %}
194
+ {{- '<|im_end|>\n' }}
195
+ {%- elif loop.last %}
196
+ {{- '<|im_end|>\n' }}
197
+ {%- endif %}
198
+ {%- else %}
199
+ {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }}
200
+ {%- endif %}
201
+ {%- endfor %}
202
+
203
+ {%- if add_generation_prompt %}
204
+ {%- if enable_thinking %}
205
+ {{- '<|im_start|>assistant\n<think>\n' }}
206
+ {%- else %}
207
+ {{- '<|im_start|>assistant\n<think></think>' }}
208
+ {%- endif %}
209
+ {%- endif %}
config.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NemotronHForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_nemotron_h.NemotronHConfig",
9
+ "AutoModelForCausalLM": "modeling_nemotron_h.NemotronHForCausalLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "chunk_size": 128,
13
+ "conv_kernel": 4,
14
+ "dtype": "bfloat16",
15
+ "eos_token_id": 2,
16
+ "expand": 2,
17
+ "head_dim": 128,
18
+ "hidden_dropout": 0.0,
19
+ "hidden_size": 4096,
20
+ "hybrid_override_pattern": "MEMEMEM*EMEMEMEM*EMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEM*EMEMEMEME",
21
+ "initializer_range": 0.02,
22
+ "intermediate_size": 2688,
23
+ "layer_norm_epsilon": 1e-05,
24
+ "mamba_head_dim": 64,
25
+ "mamba_hidden_act": "silu",
26
+ "mamba_num_heads": 128,
27
+ "mamba_proj_bias": false,
28
+ "mamba_ssm_cache_dtype": "float32",
29
+ "max_position_embeddings": 262144,
30
+ "mlp_bias": false,
31
+ "mlp_hidden_act": "relu2",
32
+ "model_type": "nemotron_h",
33
+ "moe_intermediate_size": 2688,
34
+ "moe_latent_size": 1024,
35
+ "moe_shared_expert_intermediate_size": 5376,
36
+ "moe_shared_expert_overlap": false,
37
+ "mtp_hybrid_override_pattern": "*E",
38
+ "n_group": 1,
39
+ "n_groups": 8,
40
+ "n_routed_experts": 512,
41
+ "n_shared_experts": 1,
42
+ "norm_eps": 1e-05,
43
+ "norm_topk_prob": true,
44
+ "num_attention_heads": 32,
45
+ "num_experts_per_tok": 22,
46
+ "num_hidden_layers": 88,
47
+ "num_key_value_heads": 2,
48
+ "num_logits_to_keep": 1,
49
+ "num_nextn_predict_layers": 1,
50
+ "pad_token_id": 0,
51
+ "partial_rotary_factor": 1.0,
52
+ "rescale_prenorm_residual": true,
53
+ "residual_in_fp32": false,
54
+ "rope_theta": 10000,
55
+ "routed_scaling_factor": 5.0,
56
+ "sliding_window": null,
57
+ "ssm_state_size": 128,
58
+ "tie_word_embeddings": false,
59
+ "time_step_floor": 0.0001,
60
+ "time_step_max": 0.1,
61
+ "time_step_min": 0.001,
62
+ "topk_group": 1,
63
+ "transformers_version": "4.57.6",
64
+ "use_bias": false,
65
+ "use_cache": true,
66
+ "use_conv_bias": true,
67
+ "use_mamba_kernels": true,
68
+ "vocab_size": 131072
69
+ }
configuration_nemotron_h.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """NemotronH model configuration"""
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+ from transformers.utils import logging
18
+
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+
23
+ class NemotronHConfig(PretrainedConfig):
24
+ r"""
25
+ This is the configuration class to store the configuration of a [`NemotronHModel`]. It is used to instantiate a
26
+ NemotronH model according to the specified arguments, defining the model architecture. Instantiating a configuration
27
+ with the defaults will yield a similar configuration to that of NVIDIA-Nemotron-3-Nano-30B-A3B-BF16 [nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16](https://huggingface.co/nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16).
28
+
29
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
30
+ documentation from [`PretrainedConfig`] for more information.
31
+
32
+
33
+ Args:
34
+ vocab_size (`int`, *optional*, defaults to 131072):
35
+ Vocabulary size of the NemotronH model. Defines the number of different tokens that can be represented by
36
+ the `inputs_ids` passed when calling [`NemotronHModel`].
37
+ hidden_size (`int`, *optional*, defaults to 4096):
38
+ Dimension of the hidden representations.
39
+ layers_block_type (`list`, *optional*):
40
+ Explicit list of layer types for each layer. Each element must be one of: "mamba", "attention", or "moe".
41
+ The number of layers is determined by the length of this list.
42
+ num_hidden_layers (`int`, *optional*):
43
+ Number of hidden layers in the Transformer encoder. This parameter is deprecated and only kept for
44
+ backward compatibility. The number of layers is now determined by the length of `layers_block_type`.
45
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
46
+ Whether the model's input and output word embeddings should be tied.
47
+ use_cache (`bool`, *optional*, defaults to `True`):
48
+ Whether or not the model should return the last key/values attentions.
49
+ num_logits_to_keep (`int`, *optional*, defaults to 1):
50
+ Number of prompt logits to calculate during generation. If `None`, all logits will be calculated.
51
+ pad_token_id (`int`, *optional*, defaults to 0):
52
+ The id of the padding token.
53
+ bos_token_id (`int`, *optional*, defaults to 1):
54
+ The id of the "beginning-of-sequence" token.
55
+ eos_token_id (`int`, *optional*, defaults to 2):
56
+ The id of the "end-of-sequence" token.
57
+ num_attention_heads (`int`, *optional*, defaults to 32):
58
+ Number of attention heads for each attention layer in the Transformer encoder.
59
+ num_key_value_heads (`int`, *optional*, defaults to 8):
60
+ This is the number of key_value heads that should be used to implement Grouped Query Attention.
61
+ head_dim (`int`, *optional*, defaults to 128):
62
+ Dimension of each attention head.
63
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
64
+ The maximum sequence length that this model might ever be used with.
65
+ attention_bias (`bool`, *optional*, defaults to `False`):
66
+ Whether to use bias in attention layers.
67
+ attention_dropout (`float`, *optional*, defaults to 0.0):
68
+ The dropout ratio for the attention probabilities.
69
+ sliding_window (`int`, *optional*):
70
+ Sliding window attention window size.
71
+ intermediate_size (`int`, *optional*, defaults to 21504):
72
+ Dimension of the MLP representations.
73
+ mlp_hidden_act (`str`, *optional*, defaults to `"relu2"`):
74
+ The non-linear activation function in the MLP layers.
75
+ mlp_bias (`bool`, *optional*, defaults to `False`):
76
+ Whether to use bias in MLP layers.
77
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
78
+ Flag indicating whether or not to use the fast mamba kernels.
79
+ ssm_state_size (`int`, *optional*, defaults to 128):
80
+ The dimension of the mamba state space latents.
81
+ mamba_num_heads (`int`, *optional*, defaults to 128):
82
+ Number of heads in Mamba layers.
83
+ mamba_n_groups (`int`, *optional*, defaults to 8):
84
+ Number of groups in Mamba layers.
85
+ mamba_head_dim (`int`, *optional*, defaults to 64):
86
+ Dimension of each Mamba head.
87
+ mamba_d_conv (`int`, *optional*, defaults to 4):
88
+ The size of the mamba convolution kernel.
89
+ mamba_expand (`int`, *optional*, defaults to 2):
90
+ Expanding factor used to determine the mamba intermediate size.
91
+ mamba_hidden_act (`str`, *optional*, defaults to `"silu"`):
92
+ The non-linear activation function in the Mamba layers.
93
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
94
+ Minimum value for the time step in Mamba.
95
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
96
+ Maximum value for the time step in Mamba.
97
+ mamba_dt_limit (`tuple`, *optional*, defaults to `(0.0, inf)`):
98
+ Limits for the time step in Mamba.
99
+ mamba_dt_init_floor (`float`, *optional*, defaults to 0.0001):
100
+ Floor value for time step initialization in Mamba.
101
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
102
+ Whether to use bias in the convolution layer of the mamba mixer block.
103
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
104
+ Whether to use bias in the input and output projections of the mamba mixer block.
105
+ mamba_chunk_size (`int`, *optional*, defaults to 128):
106
+ Size of chunks for Mamba processing.
107
+ mamba_ssm_cache_dtype (`str`, *optional*, defaults to `"float32"`):
108
+ Data type for Mamba SSM cache states.
109
+ n_routed_experts (`int`, *optional*, defaults to 8):
110
+ Number of routed experts in MoE layers.
111
+ n_shared_experts (`int`, *optional*, defaults to 1):
112
+ Number of shared experts that are always activated in MoE layers.
113
+ moe_intermediate_size (`int`, *optional*, defaults to 7688):
114
+ Dimension of the MLP representations in routed experts.
115
+ moe_shared_expert_intermediate_size (`int`, *optional*, defaults to 7688):
116
+ Dimension of the MLP representations in shared experts.
117
+ moe_latent_size (`int`, *optional*):
118
+ Latent size for MoE expert projections. If `None`, uses `hidden_size`.
119
+ moe_shared_expert_overlap (`bool`, *optional*, defaults to `True`):
120
+ Whether shared experts overlap with routed experts.
121
+ num_experts_per_tok (`int`, *optional*, defaults to 2):
122
+ The number of experts to route per token (top-k routing parameter).
123
+ routed_scaling_factor (`float`, *optional*, defaults to 1.0):
124
+ Scaling factor applied to routed expert outputs.
125
+ n_group (`int`, *optional*, defaults to 1):
126
+ Number of groups for expert routing.
127
+ topk_group (`int`, *optional*, defaults to 1):
128
+ Top-k group parameter for expert selection.
129
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
130
+ Whether to normalize top-k probabilities in expert routing.
131
+ num_nextn_predict_layers (`int`, *optional*, defaults to 0):
132
+ Number of additional layers for multi-token prediction. If 0, multi-token prediction is disabled.
133
+ mtp_layers_block_type (`list`, *optional*, defaults to `['attention', 'moe']`):
134
+ Explicit list of layer types for multi-token prediction layers when `num_nextn_predict_layers` > 0.
135
+ use_bias (`bool`, *optional*, defaults to `False`):
136
+ Whether to use bias in the model.
137
+ initializer_range (`float`, *optional*, defaults to 0.02):
138
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
139
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-05):
140
+ The epsilon used by the layer normalization layers.
141
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
142
+ Whether or not residuals should be in `float32`.
143
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
144
+ The dropout ratio for the hidden states.
145
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
146
+ Whether to rescale the pre-normalization residual connections.
147
+
148
+ ```python
149
+ >>> from transformers import NemotronHModel, NemotronHConfig
150
+
151
+ >>> # Initializing a NemotronH configuration
152
+ >>> configuration = NemotronHConfig()
153
+
154
+ >>> # Initializing a model (with random weights) from the configuration
155
+ >>> model = NemotronHModel(configuration)
156
+
157
+ >>> # Accessing the model configuration
158
+ >>> configuration = model.config
159
+ ```"""
160
+
161
+ model_type = "nemotron_h"
162
+ keys_to_ignore_at_inference = ["past_key_values"]
163
+
164
+ @staticmethod
165
+ def _validate_layers_block_type(layers_block_type, expected_length=None, param_name="layers_block_type"):
166
+ """
167
+ Validate layers_block_type list.
168
+
169
+ Args:
170
+ layers_block_type: List of layer types to validate
171
+ expected_length: If provided, validate the list has this length
172
+ param_name: Parameter name for error messages
173
+
174
+ Raises:
175
+ ValueError: If validation fails
176
+ """
177
+ if not isinstance(layers_block_type, list):
178
+ raise ValueError(f"{param_name} must be a list of strings. Got type: {type(layers_block_type)}")
179
+
180
+ if expected_length is not None and len(layers_block_type) != expected_length:
181
+ raise ValueError(f"{param_name} must have length {expected_length}. Got length {len(layers_block_type)}.")
182
+
183
+ valid_types = {"mamba", "attention", "moe"}
184
+ if not all(block_type in valid_types for block_type in layers_block_type):
185
+ invalid = set(layers_block_type) - valid_types
186
+ raise ValueError(f"{param_name} contains invalid types: {invalid}. Must be one of: {valid_types}")
187
+
188
+ def __init__(
189
+ self,
190
+ # General model config
191
+ vocab_size=131072,
192
+ hidden_size=4096,
193
+ layers_block_type=None,
194
+ num_hidden_layers=None, # Deprecated, only for backward compatibility
195
+ tie_word_embeddings=False,
196
+ use_cache=True,
197
+ num_logits_to_keep=1,
198
+ # Token IDs
199
+ pad_token_id=0,
200
+ bos_token_id=1,
201
+ eos_token_id=2,
202
+ # Attention layer config
203
+ num_attention_heads=32,
204
+ num_key_value_heads=8,
205
+ head_dim=128,
206
+ max_position_embeddings=4096,
207
+ attention_bias=False,
208
+ attention_dropout=0.0,
209
+ sliding_window=None,
210
+ # MLP layer config
211
+ intermediate_size=21504,
212
+ mlp_hidden_act="relu2",
213
+ mlp_bias=False,
214
+ # Mamba layer config
215
+ use_mamba_kernels=True,
216
+ ssm_state_size=128,
217
+ mamba_num_heads=128,
218
+ mamba_n_groups=8,
219
+ mamba_head_dim=64,
220
+ mamba_d_conv=4,
221
+ mamba_expand=2,
222
+ mamba_hidden_act="silu",
223
+ mamba_dt_min=0.001,
224
+ mamba_dt_max=0.1,
225
+ mamba_dt_limit=(0.0, float("inf")),
226
+ mamba_dt_init_floor=1e-4,
227
+ mamba_conv_bias=True,
228
+ mamba_proj_bias=False,
229
+ mamba_chunk_size=128,
230
+ mamba_ssm_cache_dtype="float32",
231
+ # MoE config
232
+ n_routed_experts=8,
233
+ n_shared_experts=1,
234
+ moe_intermediate_size=7688,
235
+ moe_shared_expert_intermediate_size=7688,
236
+ moe_latent_size=None,
237
+ moe_shared_expert_overlap=True,
238
+ num_experts_per_tok=2,
239
+ routed_scaling_factor=1.0,
240
+ n_group=1,
241
+ topk_group=1,
242
+ norm_topk_prob=True,
243
+ # Multi-token prediction config
244
+ num_nextn_predict_layers=0,
245
+ mtp_layers_block_type=["attention", "moe"],
246
+ # General training config
247
+ use_bias=False,
248
+ initializer_range=0.02,
249
+ layer_norm_epsilon=1e-5,
250
+ residual_in_fp32=False,
251
+ hidden_dropout=0.0,
252
+ rescale_prenorm_residual=True,
253
+ **kwargs,
254
+ ):
255
+ # Backward compatibility: convert hybrid_override_pattern to layers_block_type
256
+ # Always pop hybrid_override_pattern from kwargs to prevent it from being set as an attribute
257
+ if "hybrid_override_pattern" in kwargs:
258
+ pattern = kwargs.pop("hybrid_override_pattern")
259
+ if layers_block_type is None:
260
+ layers_block_type = self._pattern_to_list(pattern)
261
+ elif layers_block_type is None:
262
+ # Default layers_block_type if not provided
263
+ layers_block_type = ["mamba", "moe", "attention", "moe"]
264
+
265
+ # Note: num_hidden_layers is deprecated and ignored if layers_block_type is explicitly provided
266
+ # It's only kept for backward compatibility when loading old configs
267
+ if num_hidden_layers is not None:
268
+ # Warn if num_hidden_layers is provided but doesn't match layers_block_type
269
+ if len(layers_block_type) != num_hidden_layers:
270
+ logger.warning(
271
+ f"num_hidden_layers ({num_hidden_layers}) is deprecated and doesn't match "
272
+ f"layers_block_type length ({len(layers_block_type)}). Using layers_block_type length."
273
+ )
274
+
275
+ # Backward compatibility: convert mtp_hybrid_override_pattern to mtp_layers_block_type
276
+ # Always pop mtp_hybrid_override_pattern from kwargs to prevent it from being set as an attribute
277
+ if "mtp_hybrid_override_pattern" in kwargs:
278
+ pattern = kwargs.pop("mtp_hybrid_override_pattern")
279
+ if mtp_layers_block_type is None or mtp_layers_block_type == ["attention", "moe"]:
280
+ mtp_layers_block_type = self._pattern_to_list(pattern)
281
+
282
+ self.vocab_size = vocab_size
283
+ self.tie_word_embeddings = tie_word_embeddings
284
+ self.hidden_size = hidden_size
285
+ self.intermediate_size = intermediate_size
286
+ self.num_attention_heads = num_attention_heads
287
+ self.head_dim = head_dim
288
+ self.sliding_window = sliding_window
289
+ self.max_position_embeddings = max_position_embeddings
290
+ self.attention_dropout = attention_dropout
291
+ self.hidden_dropout = hidden_dropout
292
+
293
+ # Validate layers_block_type (no longer checking length against num_hidden_layers)
294
+ self._validate_layers_block_type(layers_block_type, expected_length=None, param_name="layers_block_type")
295
+ self.layers_block_type = layers_block_type
296
+
297
+ # for backward compatibility
298
+ if num_key_value_heads is None:
299
+ num_key_value_heads = num_attention_heads
300
+
301
+ self.num_key_value_heads = num_key_value_heads
302
+ self.mlp_hidden_act = mlp_hidden_act
303
+ self.attention_bias = attention_bias
304
+ self.mlp_bias = mlp_bias
305
+ self.use_bias = use_bias
306
+ self.initializer_range = initializer_range
307
+ self.layer_norm_epsilon = layer_norm_epsilon
308
+ self.residual_in_fp32 = residual_in_fp32
309
+
310
+ self.use_cache = use_cache
311
+ self.num_logits_to_keep = num_logits_to_keep
312
+
313
+ self.use_mamba_kernels = use_mamba_kernels
314
+ self.n_groups = mamba_n_groups
315
+ self.mamba_head_dim = mamba_head_dim
316
+ self.ssm_state_size = ssm_state_size
317
+ self.mamba_num_heads = mamba_num_heads
318
+ self.conv_kernel = mamba_d_conv
319
+ self.expand = mamba_expand
320
+ self.mamba_hidden_act = mamba_hidden_act
321
+ self.time_step_min = mamba_dt_min
322
+ self.time_step_max = mamba_dt_max
323
+ self.time_step_limit = mamba_dt_limit
324
+ self.time_step_floor = mamba_dt_init_floor
325
+ self.use_conv_bias = mamba_conv_bias
326
+ self.mamba_proj_bias = mamba_proj_bias
327
+ self.chunk_size = mamba_chunk_size
328
+ self.rescale_prenorm_residual = rescale_prenorm_residual
329
+ self.n_routed_experts = n_routed_experts
330
+ self.n_shared_experts = n_shared_experts
331
+ self.moe_intermediate_size = moe_intermediate_size
332
+ self.moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size
333
+ self.moe_latent_size = moe_latent_size
334
+ self.moe_shared_expert_overlap = moe_shared_expert_overlap
335
+ self.num_experts_per_tok = num_experts_per_tok
336
+ self.routed_scaling_factor = routed_scaling_factor
337
+ self.n_group = n_group
338
+ self.topk_group = topk_group
339
+ self.norm_topk_prob = norm_topk_prob
340
+ self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
341
+
342
+ # MTP config
343
+ self.num_nextn_predict_layers = num_nextn_predict_layers
344
+
345
+ # Validate mtp_layers_block_type is provided when MTP is enabled
346
+ if self.num_nextn_predict_layers > 0:
347
+ if mtp_layers_block_type is None:
348
+ raise ValueError(
349
+ "mtp_layers_block_type is required when num_nextn_predict_layers > 0. "
350
+ "Please provide an explicit list of layer types for MTP layers. "
351
+ "Example: mtp_layers_block_type=['attention', 'moe']"
352
+ )
353
+ self._validate_layers_block_type(mtp_layers_block_type, None, "mtp_layers_block_type")
354
+ self.mtp_layers_block_type = mtp_layers_block_type
355
+
356
+ super().__init__(
357
+ pad_token_id=pad_token_id,
358
+ bos_token_id=bos_token_id,
359
+ eos_token_id=eos_token_id,
360
+ tie_word_embeddings=tie_word_embeddings,
361
+ **kwargs,
362
+ )
363
+
364
+ @property
365
+ def num_hidden_layers(self) -> int:
366
+ """
367
+ Number of hidden layers derived from the length of layers_block_type.
368
+ This property replaces the deprecated num_hidden_layers parameter.
369
+ """
370
+ return len(self.layers_block_type)
371
+
372
+ @num_hidden_layers.setter
373
+ def num_hidden_layers(self, value):
374
+ """
375
+ Setter for backward compatibility when loading configs.
376
+ The value is ignored since num_hidden_layers is computed from layers_block_type.
377
+ """
378
+ # Ignore the value - num_hidden_layers is always derived from layers_block_type
379
+ pass
380
+
381
+ @property
382
+ def hybrid_override_pattern(self) -> str:
383
+ """
384
+ Backward compatibility property.
385
+ Returns the pattern string representation of layers_block_type.
386
+ """
387
+ return self._list_to_pattern(self.layers_block_type)
388
+
389
+ @property
390
+ def mtp_hybrid_override_pattern(self) -> str:
391
+ """
392
+ Backward compatibility property.
393
+ Returns the pattern string representation of mtp_layers_block_type.
394
+ """
395
+ return self._list_to_pattern(self.mtp_layers_block_type)
396
+
397
+ @staticmethod
398
+ def _list_to_pattern(layers_list: list) -> str:
399
+ """Convert list of layer types back to pattern string (for backward compatibility)."""
400
+ reverse_mapping = {"mamba": "M", "moe": "E", "attention": "*"}
401
+ return "".join(reverse_mapping[layer_type] for layer_type in layers_list)
402
+
403
+ @staticmethod
404
+ def _pattern_to_list(pattern: str) -> list:
405
+ """Convert pattern string to list of layer types (for backward compatibility)."""
406
+ pattern_mapping = {"M": "mamba", "E": "moe", "*": "attention"}
407
+ return [pattern_mapping[char] for char in pattern]
408
+
409
+
410
+ __all__ = ["NemotronHConfig"]
generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "do_sample": true,
4
+ "bos_token_id": 1,
5
+ "eos_token_id": [
6
+ 2,
7
+ 11
8
+ ],
9
+ "pad_token_id": 0,
10
+ "temperature": 1.0,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.57.1"
13
+ }
model-00001-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:105bd2c4f20e8f68f7114a11a0a8ead5a38b3c8d1cb962ba88fb8ee58d21a294
3
+ size 4997999760
model-00002-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec4d02dbe117562653395d0699fd4c8c8d5a7caf5a9dc23d3adb71f878f25f55
3
+ size 4996714072
model-00003-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:137a7acc3c5648a8834491328670eb551dc9becdac1e10b7ddb0a07eba27181d
3
+ size 4996714072
model-00004-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0a198f937fa3d7b3ba1c3ea4684e3cd32887d52b6ca2905817bbf1d8c142326
3
+ size 4996459064
model-00005-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e74a3aed4c6fbbcc9cce227c1daa24ab0c95a662b9d0aed496ec28f7a5c0a7b
3
+ size 4958833808
model-00006-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6ccac6d4cb59f7c5e4c523e08e2e376876772c5cb359a63a1e26c26175eb4604
3
+ size 4998024408
model-00007-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:068306f784fb43b69b15ee91d949d9e333d7c7821167b84d5f845000266967d5
3
+ size 4996714952
model-00008-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8f3b25f04a9549f13870966d05682602b22428023df2430c82381d33fd07fcc
3
+ size 4996714936
model-00009-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bc45e88a5fbe207974ac5104b91e83a32d2af26b480886b595eab4e7a52063a
3
+ size 4996459848
model-00010-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b9e6fe453356fb2a21eca51ff828376ed997da9de0614a1ea7e836791250490
3
+ size 4996714992
model-00011-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ee3d39ef8833d339831581e2755539154a3963533f51b12da83ccaa945588a7e
3
+ size 4975723264
model-00012-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ec01d95016c0e44e549bb2c5583b578a77aa4b1402fa2ce919c1cdbf1665fca
3
+ size 4997651448
model-00013-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7aac706af7ea0d3e36255482b8529b6dc5933b92099da283e668ba849c694b75
3
+ size 4996714960
model-00014-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:eaa949d7ac1b1a602ee2a3897a32976bd67d6500226982465a9773f26b2f5928
3
+ size 4996459848
model-00015-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f16316eaa892fe5442538a21857935297eb098d292750bd8721455048e7a311
3
+ size 4996714936
model-00016-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:de2791b665ad8a3fe24b41b827bc509b09608ddade928c039b60adf20b61e8fa
3
+ size 4996714984
model-00017-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e8ea9d674374409b5bc27b8e03b08dca851962497a60bac4c6cd92605b943ae8
3
+ size 4915166568
model-00019-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e5d3ceedf3ab01c869ec856c483ac6e8a069567572887d524936796571f2eea
3
+ size 4996714960
model-00020-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20bbcb0cc5b14f026b9cf095de2944976005435ee468d895c0656e4ef44cffcc
3
+ size 4996459848
model-00021-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53d503f4921496ce2f369945c729ee9383fb6f7c3d9d6b26fa380ed902cc3b64
3
+ size 4996714936
model-00022-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e92c37f824fd6d00afd7a7c327574118027acc504d8ee7b7a202a970b184de82
3
+ size 4996714984
model-00023-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:faa4fcdb4a8bd316ecf4a58e073f232dfa112fd77537986690c5b6e25d896847
3
+ size 4915166568
model-00024-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:023e6b268664f373450d1897531654d8c64bafdbcbda4e5c004f9b329b3a0adc
3
+ size 4997651448
model-00026-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4df92010a5ddde290aac8d9b9d8455e65ca24e51437432d87b03a19525540d5
3
+ size 4996459848
model-00033-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65496d854228f5c9eb9427eb86029c7ff3b43fe2dfb00ff52af6b3f5e829392e
3
+ size 4996714936
model-00036-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4405f66cdcbfe80ee9c2c8cb3a4341b2c9a05c44076f7e1548103007a3be1f6
3
+ size 4997651448
model-00038-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a255376e905fac0e42e54b6863ba293599e3d29132feba287b939b3d47b7900
3
+ size 4996459848
model-00040-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc9293969492e19269dfe83ebe79dae3a92dd5cb15c62b95f5f74124b0407656
3
+ size 4996714984
model-00041-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:305019d2d661b4a9d346179c07f25293497b45d1c981a40e57b4841281ea8cdc
3
+ size 4915166568
model-00042-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ab6b38d9b9b3b354566a2e8aeb52363e18d9de9b9edf292fde8847d525980ed
3
+ size 4997651448
model-00043-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:610d06cbd5be56fc677020d8f363c35715b8bb69e5656e1f9e74764303ba99bd
3
+ size 4996459880
model-00044-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a0c17be3f3ea0daad1d8d57808cbaa33707af1247c21fc13c1b5b30f44328849
3
+ size 4996714936
model-00045-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60eb1cf2b2d5367d237ac12f3b009b34e391fd8c2b57e0f954aa683d3de5ff4e
3
+ size 4996714936
model-00046-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:81feea7dde8337637bbf0c85da0f9d9d1effca86a5c91244957046b084cc55e2
3
+ size 4996714984
model-00047-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b3c105086bbaa733dc6dc8be52fed5c2ba64ae960c4e8384247e1cd37ea6b3a2
3
+ size 4915166568
model-00048-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d48f4f962bc0aa20216ce487025ef4beffd7a85241691142d549e244eba2696e
3
+ size 4997651448
model-00049-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cb70fda7f2ea7e7fa58210bc0eafc619e3d5b531bb9741bedc6bfac71508ae04
3
+ size 4998690536
model-00050-of-00050.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f9b7cea8a184b457f70d63678b69dd8d452bbae0cbf9780071e6ad8c1e9028d
3
+ size 2927698848
model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
modeling_nemotron_h.py ADDED
@@ -0,0 +1,1754 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+
19
+ import contextlib
20
+ import math
21
+ from dataclasses import dataclass
22
+ from typing import Any, Optional, Union
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import CrossEntropyLoss
29
+
30
+ from transformers.activations import ACT2FN
31
+ from transformers.generation import GenerationMixin
32
+ from transformers.modeling_attn_mask_utils import AttentionMaskConverter
33
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
34
+ from transformers.utils import (
35
+ ModelOutput,
36
+ add_code_sample_docstrings,
37
+ add_start_docstrings,
38
+ add_start_docstrings_to_model_forward,
39
+ logging,
40
+ )
41
+ from transformers.utils.import_utils import (
42
+ is_causal_conv1d_available,
43
+ is_flash_attn_2_available,
44
+ is_mamba_2_ssm_available,
45
+ )
46
+
47
+ from .configuration_nemotron_h import NemotronHConfig
48
+
49
+
50
+ logger = logging.get_logger(__name__)
51
+
52
+
53
+ # Copied from transformers.models.mamba2.modeling_mamba2
54
+ if is_mamba_2_ssm_available():
55
+ from mamba_ssm.ops.triton.selective_state_update import selective_state_update
56
+ from mamba_ssm.ops.triton.ssd_combined import mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined
57
+ else:
58
+ mamba_chunk_scan_combined, mamba_split_conv1d_scan_combined, selective_state_update = None, None, None
59
+
60
+ try:
61
+ from mamba_ssm.ops.triton.layernorm_gated import rmsnorm_fn
62
+ except ImportError:
63
+ raise ImportError("mamba-ssm is required by the Mamba model but cannot be imported")
64
+
65
+ if is_causal_conv1d_available():
66
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
67
+ else:
68
+ causal_conv1d_update, causal_conv1d_fn = None, None
69
+
70
+ if is_flash_attn_2_available():
71
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
72
+
73
+ is_fast_path_available = all(
74
+ (
75
+ selective_state_update,
76
+ mamba_chunk_scan_combined,
77
+ mamba_split_conv1d_scan_combined,
78
+ causal_conv1d_fn,
79
+ causal_conv1d_update,
80
+ )
81
+ )
82
+
83
+ # TODO: Update with correct checkpoint when model is published to HuggingFace Hub
84
+ _CHECKPOINT_FOR_DOC = "nvidia/nemotron-h-placeholder"
85
+ _CONFIG_FOR_DOC = "NemotronHConfig"
86
+
87
+
88
+ # Helper methods for segment sum computation
89
+
90
+
91
+ def pad_tensor_by_size(input_tensor: torch.Tensor, pad_size: int):
92
+ """
93
+ Padding x tensor with `pad_size` on the seq_len dim (dim=1)
94
+
95
+ Assumes that we only have tensors of either size 4 or 3
96
+ """
97
+ pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(input_tensor.shape) == 4 else (0, 0, 0, pad_size, 0, 0)
98
+
99
+ return torch.nn.functional.pad(input_tensor, pad_shape, mode="constant", value=0)
100
+
101
+
102
+ def reshape_into_chunks(input_tensor, pad_size, chunk_size):
103
+ """
104
+ Padding input_tensor with `pad_size` on the seq_len dim (dim=1) and
105
+ simultaneously splitting it into chunk sequences.
106
+
107
+ Assumes that we only have tensors of either size 4 or 3
108
+ """
109
+ # [bsz, seq_len, ...] -> [bsz, seq_len multiple of chunk_size, ...]
110
+ input_tensor = pad_tensor_by_size(input_tensor, pad_size)
111
+
112
+ if len(input_tensor.shape) == 3:
113
+ # [bsz, seq_len multiple of chunk_size, num_heads] -> [bsz, -1, chunk_size, num_heads]
114
+ return input_tensor.reshape(input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2])
115
+ else:
116
+ # [bsz, seq_len multiple of chunk_size, num_heads, head_dim or state_size] -> [bsz, -1, chunk_size, num_heads, head_dim or state_size]
117
+ return input_tensor.reshape(
118
+ input_tensor.shape[0], -1, chunk_size, input_tensor.shape[2], input_tensor.shape[3]
119
+ )
120
+
121
+
122
+ def segment_sum(input_tensor):
123
+ """
124
+ More stable segment sum calculation. Uses cumulative sums and masking instead of direct subtractions.
125
+ """
126
+ chunk_size = input_tensor.size(-1)
127
+ # 1. expand input tensor to have an additional dimension and repeat along that dimension
128
+ # [..., chunk_size] -> [..., chunk_size, chunk_size]
129
+ input_tensor = input_tensor[..., None].expand(*input_tensor.size(), chunk_size)
130
+ # 2. create a lower triangular mask with the diagonal set to 0 to 0 out elements above diag
131
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=-1)
132
+ input_tensor = input_tensor.masked_fill(~mask, 0)
133
+ # 3. compute actual cumsum
134
+ tensor_segsum = torch.cumsum(input_tensor, dim=-2)
135
+
136
+ # 4. apply mask to keep only the lower triangular part of the cumulative sum result (incl diagonal this time)
137
+ mask = torch.tril(torch.ones(chunk_size, chunk_size, device=input_tensor.device, dtype=torch.bool), diagonal=0)
138
+ tensor_segsum = tensor_segsum.masked_fill(~mask, -torch.inf)
139
+ return tensor_segsum
140
+
141
+
142
+ def apply_mask_to_padding_states(hidden_states, attention_mask):
143
+ """
144
+ Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
145
+ """
146
+ if attention_mask is not None and not torch.all(attention_mask == 1):
147
+ dtype = hidden_states.dtype
148
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
149
+
150
+ return hidden_states
151
+
152
+ # Adapted from transformers.models.zamba2.modeling_zamba2.Zamba2HybridDynamicCache for the v2 mixer
153
+ class NemotronHHybridDynamicCache:
154
+ """
155
+ A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
156
+ (which has a constant shape regardless of seq_len).
157
+
158
+ This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
159
+ and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
160
+ For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
161
+ while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
162
+ For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
163
+ while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
164
+ and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
165
+ """
166
+
167
+ is_compileable = False
168
+
169
+ def __init__(
170
+ self, config: NemotronHConfig, batch_size: int, dtype: torch.dtype = torch.float16, device: str | None = None
171
+ ):
172
+ self.dtype = dtype
173
+ self.layers_block_type = config.layers_block_type
174
+ self.has_previous_state = False
175
+ self.intermediate_size = int(config.mamba_num_heads * config.mamba_head_dim)
176
+ self.ssm_state_size = config.ssm_state_size
177
+ self.conv_kernel_size = config.conv_kernel
178
+ self.n_mamba_heads = config.mamba_num_heads
179
+ self.transformer_layers = []
180
+ self._modules = {}
181
+ self._parameters = {}
182
+ self._buffers = {}
183
+ self.conv_states = {}
184
+ self.ssm_states = {}
185
+ for i in range(config.num_hidden_layers):
186
+ if self.layers_block_type[i] == "mamba":
187
+ # Only allocate mamba cache for mamba layers
188
+ self.conv_states[i] = torch.zeros(
189
+ batch_size,
190
+ self.intermediate_size + 2 * config.n_groups * self.ssm_state_size,
191
+ self.conv_kernel_size,
192
+ device=device,
193
+ dtype=dtype,
194
+ )
195
+ self.ssm_states[i] = torch.zeros(
196
+ batch_size,
197
+ self.n_mamba_heads,
198
+ config.mamba_head_dim,
199
+ self.ssm_state_size,
200
+ device=device,
201
+ dtype=dtype,
202
+ )
203
+ else:
204
+ # For attention and moe layers, use empty tensors
205
+ self.conv_states[i] = torch.tensor([[]] * batch_size, device=device)
206
+ self.ssm_states[i] = torch.tensor([[]] * batch_size, device=device)
207
+
208
+ if self.layers_block_type[i] == "attention":
209
+ self.transformer_layers.append(i)
210
+ self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
211
+ self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
212
+
213
+ def __len__(self):
214
+ return len(self.key_cache)
215
+
216
+ def update(
217
+ self,
218
+ key_states: torch.Tensor,
219
+ value_states: torch.Tensor,
220
+ layer_idx: int,
221
+ cache_kwargs: dict[str, Any] | None = None,
222
+ ) -> tuple[torch.Tensor, torch.Tensor]:
223
+ # Update the cache
224
+ if self.key_cache[layer_idx].shape[-1] == 0:
225
+ self.key_cache[layer_idx] = key_states
226
+ self.value_cache[layer_idx] = value_states
227
+ else:
228
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
229
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
230
+
231
+ return self.key_cache[layer_idx], self.value_cache[layer_idx]
232
+
233
+ def reorder_cache(self, beam_idx: torch.LongTensor):
234
+ """Reorders the cache for beam search, given the selected beam indices."""
235
+ if self.get_seq_length() > 0:
236
+ for layer_idx in range(len(self.key_cache)):
237
+ device = self.key_cache[layer_idx].device
238
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
239
+ device = self.value_cache[layer_idx].device
240
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
241
+
242
+ device = self.conv_states[layer_idx].device
243
+ self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
244
+ device = self.ssm_states[layer_idx].device
245
+ self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
246
+
247
+ def get_seq_length(self, layer_idx: int | None = 0) -> int:
248
+ """Returns the sequence length of the cached states. A layer index can be optionally passed."""
249
+ # take any layer that contains cache and not empty tensor
250
+ layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
251
+ if len(self.key_cache) <= layer_idx or self.key_cache[layer_idx].numel() == 0:
252
+ return 0
253
+ return self.key_cache[layer_idx].shape[-2]
254
+
255
+ def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
256
+ """Return the length and offset of the cache, used to generate the mask"""
257
+ kv_offset = 0
258
+ query_length = cache_position.shape[0]
259
+ kv_length = self.get_seq_length(layer_idx) + query_length
260
+ return kv_length, kv_offset
261
+
262
+ def update_conv_state(
263
+ self, layer_idx: int, new_conv_state: torch.Tensor, cache_position: torch.LongTensor
264
+ ) -> torch.Tensor:
265
+ conv_state = self.conv_states[layer_idx]
266
+ cache_position = cache_position.clamp(0, self.conv_kernel_size - 1)
267
+
268
+ conv_state = conv_state.roll(shifts=-1, dims=-1)
269
+ conv_state[:, :, cache_position] = new_conv_state.to(conv_state.device)
270
+ self.conv_states[layer_idx].zero_()
271
+ self.conv_states[layer_idx] += conv_state
272
+ return self.conv_states[layer_idx]
273
+
274
+ def reset(self):
275
+ self.conv_states.zero_()
276
+ self.ssm_states.zero_()
277
+
278
+ class MambaRMSNormGated(torch.nn.Module):
279
+ """
280
+ Gated Root Mean Square Normalization for Mamba layers.
281
+
282
+ This normalization variant supports gating, allowing the normalization to be
283
+ modulated by a gating signal. It is specifically designed for use in Mamba blocks
284
+ and supports grouped normalization.
285
+
286
+ Args:
287
+ hidden_size (`int`):
288
+ The dimension of the hidden states to normalize.
289
+ group_size (`int`):
290
+ Size of each group for grouped normalization.
291
+ eps (`float`, *optional*, defaults to 1e-5):
292
+ A small value added to the variance for numerical stability.
293
+ """
294
+ def __init__(self, hidden_size, group_size, eps=1e-5):
295
+ super().__init__()
296
+ self.weight = nn.Parameter(torch.ones(hidden_size))
297
+ self.variance_epsilon = eps
298
+ self.group_size = group_size
299
+
300
+ def forward(self, hidden_states, gate=None):
301
+ return rmsnorm_fn(x=hidden_states,
302
+ weight=self.weight,
303
+ bias=None,
304
+ z=gate,
305
+ eps=self.variance_epsilon,
306
+ group_size=self.group_size,
307
+ norm_before_gate=False
308
+ )
309
+
310
+ # Adapted from transformers.models.zamba2.modeling_zamba2.Zamba2MambaMixer
311
+ class NemotronHMamba2Mixer(nn.Module):
312
+ """
313
+ Compute ∆, A, B, C, and D the state space parameters and compute the `contextualized_states`.
314
+ A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
315
+ ∆, B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
316
+ and is why Mamba is called **selective** state spaces)
317
+ """
318
+
319
+ def __init__(self, config: NemotronHConfig, layer_idx: int | None = None):
320
+ super().__init__()
321
+ self.config = config
322
+ self.hidden_size = config.hidden_size
323
+ self.ssm_state_size = config.ssm_state_size
324
+ self.conv_kernel_size = config.conv_kernel
325
+ self.intermediate_size = config.mamba_num_heads * config.mamba_head_dim
326
+ self.layer_idx = layer_idx
327
+ self.use_conv_bias = config.use_conv_bias
328
+ self.activation = config.mamba_hidden_act
329
+ self.act = ACT2FN[config.mamba_hidden_act]
330
+ self.use_mem_eff_path = True
331
+
332
+ self.n_groups = config.n_groups
333
+ self.head_dim = config.mamba_head_dim
334
+ self.num_heads = config.mamba_num_heads
335
+ self.chunk_size = config.chunk_size
336
+
337
+ self.time_step_limit = config.time_step_limit
338
+ self.time_step_min = config.time_step_min
339
+ self.time_step_max = config.time_step_max
340
+
341
+ self.conv_dim = self.intermediate_size + 2 * self.n_groups * self.ssm_state_size
342
+
343
+ self.conv1d = nn.Conv1d(
344
+ in_channels=self.conv_dim,
345
+ out_channels=self.conv_dim,
346
+ bias=config.use_conv_bias,
347
+ kernel_size=self.conv_kernel_size,
348
+ groups=self.conv_dim,
349
+ padding=self.conv_kernel_size - 1,
350
+ )
351
+
352
+ # projection of the input hidden states
353
+ projection_size = self.intermediate_size + self.conv_dim + self.num_heads
354
+
355
+ self.in_proj = nn.Linear(
356
+ self.hidden_size,
357
+ projection_size,
358
+ bias=config.use_bias,
359
+ )
360
+ # selective projection used to make dt, B and C input dependent
361
+
362
+ # time step projection (discretization)
363
+ # instantiate once and copy inv_dt in init_weights of PretrainedModel
364
+ self.dt_bias = nn.Parameter(torch.ones(self.num_heads))
365
+
366
+ # S4D real initialization. These are not discretized!
367
+ # The core is to load them, compute the discrete states, then write the updated state. Keeps the memory bounded
368
+ A = torch.arange(1, self.num_heads + 1)
369
+ self.A_log = nn.Parameter(torch.log(A))
370
+
371
+ self.norm = MambaRMSNormGated(self.intermediate_size, eps=config.layer_norm_epsilon, group_size=self.intermediate_size // self.n_groups)
372
+ self.D = nn.Parameter(torch.ones(self.num_heads))
373
+
374
+ self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
375
+
376
+ if not is_fast_path_available:
377
+ logger.warning_once(
378
+ "The fast path is not available because one of `(selective_state_update, causal_conv1d_fn, causal_conv1d_update)`"
379
+ " is None. Falling back to the naive implementation. To install follow https://github.com/state-spaces/mamba/#installation and"
380
+ " https://github.com/Dao-AILab/causal-conv1d"
381
+ )
382
+
383
+
384
+ def cuda_kernels_forward(
385
+ self,
386
+ hidden_states: torch.Tensor,
387
+ cache_params: Optional[NemotronHHybridDynamicCache] = None,
388
+ attention_mask: Optional[torch.Tensor] = None,
389
+ ):
390
+ # set up dimensions for reshapes later
391
+
392
+ batch_size, seq_len, _ = hidden_states.shape
393
+ groups_time_state_size = self.n_groups * self.ssm_state_size
394
+ d_to_remove = 2 * self.intermediate_size + 2 * self.n_groups * self.ssm_state_size + self.num_heads
395
+
396
+ # getting projected states from cache if it exists
397
+ if cache_params is not None and cache_params.has_previous_state:
398
+ in_projected_states = self.in_proj(hidden_states.squeeze(1)) # (B 2D)
399
+ d_mlp = (in_projected_states.shape[-1] - d_to_remove) // 2
400
+ split_projection_dim = [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads]
401
+ _, _, gate, hidden_states_B_C, dt = torch.split(in_projected_states, split_projection_dim, dim=-1)
402
+
403
+ hidden_states_B_C = causal_conv1d_update(
404
+ hidden_states_B_C,
405
+ cache_params.conv_states[self.layer_idx],
406
+ self.conv1d.weight.squeeze(1),
407
+ self.conv1d.bias,
408
+ self.activation,
409
+ )
410
+
411
+ hidden_states, B, C = torch.split(
412
+ hidden_states_B_C,
413
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
414
+ dim=-1,
415
+ )
416
+ A = -torch.exp(self.A_log.float()) # (nheads,)
417
+
418
+ A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
419
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
420
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
421
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
422
+ B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
423
+ C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
424
+ hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
425
+ hidden_states = selective_state_update(
426
+ cache_params.ssm_states[self.layer_idx],
427
+ hidden_states_reshaped,
428
+ dt,
429
+ A,
430
+ B,
431
+ C,
432
+ D,
433
+ z=None,
434
+ dt_bias=dt_bias,
435
+ dt_softplus=True,
436
+ )
437
+ hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
438
+ hidden_states = self.norm(hidden_states, gate)
439
+ out = self.out_proj(hidden_states)[:, None, ...]
440
+ # if no cache is found, calling the kernel
441
+ else:
442
+ if attention_mask is not None and not torch.all(attention_mask == 1):
443
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
444
+ dtype = hidden_states.dtype
445
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
446
+ # 1. Gated MLP's linear projection
447
+ projected_states = self.in_proj(hidden_states)
448
+ A = -torch.exp(self.A_log.float()) # (num_heads) or (intermediate_size, state_size)
449
+ dt_limit_kwargs = {} if self.time_step_limit is None else {"dt_limit": self.time_step_limit}
450
+ if attention_mask is not None:
451
+ input_not_masked = torch.all(attention_mask == 1)
452
+ else:
453
+ input_not_masked = True
454
+
455
+ if self.use_mem_eff_path and self.training and cache_params is None and input_not_masked:
456
+ out, ssm_state = mamba_split_conv1d_scan_combined(
457
+ projected_states,
458
+ self.conv1d.weight.squeeze(1),
459
+ self.conv1d.bias,
460
+ self.dt_bias,
461
+ A,
462
+ D=self.D,
463
+ chunk_size=self.chunk_size,
464
+ seq_idx=None,
465
+ activation=self.activation,
466
+ rmsnorm_weight=self.norm.weight,
467
+ rmsnorm_eps=self.norm.variance_epsilon,
468
+ outproj_weight=self.out_proj.weight,
469
+ outproj_bias=self.out_proj.bias,
470
+ headdim=self.head_dim,
471
+ ngroups=self.n_groups,
472
+ norm_before_gate=False,
473
+ return_final_states=True,
474
+ **dt_limit_kwargs,
475
+ )
476
+
477
+ else:
478
+ gate, hidden_states_B_C, time_step = torch.split(
479
+ projected_states,
480
+ [self.intermediate_size, self.conv_dim, self.num_heads],
481
+ dim=-1,
482
+ )
483
+
484
+ # 1D Convolution
485
+ if cache_params is not None:
486
+ hidden_states_B_C_t = hidden_states_B_C.transpose(1, 2)
487
+ conv_state = nn.functional.pad(
488
+ hidden_states_B_C_t, (self.conv_kernel_size - hidden_states_B_C_t.shape[-1], 0)
489
+ )
490
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
491
+ if causal_conv1d_fn is None or self.activation not in ["silu", "swish"]:
492
+ hidden_states_B_C = self.act(
493
+ self.conv1d(hidden_states_B_C.transpose(1, 2)).transpose(1, 2)[:, :seq_len]
494
+ ) # (B, L, self.d_inner + 2 * ngroups * d_state)
495
+ else:
496
+ hidden_states_B_C = causal_conv1d_fn(
497
+ x=hidden_states_B_C.transpose(1, 2),
498
+ weight=self.conv1d.weight.squeeze(1),
499
+ bias=self.conv1d.bias,
500
+ activation=self.activation,
501
+ ).transpose(1, 2)[:, :seq_len]
502
+ hidden_states, B, C = torch.split(
503
+ hidden_states_B_C,
504
+ [self.intermediate_size, groups_time_state_size, groups_time_state_size],
505
+ dim=-1,
506
+ )
507
+ if attention_mask is not None and not torch.all(attention_mask == 1):
508
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
509
+ dtype = hidden_states.dtype
510
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
511
+ scan_output, ssm_state = mamba_chunk_scan_combined(
512
+ hidden_states.view(batch_size, seq_len, -1, self.head_dim),
513
+ time_step,
514
+ A,
515
+ B.view(batch_size, seq_len, self.n_groups, -1),
516
+ C.view(batch_size, seq_len, self.n_groups, -1),
517
+ chunk_size=self.chunk_size,
518
+ D=self.D,
519
+ z=None,
520
+ seq_idx=None,
521
+ return_final_states=True,
522
+ dt_bias=self.dt_bias,
523
+ dt_softplus=True,
524
+ **dt_limit_kwargs,
525
+ )
526
+ if ssm_state is not None and cache_params is not None:
527
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
528
+ scan_output = scan_output.view(batch_size, seq_len, -1)
529
+ # Multiply "gate" branch and apply extra normalization layer
530
+ scan_output = self.norm(scan_output, gate)
531
+ out = self.out_proj(scan_output)
532
+ return out
533
+
534
+ # fmt: off
535
+ def torch_forward(self, input_states, cache_params: Optional[NemotronHHybridDynamicCache]=None, attention_mask: Optional[torch.Tensor]=None):
536
+ batch_size, seq_len, _ = input_states.shape
537
+ dtype = input_states.dtype
538
+ # Gated MLP's linear projection
539
+ if cache_params is not None and cache_params.has_previous_state:
540
+ projected_states = self.in_proj(input_states.squeeze(1))
541
+ else:
542
+ if attention_mask is not None and not torch.all(attention_mask==1):
543
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
544
+ input_states = (input_states * attention_mask[:, :, None]).to(dtype)
545
+ projected_states = self.in_proj(input_states)
546
+ d_mlp = (projected_states.shape[-1] - 2 * self.intermediate_size - 2 * self.n_groups * self.ssm_state_size- self.num_heads) // 2
547
+ _, _, gate, hidden_states, dt = projected_states.split(
548
+ [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
549
+ )
550
+
551
+ # Convolution sequence transformation
552
+ if cache_params is not None:
553
+ ssm_state = cache_params.ssm_states[self.layer_idx].clone()
554
+ ssm_state = ssm_state.to(hidden_states.device)
555
+ if cache_params.has_previous_state:
556
+ gate = gate.unsqueeze(1)
557
+ conv_state = cache_params.conv_states[self.layer_idx] # [batch, intermediate_size, conv_kernel_size]
558
+ conv_state = torch.roll(conv_state, shifts=-1, dims=-1)
559
+ # handle batched generation - states are copied through
560
+ conv_state[:, :, -1] = hidden_states[:, 0, :] if hidden_states.ndim == 3 else hidden_states
561
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
562
+ hidden_states = torch.sum(conv_state.to(projected_states.device) * self.conv1d.weight[:, 0, :], dim=-1)
563
+ if self.use_conv_bias:
564
+ hidden_states += self.conv1d.bias
565
+ hidden_states = self.act(hidden_states).to(dtype)[:, None, ...] # [batch, 1, intermediate_size] : decoding
566
+ else:
567
+ hidden_states = hidden_states.transpose(1,2)
568
+ conv_state = nn.functional.pad(
569
+ hidden_states,
570
+ (self.conv_kernel_size - hidden_states.shape[-1], 0)
571
+ )
572
+ cache_params.conv_states[self.layer_idx].copy_(conv_state)
573
+ hidden_states = self.act(self.conv1d(hidden_states).transpose(1,2))[:, :seq_len, :] # [batch, intermediate_size, seq_len]
574
+ if attention_mask is not None and not torch.all(attention_mask==1):
575
+ dtype = hidden_states.dtype
576
+ # tune out hidden states for pad tokens, see https://github.com/state-spaces/mamba/issues/66
577
+ hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
578
+ else:
579
+ ssm_state = torch.zeros(
580
+ (batch_size, self.num_heads, self.head_dim, self.ssm_state_size),
581
+ device=hidden_states.device, dtype=dtype
582
+ )
583
+ hidden_states = self.act(self.conv1d(hidden_states.transpose(1, 2))[..., :seq_len].transpose(1, 2))
584
+ hidden_states, B, C = torch.split(hidden_states, [self.intermediate_size, self.n_groups * self.ssm_state_size, self.n_groups * self.ssm_state_size], dim=-1)
585
+ A = -torch.exp(self.A_log.float()) # [num_heads]
586
+ if cache_params is not None and cache_params.has_previous_state:
587
+ # Note: there is no need to pad parameter matrices here, as there is just one new token
588
+ # for batched generation
589
+ dt = dt[:, None, ...] if dt.ndim == 2 else dt[:, 0, :][:, None, ...]
590
+ dt = dt.transpose(1, 2).expand(batch_size, dt.shape[-1], self.head_dim)
591
+ # [num_heads] -> [num_heads, head_dim]
592
+ dt_bias = self.dt_bias[..., None].expand(self.dt_bias.shape[0], self.head_dim)
593
+
594
+ dt = torch.nn.functional.softplus(dt + dt_bias.to(dt.dtype))
595
+ dt = torch.clamp(dt, self.time_step_min) #, self.time_step_max)
596
+ A = A[..., None, None].expand(self.num_heads, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
597
+ # [bsz, num_heads, head_dim, state_size]
598
+ dA = torch.exp(dt[..., None] * A)
599
+
600
+ # Discretize B
601
+ # [bsz, n_groups * state_size] -> [bsz, n_groups, 1, state_size] ->
602
+ # -> [bsz, n_groups, group to head repetition factor, state_size] -> [bsz, num_heads, state_size]
603
+ B = B.reshape(batch_size, self.n_groups, -1)[..., None, :]
604
+ B = B.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, B.shape[-1]).contiguous()
605
+ B = B.reshape(batch_size, -1, B.shape[-1])
606
+ # [bsz, num_heads, head_dim, state_size]
607
+ dB = dt[..., None] * B[..., None, :]
608
+
609
+ # Discretize x into dB
610
+ # [bsz, intermediate_size] -> [bsz, num_heads, head_dim]
611
+ hidden_states = hidden_states.reshape(batch_size, -1, self.head_dim)
612
+ dBx = dB * hidden_states[..., None]
613
+
614
+ # State calculation
615
+ cache_params.ssm_states[self.layer_idx].copy_(
616
+ cache_params.ssm_states[self.layer_idx] * dA + dBx
617
+ )
618
+
619
+ # Subsequent output
620
+ # [bsz, n_groups * state_size] -> [bsz, num_heads, state_size]
621
+ C = C.reshape(batch_size, self.n_groups, -1)[..., None, :]
622
+ C = C.expand(batch_size, self.n_groups, self.num_heads // self.n_groups, C.shape[-1]).contiguous()
623
+ C = C.reshape(batch_size, -1, C.shape[-1])
624
+ # [bsz, num_heads, head_dim]
625
+
626
+ ssm_states = cache_params.ssm_states[self.layer_idx].to(C.dtype) # Shape: [b, h, d, n]
627
+ # Reshape ssm_states to merge the first two dimensions
628
+ ssm_states_reshaped = ssm_states.view(batch_size * self.num_heads, self.head_dim, self.ssm_state_size) # Shape: [b*h, d, n]
629
+ C_reshaped = C.view(batch_size * self.num_heads, self.ssm_state_size, 1) # Shape: [b*h, n, 1]
630
+ y = torch.bmm(ssm_states_reshaped, C_reshaped)
631
+ y = y.view(batch_size, self.num_heads, self.head_dim)
632
+
633
+ # D skip connection
634
+ # [num_heads] -> [num_heads, head_dim]
635
+ D = self.D[..., None].expand(self.D.shape[0], self.head_dim)
636
+ y = (y + hidden_states * D).to(y.dtype)
637
+
638
+ # [bsz, num_heads, head_dim] -> [bsz, 1, intermediate_size]
639
+ y = y.reshape(batch_size, -1)[:, None, ...]
640
+ else:
641
+ # begin ssd naive implementation without einsums
642
+ dt = nn.functional.softplus(dt + self.dt_bias)
643
+ dt = torch.clamp(dt, self.time_step_min)
644
+ hidden_states = hidden_states.reshape(batch_size, seq_len, -1, self.head_dim).float()
645
+ B = B.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
646
+ C = C.reshape(batch_size, seq_len, -1, self.ssm_state_size).float()
647
+ B = B.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
648
+ C = C.repeat_interleave(self.num_heads // self.n_groups, dim=2, output_size=self.num_heads)
649
+ pad_size = (self.chunk_size - seq_len % self.chunk_size) % self.chunk_size
650
+
651
+ D_residual = self.D[..., None] * pad_tensor_by_size(hidden_states, pad_size)
652
+
653
+ # Discretize x and A
654
+ hidden_states = hidden_states * dt[..., None]
655
+ A = A.to(hidden_states.dtype) * dt
656
+
657
+ # Rearrange into blocks/chunks
658
+ hidden_states, A, B, C = [reshape_into_chunks(t, pad_size, self.chunk_size) for t in (hidden_states, A, B, C)]
659
+
660
+
661
+ # [bsz, -1, chunk_size, num_heads] -> [bsz, num_heads, -1, chunk_size]
662
+ A = A.permute(0, 3, 1, 2)
663
+ A_cumsum = torch.cumsum(A, dim=-1)
664
+
665
+ # 1. Compute the output for each intra-chunk (diagonal blocks)
666
+ # This is the analog of a causal mask
667
+ L = torch.exp(segment_sum(A))
668
+
669
+ # First, contraction of C and B to get G (attention-weights like)
670
+ G_intermediate = C[:, :, :, None, :, :] * B[:, :, None, :, : ,:] # shape: (b, c, l, s, h, n)
671
+ G = G_intermediate.sum(dim=-1) # shape: (b, c, l, s, h)
672
+
673
+
674
+ # Step 2: Compute M, equivalent to applying attention mask to weights
675
+ M_intermediate = G[..., None] * L.permute(0, 2, 3, 4, 1)[..., None]
676
+ M = M_intermediate.sum(dim=-1)
677
+
678
+ # Step 3: Compute Y_diag (apply to values)
679
+ Y_diag = (M[..., None] * hidden_states[:, :, None]).sum(3)
680
+
681
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
682
+
683
+ decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum)
684
+ B_decay_contraction = B * decay_states.permute(0, 2, 3, 1)[..., None]
685
+ # permute back B * decay states
686
+ states = (B_decay_contraction.permute(0, 1, 3, 2, 4)[..., None] * hidden_states.permute(0, 1, 3, 2, 4)[..., None, :]).sum(dim=3).permute(0, 1, 2, 4, 3)
687
+ if cache_params is not None and cache_params.has_previous_state:
688
+ previous_states = cache_params.ssm_states[self.layer_idx][:, None, ...]
689
+ else:
690
+ previous_states = torch.zeros_like(states[:, :1])
691
+ states = torch.cat([previous_states, states], dim=1)
692
+ decay_chunk = torch.exp(segment_sum(nn.functional.pad(A_cumsum[:, :, :, -1], (1, 0))))
693
+
694
+ states_permuted = states.permute(0, 2, 1, 3, 4)
695
+ result = (decay_chunk[..., None, None] * states_permuted[:, :, None, ...]).sum(dim=2)
696
+ new_states = result.permute(0, 2, 1, 3, 4)
697
+ states, ssm_state = new_states[:, :-1], new_states[:, -1]
698
+
699
+ # Compute state -> output conversion per chunk
700
+ # (left term of low-rank factorization of off-diagonal blocks; C terms)
701
+ state_decay_out = torch.exp(A_cumsum)
702
+ # compute Yoff
703
+ C_times_states = (C[..., None, :] * states[:, :, None, ...])
704
+ state_decay_out_permuted = state_decay_out.permute(0, 2, 3, 1)
705
+ Y_off = (C_times_states.sum(-1) * state_decay_out_permuted[..., None])
706
+ # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
707
+
708
+ y = Y_diag + Y_off
709
+ # [bsz, -1, self.chunk_size, num_heads, head_dim] -> [bsz, (padded) seq_len, num_heads, head_dim]
710
+ y = y.reshape(batch_size, -1, self.num_heads, self.head_dim)
711
+
712
+ y = y + D_residual
713
+ # Cutting off padded chunks
714
+ if pad_size > 0:
715
+ y = y[:, :seq_len, :, :]
716
+ y = y.reshape(batch_size, seq_len, -1)
717
+ if ssm_state is not None and cache_params is not None:
718
+ cache_params.ssm_states[self.layer_idx].copy_(ssm_state)
719
+
720
+ scan_output = self.norm(y, gate)
721
+
722
+ # end ssd naive
723
+
724
+ # 4. Final linear projection
725
+ contextualized_states = self.out_proj(scan_output.to(dtype)) # [batch, seq_len, hidden_size]
726
+ return contextualized_states
727
+ # fmt: on
728
+
729
+ def forward(
730
+ self,
731
+ hidden_states,
732
+ cache_params: Optional[NemotronHHybridDynamicCache] = None,
733
+ attention_mask: Optional[torch.Tensor] = None,
734
+ ):
735
+ if is_fast_path_available and "cuda" in self.in_proj.weight.device.type:
736
+ return self.cuda_kernels_forward(hidden_states, cache_params, attention_mask)
737
+
738
+ return self.torch_forward(hidden_states, cache_params, attention_mask)
739
+
740
+
741
+ class NemotronHRMSNorm(nn.Module):
742
+ """
743
+ Root Mean Square Layer Normalization for NemotronH.
744
+
745
+ NemotronHRMSNorm is equivalent to T5LayerNorm and LlamaRMSNorm. It normalizes
746
+ the input using the root mean square of the hidden dimensions, then scales by
747
+ a learned weight parameter.
748
+
749
+ Args:
750
+ hidden_size (`int`):
751
+ The dimension of the hidden states to normalize.
752
+ eps (`float`, *optional*, defaults to 1e-6):
753
+ A small value added to the variance for numerical stability.
754
+ """
755
+ def __init__(self, hidden_size, eps=1e-6):
756
+ super().__init__()
757
+ self.weight = nn.Parameter(torch.ones(hidden_size))
758
+ self.variance_epsilon = eps
759
+
760
+ def forward(self, hidden_states):
761
+ input_dtype = hidden_states.dtype
762
+ hidden_states = hidden_states.to(torch.float32)
763
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
764
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
765
+ return (self.weight.to(torch.float32) * hidden_states).to(input_dtype)
766
+
767
+ class NemotronHBlock(nn.Module):
768
+ """
769
+ A single transformer block in the NemotronH model.
770
+
771
+ This block can contain different types of mixers (Mamba, Attention, MLP, or MoE)
772
+ depending on the configuration. Each block applies pre-normalization followed by
773
+ the mixer, then adds a residual connection.
774
+
775
+ Args:
776
+ config (`NemotronHConfig`):
777
+ Model configuration specifying the block architecture.
778
+ layer_idx (`int`):
779
+ Index of this block in the model. Used to determine the block type from
780
+ `config.layers_block_type[layer_idx]`.
781
+ """
782
+ def __init__(self, config, layer_idx):
783
+ super().__init__()
784
+ self.config = config
785
+ self.layer_idx = layer_idx
786
+ self.residual_in_fp32 = config.residual_in_fp32
787
+ self.norm = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
788
+
789
+ # M: Mamba2, *: Attention, -: MLP
790
+ self.block_type = config.layers_block_type[layer_idx]
791
+ if self.block_type == "mamba":
792
+ self.mixer = NemotronHMamba2Mixer(config, layer_idx=layer_idx)
793
+ elif self.block_type == "attention":
794
+ self.mixer = NemotronHAttention(config, layer_idx=layer_idx)
795
+ elif self.block_type == "mlp":
796
+ self.mixer = NemotronHMLP(config, layer_idx=layer_idx)
797
+ elif self.block_type == "moe":
798
+ self.mixer = NemotronHMoE(config, layer_idx=layer_idx)
799
+ else:
800
+ raise ValueError(f"Invalid layer pattern {config.hybrid_override_pattern[layer_idx]}")
801
+
802
+ def forward(
803
+ self,
804
+ hidden_states,
805
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None,
806
+ cache_position: Optional[torch.LongTensor] = None,
807
+ attention_mask: Optional[torch.Tensor] = None,
808
+ output_attentions: bool = False,
809
+ ):
810
+ if hidden_states.device.type == "cuda":
811
+ stream_context = torch.cuda.stream(torch.cuda.default_stream(hidden_states.device))
812
+ else:
813
+ stream_context = contextlib.nullcontext()
814
+
815
+ with stream_context:
816
+ residual = hidden_states
817
+ hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
818
+ if self.residual_in_fp32:
819
+ residual = residual.to(torch.float32)
820
+
821
+ if self.block_type == "mamba":
822
+ hidden_states = self.mixer(
823
+ hidden_states, cache_params=past_key_values, attention_mask=attention_mask
824
+ )
825
+ elif self.block_type == "attention":
826
+ hidden_states, _, _ = self.mixer(
827
+ hidden_states=hidden_states,
828
+ past_key_values=past_key_values,
829
+ attention_mask=attention_mask,
830
+ output_attentions=output_attentions,
831
+ )
832
+ elif self.block_type in ["mlp", "moe"]:
833
+ hidden_states = self.mixer(
834
+ hidden_states
835
+ )
836
+ else:
837
+ raise ValueError(f"Invalid block_type: {self.block_type}")
838
+
839
+ hidden_states = residual + hidden_states
840
+ return hidden_states
841
+
842
+
843
+ # Copied from transformers.models.nemotron.modeling_nemotron Nemotron->NemotronH
844
+ class NemotronHMLP(nn.Module):
845
+ """
846
+ Multi-Layer Perceptron (MLP) module for NemotronH.
847
+
848
+ This module implements a standard feed-forward network with one hidden layer,
849
+ applying an activation function between the up and down projections.
850
+
851
+ Args:
852
+ config (`NemotronHConfig`):
853
+ Model configuration containing hyperparameters.
854
+ intermediate_size (`int`, *optional*):
855
+ Dimension of the intermediate hidden layer. If not provided, uses `config.intermediate_size`.
856
+ layer_idx (`int`, *optional*):
857
+ Index of the layer in the model. Used for proper cache management.
858
+ is_expert (`bool`, *optional*, defaults to `False`):
859
+ Whether this MLP is used as an expert in a Mixture-of-Experts layer.
860
+ """
861
+ def __init__(self, config, intermediate_size=None, layer_idx: Optional[int] = None, is_expert=False):
862
+ super().__init__()
863
+ self.config = config
864
+ self.layer_idx = layer_idx
865
+ if layer_idx is None:
866
+ logger.warning_once(
867
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
868
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
869
+ "when creating this class."
870
+ )
871
+ use_latent_size = (self.config.moe_latent_size is not None) and is_expert
872
+ self.hidden_size = config.hidden_size
873
+ input_size = self.hidden_size if not use_latent_size else config.moe_latent_size
874
+
875
+ self.intermediate_size = intermediate_size or config.intermediate_size
876
+ self.up_proj = nn.Linear(input_size, self.intermediate_size, bias=config.mlp_bias)
877
+ self.down_proj = nn.Linear(self.intermediate_size, input_size, bias=config.mlp_bias)
878
+ self.act_fn = ACT2FN[config.mlp_hidden_act]
879
+
880
+ def forward(self, x):
881
+ return self.down_proj(self.act_fn(self.up_proj(x)))
882
+
883
+
884
+ class NemotronHMoE(nn.Module):
885
+ """
886
+ Mixture-of-Experts (MoE) module for NemotronH.
887
+
888
+ This module implements a sparse MoE layer with both routed experts and shared experts.
889
+ Tokens are routed to a subset of experts based on learned routing weights, while all
890
+ tokens are processed by shared experts. The architecture supports optional latent
891
+ dimension projection for computational efficiency.
892
+
893
+ Args:
894
+ config (`NemotronHConfig`):
895
+ Model configuration containing MoE-specific hyperparameters including:
896
+ - `n_routed_experts`: Number of routed expert MLPs
897
+ - `num_experts_per_tok`: Number of experts each token is routed to
898
+ - `moe_intermediate_size`: Hidden dimension for routed experts
899
+ - `moe_shared_expert_intermediate_size`: Hidden dimension for shared experts
900
+ - `moe_latent_size`: Optional latent dimension for dimensionality reduction
901
+ layer_idx (`int`, *optional*):
902
+ Index of the layer in the model.
903
+ """
904
+ def __init__(self, config, layer_idx: Optional[int] = None):
905
+ super().__init__()
906
+ self.config = config
907
+ self.experts = nn.ModuleList(
908
+ [
909
+ NemotronHMLP(config, intermediate_size=config.moe_intermediate_size, layer_idx=layer_idx, is_expert=True)
910
+ for _ in range(config.n_routed_experts)
911
+ ]
912
+ )
913
+ self.gate = NemotronHTopkRouter(config)
914
+ self.shared_experts = NemotronHMLP(
915
+ config=config, intermediate_size=config.moe_shared_expert_intermediate_size, layer_idx=layer_idx, is_expert=False
916
+ )
917
+
918
+ if config.moe_latent_size is not None:
919
+ self.fc1_latent_proj = nn.Linear(config.hidden_size, config.moe_latent_size, bias=config.mlp_bias)
920
+ self.fc2_latent_proj = nn.Linear(config.moe_latent_size, config.hidden_size, bias=config.mlp_bias)
921
+ else:
922
+ self.fc1_latent_proj = nn.Identity()
923
+ self.fc2_latent_proj = nn.Identity()
924
+
925
+ def moe(self, hidden_states: torch.Tensor, topk_indices: torch.Tensor, topk_weights: torch.Tensor):
926
+ final_hidden_states = torch.zeros_like(hidden_states, dtype=topk_weights.dtype)
927
+ expert_mask = torch.nn.functional.one_hot(topk_indices, num_classes=len(self.experts))
928
+ expert_mask = expert_mask.permute(2, 0, 1)
929
+
930
+ for expert_idx in range(len(self.experts)):
931
+ expert = self.experts[expert_idx]
932
+ mask = expert_mask[expert_idx]
933
+ token_indices, weight_indices = torch.where(mask)
934
+
935
+ if token_indices.numel() > 0:
936
+ expert_weights = topk_weights[token_indices, weight_indices]
937
+ expert_input = hidden_states[token_indices]
938
+ expert_output = expert(expert_input)
939
+ weighted_output = expert_output * expert_weights.unsqueeze(-1)
940
+ final_hidden_states.index_add_(0, token_indices, weighted_output)
941
+ else:
942
+ # Local empty expert: no-op compute that still marks params as used.
943
+ expert_dtype = expert.down_proj.weight.dtype
944
+ dummy_out = expert(torch.zeros_like(hidden_states[0]).unsqueeze(0).to(expert_dtype))
945
+ final_hidden_states = final_hidden_states + dummy_out
946
+
947
+ # in original deepseek, the output of the experts are gathered once we leave this module
948
+ # thus the moe module is itself an IsolatedParallel module
949
+ # and all expert are "local" meaning we shard but we don't gather
950
+ return final_hidden_states.type(hidden_states.dtype)
951
+
952
+ def forward(self, hidden_states):
953
+ residuals = hidden_states
954
+ orig_shape = hidden_states.shape
955
+ topk_indices, topk_weights = self.gate(hidden_states)
956
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
957
+
958
+ hidden_states = self.fc1_latent_proj(hidden_states)
959
+ hidden_states = self.moe(hidden_states, topk_indices, topk_weights)
960
+ hidden_states = self.fc2_latent_proj(hidden_states)
961
+
962
+ hidden_states = hidden_states.view(*orig_shape)
963
+
964
+ hidden_states = hidden_states + self.shared_experts(residuals)
965
+ return hidden_states
966
+
967
+
968
+ class NemotronHTopkRouter(nn.Module):
969
+ """
970
+ Top-K routing module for Mixture-of-Experts.
971
+
972
+ This router determines which experts should process each token by computing routing
973
+ logits and selecting the top-K experts based on grouped scoring. It implements
974
+ group-based expert selection with score correction for load balancing.
975
+
976
+ Args:
977
+ config (`NemotronHConfig`):
978
+ Model configuration containing routing hyperparameters including:
979
+ - `num_experts_per_tok`: Number of experts to route each token to (K)
980
+ - `n_routed_experts`: Total number of available experts
981
+ - `routed_scaling_factor`: Scaling factor applied to routing weights
982
+ - `n_group`: Number of expert groups for grouped routing
983
+ - `topk_group`: Number of groups to select from
984
+ - `norm_topk_prob`: Whether to normalize the top-K routing probabilities
985
+ """
986
+ def __init__(self, config):
987
+ super().__init__()
988
+ self.config = config
989
+ self.top_k = config.num_experts_per_tok
990
+ self.n_routed_experts = config.n_routed_experts
991
+ self.routed_scaling_factor = config.routed_scaling_factor
992
+ self.n_group = config.n_group
993
+ self.topk_group = config.topk_group
994
+ self.norm_topk_prob = config.norm_topk_prob
995
+
996
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, config.hidden_size)))
997
+ self.register_buffer("e_score_correction_bias", torch.zeros(self.n_routed_experts, dtype=torch.float32))
998
+
999
+ @torch.no_grad()
1000
+ def get_topk_indices(self, scores):
1001
+ scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0)
1002
+ group_scores = (
1003
+ scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group)
1004
+ .topk(2, dim=-1)[0]
1005
+ .sum(dim=-1)
1006
+ )
1007
+ group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
1008
+ group_mask = torch.zeros_like(group_scores)
1009
+ group_mask.scatter_(1, group_idx, 1)
1010
+ score_mask = (
1011
+ group_mask.unsqueeze(-1)
1012
+ .expand(-1, self.n_group, self.n_routed_experts // self.n_group)
1013
+ .reshape(-1, self.n_routed_experts)
1014
+ )
1015
+ scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0)
1016
+ topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=False)[1]
1017
+ return topk_indices
1018
+
1019
+ def forward(self, hidden_states):
1020
+ """
1021
+ Compute expert routing for each token in the input.
1022
+
1023
+ This method performs the following steps:
1024
+ 1. Compute routing logits using a linear projection
1025
+ 2. Apply sigmoid activation to get routing scores
1026
+ 3. Select top-K experts using grouped selection strategy
1027
+ 4. Gather and optionally normalize the routing weights
1028
+ 5. Apply scaling factor to final weights
1029
+
1030
+ Args:
1031
+ hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
1032
+ Input hidden states to be routed to experts.
1033
+
1034
+ Returns:
1035
+ `tuple` containing:
1036
+ - topk_indices (`torch.Tensor` of shape `(batch_size * sequence_length, num_experts_per_tok)`):
1037
+ Indices of the selected experts for each token.
1038
+ - topk_weights (`torch.Tensor` of shape `(batch_size * sequence_length, num_experts_per_tok)`):
1039
+ Normalized routing weights for each selected expert, scaled by routed_scaling_factor.
1040
+ """
1041
+ self._maintain_float32_expert_bias()
1042
+
1043
+ hidden_states = hidden_states.view(-1, self.config.hidden_size)
1044
+ router_logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
1045
+ scores = router_logits.sigmoid()
1046
+ topk_indices = self.get_topk_indices(scores)
1047
+ topk_weights = scores.gather(1, topk_indices)
1048
+ if self.norm_topk_prob:
1049
+ denominator = topk_weights.sum(dim=-1, keepdim=True) + 1e-20
1050
+ topk_weights /= denominator
1051
+ topk_weights = topk_weights * self.routed_scaling_factor
1052
+ return topk_indices, topk_weights
1053
+
1054
+ def _maintain_float32_expert_bias(self):
1055
+ """
1056
+ Ensure e_score_correction_bias stays in float32 for numerical stability.
1057
+
1058
+ This method is called at the start of forward() to revert the bias back to
1059
+ float32 if the model was cast to a lower precision dtype (e.g., via model.to(torch.bfloat16)).
1060
+
1061
+ """
1062
+ if self.e_score_correction_bias.dtype != torch.float32:
1063
+ self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(torch.float32)
1064
+
1065
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
1066
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1067
+ """
1068
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
1069
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
1070
+ """
1071
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
1072
+ if n_rep == 1:
1073
+ return hidden_states
1074
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
1075
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1076
+
1077
+ def eager_attention_forward(
1078
+ module: nn.Module,
1079
+ query: torch.Tensor,
1080
+ key: torch.Tensor,
1081
+ value: torch.Tensor,
1082
+ attention_mask: Optional[torch.Tensor],
1083
+ scaling: float,
1084
+ dropout: float = 0.0,
1085
+ **kwargs,
1086
+ ):
1087
+ """Eager attention forward pass - computes attention weights explicitly."""
1088
+ key_states = repeat_kv(key, module.num_key_value_groups)
1089
+ value_states = repeat_kv(value, module.num_key_value_groups)
1090
+
1091
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
1092
+ if attention_mask is not None:
1093
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
1094
+ attn_weights = attn_weights + causal_mask
1095
+
1096
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
1097
+ attn_weights = F.dropout(attn_weights, p=dropout, training=module.training)
1098
+ attn_output = torch.matmul(attn_weights, value_states)
1099
+ attn_output = attn_output.transpose(1, 2).contiguous()
1100
+
1101
+ return attn_output, attn_weights
1102
+
1103
+
1104
+ class NemotronHAttention(nn.Module):
1105
+ """Multi-headed attention from 'Attention Is All You Need' paper
1106
+
1107
+ Args:
1108
+ config (`NemotronHConfig`):
1109
+ Model configuration containing attention parameters like num_attention_heads, num_key_value_heads,
1110
+ hidden_size, head_dim, attention_dropout, and attention_bias.
1111
+ layer_idx (`int`, *optional*):
1112
+ Index of the layer in the model. Required for proper caching during generation. If not provided,
1113
+ a warning is emitted and caching may fail.
1114
+ """
1115
+
1116
+ def __init__(self, config: NemotronHConfig, layer_idx: Optional[int] = None):
1117
+ super().__init__()
1118
+ self.config = config
1119
+ self.layer_idx = layer_idx
1120
+ if layer_idx is None:
1121
+ logger.warning_once(
1122
+ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
1123
+ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
1124
+ "when creating this class."
1125
+ )
1126
+
1127
+ self.attention_dropout = config.attention_dropout
1128
+ self.hidden_size = config.hidden_size
1129
+ self.num_heads = config.num_attention_heads
1130
+ if config.head_dim is not None:
1131
+ self.head_dim = config.head_dim
1132
+ else:
1133
+ self.head_dim = config.hidden_size // config.num_attention_heads
1134
+ self.num_key_value_heads = config.num_key_value_heads
1135
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
1136
+ self.max_position_embeddings = config.max_position_embeddings
1137
+ self.scaling = self.head_dim ** -0.5
1138
+ self.is_causal = True
1139
+
1140
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
1141
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
1142
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
1143
+ self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias)
1144
+
1145
+ def forward(
1146
+ self,
1147
+ hidden_states: torch.Tensor,
1148
+ attention_mask: Optional[torch.Tensor] = None,
1149
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None,
1150
+ **kwargs,
1151
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
1152
+ bsz, q_len, _ = hidden_states.size()
1153
+
1154
+ query_states = self.q_proj(hidden_states)
1155
+ key_states = self.k_proj(hidden_states)
1156
+ value_states = self.v_proj(hidden_states)
1157
+
1158
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
1159
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1160
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
1161
+
1162
+ if past_key_values is not None:
1163
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx)
1164
+
1165
+ # Select attention implementation based on config
1166
+ attention_interface = eager_attention_forward
1167
+ if self.config._attn_implementation != "eager":
1168
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
1169
+
1170
+ if attention_mask is None and q_len > 1:
1171
+ mask = torch.triu(torch.full((q_len, q_len), float("-inf"), device=hidden_states.device), diagonal=1)
1172
+ attention_mask = mask.view(1, 1, q_len, q_len)
1173
+
1174
+ attn_output, attn_weights = attention_interface(
1175
+ self,
1176
+ query_states,
1177
+ key_states,
1178
+ value_states,
1179
+ attention_mask,
1180
+ dropout=0.0 if not self.training else self.attention_dropout,
1181
+ scaling=self.scaling,
1182
+ **kwargs,
1183
+ )
1184
+
1185
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
1186
+ attn_output = self.o_proj(attn_output)
1187
+
1188
+ return attn_output, attn_weights, past_key_values
1189
+
1190
+
1191
+ # Copied from transformers.models.mamba2.modeling_mamba2.Mamba2PreTrainedModel
1192
+ class NemotronHPreTrainedModel(PreTrainedModel):
1193
+ """
1194
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1195
+ models.
1196
+ """
1197
+
1198
+ config_class = NemotronHConfig
1199
+ base_model_prefix = "model"
1200
+ _no_split_modules = ["NemotronHBlock"]
1201
+ supports_gradient_checkpointing = True
1202
+ _is_stateful = True
1203
+ _supports_sdpa = True
1204
+ _supports_flash_attn_2 = True
1205
+ _checkpoint_conversion_mapping = {"backbone": "model"}
1206
+
1207
+ def _init_weights(self, module):
1208
+ """Initialize the weights."""
1209
+ if isinstance(module, NemotronHMamba2Mixer):
1210
+ if getattr(module.dt_bias, "_is_hf_initialized", False):
1211
+ return
1212
+ module.A_log._no_weight_decay = True
1213
+ module.D._no_weight_decay = True
1214
+
1215
+ dt = torch.exp(
1216
+ torch.rand(self.config.mamba_num_heads)
1217
+ * (math.log(self.config.time_step_max) - math.log(self.config.time_step_min))
1218
+ + math.log(self.config.time_step_min)
1219
+ ).clamp(min=self.config.time_step_floor)
1220
+
1221
+ # # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
1222
+ inv_dt = dt + torch.log(-torch.expm1(-dt))
1223
+ with torch.no_grad():
1224
+ module.dt_bias.copy_(inv_dt)
1225
+ module.dt_bias._no_reinit = True
1226
+ elif isinstance(module, NemotronHTopkRouter):
1227
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
1228
+ nn.init.zeros_(module.e_score_correction_bias)
1229
+
1230
+ if isinstance(module, nn.Linear):
1231
+ if module.bias is not None:
1232
+ if not getattr(module.bias, "_no_reinit", False):
1233
+ nn.init.zeros_(module.bias)
1234
+ elif isinstance(module, nn.Embedding):
1235
+ nn.init.normal_(module.weight, std=self.config.initializer_range)
1236
+
1237
+ if self.config.rescale_prenorm_residual:
1238
+ # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
1239
+ # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
1240
+ # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
1241
+ # > -- GPT-2 :: https://openai.com/blog/better-language-models/
1242
+ #
1243
+ # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
1244
+ for name, p in module.named_parameters():
1245
+ if getattr(p, "_is_hf_initialized", False):
1246
+ continue
1247
+ if name in ["out_proj.weight"]:
1248
+ # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
1249
+ # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
1250
+ # We need to reinit p since this code could be called multiple times
1251
+ # Having just p *= scale would repeatedly scale it down
1252
+ nn.init.kaiming_uniform_(p, a=math.sqrt(5))
1253
+ with torch.no_grad():
1254
+ p /= math.sqrt(self.config.num_hidden_layers)
1255
+
1256
+
1257
+ @dataclass
1258
+ # Copied from transformers.models.mamba2.modeling_mamba2.Mamba2Output with MAMBA2->NemotronH,Mamba2->NemotronH
1259
+ class NemotronHOutput(ModelOutput):
1260
+ """
1261
+ Class for the NemotronH model outputs.
1262
+
1263
+ Args:
1264
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
1265
+ Sequence of hidden-states at the output of the last layer of the model.
1266
+ past_key_values (`NemotronHHybridDynamicCache`):
1267
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1268
+ avoid providing the old `input_ids`.
1269
+
1270
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
1271
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1272
+ tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1273
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1274
+
1275
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1276
+ """
1277
+
1278
+ last_hidden_state: Optional[torch.FloatTensor] = None
1279
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None
1280
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1281
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1282
+
1283
+
1284
+ @dataclass
1285
+ # Copied from transformers.models.mamba2.modeling_mamba2.MambaCausalLMOutput with Mamba2->NemotronH
1286
+ class NemotronHCausalLMOutput(ModelOutput):
1287
+ """
1288
+ Base class for causal language model (or autoregressive) outputs.
1289
+
1290
+ Args:
1291
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1292
+ Language modeling loss (for next-token prediction).
1293
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1294
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1295
+ past_key_values (`NemotronHHybridDynamicCache`):
1296
+ The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
1297
+ avoid providing the old `input_ids`.
1298
+
1299
+ Includes both the State space model state matrices after the selective scan, and the Convolutional states
1300
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1301
+ tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
1302
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
1303
+
1304
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
1305
+ """
1306
+
1307
+ loss: Optional[torch.FloatTensor] = None
1308
+ logits: Optional[torch.FloatTensor] = None
1309
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None
1310
+ hidden_states: Optional[tuple[torch.FloatTensor]] = None
1311
+ attentions: Optional[tuple[torch.FloatTensor]] = None
1312
+
1313
+
1314
+ NEMOTRONH_START_DOCSTRING = r"""
1315
+
1316
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1317
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1318
+ etc.)
1319
+
1320
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1321
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1322
+ and behavior.
1323
+
1324
+ Parameters:
1325
+ config ([`NemotronHConfig`]): Model configuration class with all the parameters of the model.
1326
+ Initializing with a config file does not load the weights associated with the model, only the
1327
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1328
+ """
1329
+
1330
+ NEMOTRONH_INPUTS_DOCSTRING = r"""
1331
+ Args:
1332
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`, *optional*):
1333
+ Indices of input sequence tokens in the vocabulary.
1334
+
1335
+ If `past_key_values.seqlen_offset>0`, only `input_ids` that do not have their past calculated should be passed as
1336
+ `input_ids`.
1337
+
1338
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
1339
+ [`PreTrainedTokenizer.__call__`] for details.
1340
+
1341
+ [What are input IDs?](../glossary#input-ids)
1342
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1343
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
1344
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
1345
+ model's internal embedding lookup matrix.
1346
+ position_ids (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1347
+ Indices of positions of each input sequence tokens in the position embeddings.
1348
+ past_key_values (`NemotronHHybridDynamicCache`, *optional*):
1349
+ If passed along, the model uses the previous state in all the blocks (which will give the output for the
1350
+ `input_ids` provided as if the model add `state_input_ids + input_ids` as context).
1351
+ use_cache (`bool`, *optional*):
1352
+ If set to `True`, the `past_key_values` is returned and can be used to quickly generate the next logits.
1353
+ output_attentions (`bool`, *optional*):
1354
+ Whether or not to return the attentions tensors of all attention layers.
1355
+ output_hidden_states (`bool`, *optional*):
1356
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
1357
+ more detail.
1358
+ return_dict (`bool`, *optional*):
1359
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
1360
+ cache_position (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1361
+ The position of the current input in the cache. This is used to ensure that the cache is correctly updated.
1362
+ If `past_key_values` is passed, `cache_position` should also be passed.
1363
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1364
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1365
+
1366
+ - 1 for tokens that are **not masked**,
1367
+ - 0 for tokens that are **masked**.
1368
+
1369
+ [What are attention masks?](../glossary#attention-mask)
1370
+ """
1371
+
1372
+
1373
+ @add_start_docstrings(
1374
+ "The bare NemotronH Model transformer outputting raw hidden-states without any specific head on top.",
1375
+ NEMOTRONH_START_DOCSTRING,
1376
+ )
1377
+ class NemotronHModel(NemotronHPreTrainedModel):
1378
+ def __init__(self, config):
1379
+ super().__init__(config)
1380
+
1381
+ self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
1382
+ self.layers = nn.ModuleList([NemotronHBlock(config, layer_idx=idx) for idx in range(config.num_hidden_layers)])
1383
+
1384
+ self.gradient_checkpointing = False
1385
+ self.norm_f = NemotronHRMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)
1386
+ # Initialize weights and apply final processing
1387
+ self._register_load_state_dict_pre_hook(self.load_hook)
1388
+ self.post_init()
1389
+
1390
+ def load_hook(self, state_dict, prefix, *args):
1391
+ for k in state_dict:
1392
+ if "embedding." in k:
1393
+ state_dict[k.replace("embedding.", "embeddings.")] = state_dict.pop(k)
1394
+ break
1395
+
1396
+ def get_input_embeddings(self):
1397
+ return self.embeddings
1398
+
1399
+ def set_input_embeddings(self, new_embeddings):
1400
+ self.embeddings = new_embeddings
1401
+
1402
+ @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1403
+ @add_code_sample_docstrings(
1404
+ checkpoint=_CHECKPOINT_FOR_DOC,
1405
+ output_type=NemotronHOutput,
1406
+ config_class=_CONFIG_FOR_DOC,
1407
+ )
1408
+ def forward(
1409
+ self,
1410
+ input_ids: Optional[torch.LongTensor] = None,
1411
+ inputs_embeds: Optional[torch.LongTensor] = None,
1412
+ position_ids: Optional[torch.LongTensor] = None,
1413
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None,
1414
+ use_cache: Optional[bool] = None,
1415
+ output_attentions: Optional[bool] = None,
1416
+ output_hidden_states: Optional[bool] = None,
1417
+ return_dict: Optional[bool] = None,
1418
+ cache_position: Optional[torch.LongTensor] = None,
1419
+ attention_mask: Optional[torch.Tensor] = None,
1420
+ **kwargs,
1421
+ ) -> Union[tuple, NemotronHOutput]:
1422
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1423
+ output_hidden_states = (
1424
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1425
+ )
1426
+ use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
1427
+
1428
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1429
+
1430
+ if (input_ids is None) ^ (inputs_embeds is not None): # ^ is python for xor
1431
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1432
+
1433
+ if inputs_embeds is None:
1434
+ inputs_embeds = self.embeddings(input_ids)
1435
+
1436
+ if self.gradient_checkpointing and self.training and use_cache:
1437
+ logger.warning_once(
1438
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1439
+ )
1440
+ use_cache = False
1441
+
1442
+ # From zamba_modeling.py
1443
+ if use_cache and past_key_values is None:
1444
+ logger.warning_once(
1445
+ "NemotronH requires an initialized `NemotronHHybridDynamicCache` to return a cache. None was "
1446
+ "provided, so no cache will be returned."
1447
+ )
1448
+
1449
+ hidden_states = inputs_embeds
1450
+
1451
+ if cache_position is None:
1452
+ past_seen_tokens = (
1453
+ past_key_values.get_seq_length()
1454
+ if past_key_values is not None
1455
+ else 0
1456
+ )
1457
+ cache_position = torch.arange(
1458
+ past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device
1459
+ )
1460
+ if position_ids is None:
1461
+ position_ids = cache_position.unsqueeze(0)
1462
+
1463
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
1464
+ mamba_mask = self._update_mamba_mask(attention_mask, cache_position)
1465
+
1466
+ all_hidden_states = () if output_hidden_states else None
1467
+ all_self_attns = () if output_attentions else None
1468
+ # Until HERE
1469
+
1470
+ for layer_idx, mixer_block in enumerate(self.layers):
1471
+ # Depending on the layer type we opt for 2D base attention mask (Mamba) or 4D causal mask (Attention)
1472
+ if mixer_block.block_type == "mamba":
1473
+ layer_mask = mamba_mask
1474
+ elif mixer_block.block_type == "attention":
1475
+ layer_mask = causal_mask
1476
+ elif mixer_block.block_type in ["mlp", "moe"]:
1477
+ layer_mask = None
1478
+ else:
1479
+ raise ValueError(f"Invalid block_type: {self.block_type}")
1480
+
1481
+ if output_hidden_states:
1482
+ all_hidden_states += (hidden_states,)
1483
+
1484
+ if self.gradient_checkpointing and self.training:
1485
+ hidden_states = self._gradient_checkpointing_func(
1486
+ mixer_block.__call__, hidden_states, past_key_values, cache_position, layer_mask
1487
+ )
1488
+ else:
1489
+ hidden_states = mixer_block(
1490
+ hidden_states,
1491
+ past_key_values=past_key_values,
1492
+ cache_position=cache_position,
1493
+ attention_mask=layer_mask,
1494
+ output_attentions=output_attentions,
1495
+ )
1496
+
1497
+ hidden_states = self.norm_f(hidden_states)
1498
+
1499
+ if output_hidden_states:
1500
+ all_hidden_states = all_hidden_states + (hidden_states,)
1501
+
1502
+ if past_key_values is not None and not past_key_values.has_previous_state:
1503
+ past_key_values.has_previous_state = True
1504
+
1505
+ if not return_dict:
1506
+ return tuple(v for v in [hidden_states, past_key_values, all_hidden_states] if v is not None)
1507
+
1508
+ return NemotronHOutput(
1509
+ last_hidden_state=hidden_states,
1510
+ past_key_values=past_key_values if use_cache else None,
1511
+ hidden_states=all_hidden_states,
1512
+ attentions=all_self_attns,
1513
+ )
1514
+
1515
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
1516
+ if self.config._attn_implementation == "flash_attention_2":
1517
+ if attention_mask is not None and 0.0 in attention_mask:
1518
+ return attention_mask
1519
+ return None
1520
+
1521
+ dtype, device = input_tensor.dtype, input_tensor.device
1522
+ min_dtype = torch.finfo(dtype).min
1523
+ sequence_length = input_tensor.shape[1]
1524
+ if cache_position is None:
1525
+ target_length = sequence_length
1526
+ else:
1527
+ target_length = cache_position[-1] + 1
1528
+
1529
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1530
+ if sequence_length != 1:
1531
+ causal_mask = torch.triu(causal_mask, diagonal=1)
1532
+ if cache_position is not None:
1533
+ causal_mask *= (torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)).to(torch.bool)
1534
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1535
+ if attention_mask is not None:
1536
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
1537
+ if attention_mask.dim() == 2:
1538
+ mask_length = attention_mask.shape[-1]
1539
+ padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
1540
+ causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
1541
+
1542
+ if (
1543
+ self.config._attn_implementation == "sdpa"
1544
+ and attention_mask is not None
1545
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
1546
+ ):
1547
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
1548
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
1549
+ # Details: https://github.com/pytorch/pytorch/issues/110213
1550
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
1551
+
1552
+ return causal_mask
1553
+
1554
+ def _update_mamba_mask(self, attention_mask, cache_position):
1555
+ """
1556
+ No need for zeroing states when
1557
+ 1. Cached forward
1558
+ 2. Attending to all inputs
1559
+ """
1560
+ mamba_mask = attention_mask
1561
+ if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
1562
+ mamba_mask = None
1563
+ return mamba_mask
1564
+
1565
+
1566
+ def register_nemotron_h_conversion_mapping():
1567
+ try:
1568
+ from transformers.conversion_mapping import WeightRenaming, register_checkpoint_conversion_mapping
1569
+ has_conversion_mapping = True
1570
+ except ImportError:
1571
+ has_conversion_mapping = False
1572
+
1573
+ if not has_conversion_mapping:
1574
+ return
1575
+
1576
+ register_checkpoint_conversion_mapping(
1577
+ "nemotron_h",
1578
+ [
1579
+ WeightRenaming("backbone.", "model."),
1580
+ WeightRenaming("embedding.weight", "embeddings.weight"),
1581
+ ],
1582
+ overwrite=True,
1583
+ )
1584
+
1585
+
1586
+
1587
+ @add_start_docstrings(
1588
+ """
1589
+ The NEMOTRONH Model transformer with a language modeling head on top (linear layer with weights not tied to the input
1590
+ embeddings).
1591
+ """,
1592
+ NEMOTRONH_START_DOCSTRING,
1593
+ )
1594
+ class NemotronHForCausalLM(NemotronHPreTrainedModel, GenerationMixin):
1595
+ _keys_to_ignore_on_load_unexpected = [r"mtp.*"]
1596
+
1597
+ def __init__(self, config):
1598
+ super().__init__(config)
1599
+ self.model = NemotronHModel(config)
1600
+ self.vocab_size = config.vocab_size
1601
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
1602
+
1603
+ register_nemotron_h_conversion_mapping()
1604
+
1605
+ # Initialize weights and apply final processing
1606
+ self.post_init()
1607
+
1608
+ def _get_key_renaming_mapping(
1609
+ self,
1610
+ checkpoint_keys: list[str],
1611
+ key_mapping: Optional[dict[str, str]] = None,
1612
+ loading_base_model_from_task_state_dict: bool = False,
1613
+ loading_task_model_from_base_state_dict: bool = False,
1614
+ ):
1615
+ """Convert backbone.* keys to model.* keys for backward compatibility."""
1616
+ if key_mapping is None:
1617
+ key_mapping = {"^backbone": "model"}
1618
+ else:
1619
+ key_mapping = {"^backbone": "model", **key_mapping}
1620
+
1621
+ has_prefix_module = any(s.startswith("backbone") for s in checkpoint_keys)
1622
+ if has_prefix_module:
1623
+ loading_task_model_from_base_state_dict = False
1624
+
1625
+ return super()._get_key_renaming_mapping(
1626
+ checkpoint_keys,
1627
+ key_mapping,
1628
+ loading_base_model_from_task_state_dict=loading_base_model_from_task_state_dict,
1629
+ loading_task_model_from_base_state_dict=loading_task_model_from_base_state_dict,
1630
+ )
1631
+
1632
+ def get_input_embeddings(self):
1633
+ return self.model.get_input_embeddings()
1634
+
1635
+ def set_input_embeddings(self, new_embeddings):
1636
+ return self.model.set_input_embeddings(new_embeddings)
1637
+
1638
+ def get_output_embeddings(self):
1639
+ return self.lm_head
1640
+
1641
+ def set_output_embeddings(self, new_embeddings):
1642
+ self.lm_head = new_embeddings
1643
+
1644
+ def get_decoder(self):
1645
+ return self.model
1646
+
1647
+ def set_decoder(self, decoder):
1648
+ self.model = decoder
1649
+
1650
+ def prepare_inputs_for_generation(
1651
+ self,
1652
+ input_ids,
1653
+ past_key_values=None,
1654
+ attention_mask=None,
1655
+ inputs_embeds=None,
1656
+ cache_position=None,
1657
+ position_ids=None,
1658
+ use_cache=True,
1659
+ is_first_iteration=False,
1660
+ **kwargs,
1661
+ ):
1662
+ # Overwritten -- has a unique cache type, `NemotronHHybridDynamicCache`
1663
+
1664
+ if past_key_values is None:
1665
+ past_key_values = NemotronHHybridDynamicCache(
1666
+ self.config, input_ids.shape[0], dtype=self.dtype, device=self.device
1667
+ )
1668
+
1669
+ kwargs["logits_to_keep"] = self.config.num_logits_to_keep
1670
+ model_inputs = super().prepare_inputs_for_generation(
1671
+ input_ids,
1672
+ past_key_values=past_key_values,
1673
+ attention_mask=attention_mask,
1674
+ inputs_embeds=inputs_embeds,
1675
+ cache_position=cache_position,
1676
+ position_ids=position_ids,
1677
+ use_cache=use_cache,
1678
+ is_first_iteration=is_first_iteration,
1679
+ **kwargs,
1680
+ )
1681
+
1682
+ return model_inputs
1683
+
1684
+ @add_start_docstrings_to_model_forward(NEMOTRONH_INPUTS_DOCSTRING)
1685
+ @add_code_sample_docstrings(
1686
+ checkpoint=_CHECKPOINT_FOR_DOC,
1687
+ output_type=NemotronHCausalLMOutput,
1688
+ config_class=_CONFIG_FOR_DOC,
1689
+ )
1690
+ def forward(
1691
+ self,
1692
+ input_ids: Optional[torch.LongTensor] = None,
1693
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1694
+ position_ids: Optional[torch.LongTensor] = None,
1695
+ past_key_values: Optional[NemotronHHybridDynamicCache] = None,
1696
+ labels: Optional[torch.LongTensor] = None,
1697
+ output_attentions: Optional[bool] = None,
1698
+ output_hidden_states: Optional[bool] = None,
1699
+ return_dict: Optional[bool] = None,
1700
+ use_cache: Optional[bool] = None,
1701
+ cache_position: Optional[torch.Tensor] = None,
1702
+ attention_mask: Optional[torch.Tensor] = None,
1703
+ **kwargs, # for now we need this for generation
1704
+ ) -> Union[tuple, NemotronHCausalLMOutput]:
1705
+ r"""
1706
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1707
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
1708
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
1709
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
1710
+ """
1711
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1712
+
1713
+ output_hidden_states = (
1714
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1715
+ )
1716
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1717
+
1718
+ nemotron_h_outputs = self.model(
1719
+ input_ids,
1720
+ past_key_values=past_key_values,
1721
+ inputs_embeds=inputs_embeds,
1722
+ output_attentions=output_attentions,
1723
+ output_hidden_states=output_hidden_states,
1724
+ return_dict=return_dict,
1725
+ use_cache=use_cache,
1726
+ cache_position=cache_position,
1727
+ attention_mask=attention_mask,
1728
+ )
1729
+ hidden_states = nemotron_h_outputs[0]
1730
+
1731
+ logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
1732
+
1733
+ loss = None
1734
+ if labels is not None:
1735
+ # move labels to correct device to enable model parallelism
1736
+ labels = labels.to(logits.device)
1737
+ # Shift so that tokens < n predict n
1738
+ shift_logits = logits[..., :-1, :].contiguous()
1739
+ shift_labels = labels[..., 1:].contiguous()
1740
+ # Flatten the tokens
1741
+ loss_fct = CrossEntropyLoss()
1742
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
1743
+
1744
+ if not return_dict:
1745
+ output = (logits,) + nemotron_h_outputs[1:]
1746
+ return ((loss,) + output) if loss is not None else output
1747
+
1748
+ return NemotronHCausalLMOutput(
1749
+ loss=loss,
1750
+ logits=logits,
1751
+ past_key_values=nemotron_h_outputs.past_key_values,
1752
+ hidden_states=nemotron_h_outputs.hidden_states,
1753
+ attentions=nemotron_h_outputs.attentions,
1754
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|im_end|>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "<|im_end|>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "unk_token": {
24
+ "content": "<unk>",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ }
30
+ }
super_v3_reasoning_parser.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from vllm.reasoning.abs_reasoning_parsers import ReasoningParserManager
2
+ from vllm.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser
3
+
4
+
5
+ @ReasoningParserManager.register_module("nano_v3")
6
+ class NanoV3ReasoningParser(DeepSeekR1ReasoningParser):
7
+ def extract_reasoning(self, model_output, request):
8
+ reasoning_content, final_content = super().extract_reasoning(
9
+ model_output, request
10
+ )
11
+ if (
12
+ hasattr(request, "chat_template_kwargs")
13
+ and request.chat_template_kwargs
14
+ and (
15
+ request.chat_template_kwargs.get("enable_thinking") is False
16
+ or request.chat_template_kwargs.get("force_nonempty_content") is True
17
+ )
18
+ and final_content is None
19
+ ):
20
+ """
21
+ The original `deepseek_r1` reasoning parser this inherits from will automatically put everything in the reasoning content when it cannot parse out reasoning. This was fine for the DeepSeek R1 model that was not intended to be used without reasoning.
22
+
23
+ 1. Since the Nemotron 3 Nano and Super both have thinking off modes modulated by "enable_thinking=false" in the chat template kwargs, this change instead which will properly place the content in cases where there is no thinking enabled via config.
24
+ 2. There are rare cases where the model will output only reasoning without an end-think token `</think>` (e.g. reasoning exceeds max length), which results in empty content returned. End users may want to unilaterally avoid such cases and always have a content response even if the model does not finish its reasoning.
25
+ """
26
+ # Put all nonempty content into the content, rather than return content
27
+ reasoning_content, final_content = None, reasoning_content
28
+
29
+ return reasoning_content, final_content
tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:623c34567aebb18582765289fbe23d901c62704d6518d71866e0e58db892b5b7
3
+ size 17077484
tokenizer_config.json ADDED
The diff for this file is too large to render. See raw diff