Commit 7b33743a authored by Tri Dao's avatar Tri Dao
Browse files

[Gen] Add back num_last_tokens in gpt.py

Showing with 2 additions and 0 deletions
+2 -0
......@@ -634,6 +634,8 @@ class GPTLMHeadModel(GPTPreTrainedModel, GenerationMixin):
input_ids, position_ids=position_ids, inference_params=inference_params
)
assert hidden_states.ndim == 3, "sequence_parallel is not supported in generation mode"
if num_last_tokens > 0:
hidden_states = hidden_states[:, -num_last_tokens:]
if self.project_out is not None:
hidden_states = self.project_out(hidden_states)
lm_logits = self.lm_head(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment