Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RecurrentGemma crashes during inference for inputs longer than sliding window width #37219

Open
1 of 4 tasks
assafbk opened this issue Apr 2, 2025 · 1 comment
Open
1 of 4 tasks
Labels

Comments

@assafbk
Copy link

assafbk commented Apr 2, 2025

System Info

System Info:

  • transformers version: 4.50.3
  • Platform: Linux-6.8.0-52-generic-x86_64-with-glibc2.35
  • Python version: 3.10.16
  • Huggingface_hub version: 0.30.1
  • Safetensors version: 0.5.3
  • Accelerate version: 1.6.0
  • Accelerate config: not found
  • DeepSpeed version: not installed
  • PyTorch version (GPU?): 2.5.1+cu118 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA RTX A6000

Who can help?

@ArthurZucker, @gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Code snippet for reproduction:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-9b-it")
model = AutoModelForCausalLM.from_pretrained("google/recurrentgemma-9b-it", device_map="cuda", torch_dtype=torch.float16)

input_text = "Write me a poem about Machine Learning." * 300    # This string is 2402 tokens long, which is larger than 2048, the sliding window attention width
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))

Error message:

Traceback (most recent call last):
File "/data2/assaf/tmp/test_rg.py", line 13, in
outputs = model.generate(**input_ids, max_new_tokens=20)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
result = self._sample(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 3289, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 852, in forward
outputs = self.model(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 717, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 764, in _update_causal_mask
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2402) at non-singleton dimension 3

Expected behavior

If the sequence is longer than the sliding window width (like it is now in the script) then the script crashes with the error message above.

If the sequence is shorter than the sliding window width (e.g. replace *300 by *200) then the script runs fine.

The bug was seen in transformers version v4.50.3
It does not reproduce on earlier transformers versions, such as v4.42.4

@assafbk assafbk added the bug label Apr 2, 2025
@Rocketknight1
Copy link
Member

Hi @assafbk, can you use git bisect to identify the commit where this bug appeared? It'll help a lot in identifying the cause and assigning the right person to fix it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants