You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
Code snippet for reproduction:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-9b-it")
model = AutoModelForCausalLM.from_pretrained("google/recurrentgemma-9b-it", device_map="cuda", torch_dtype=torch.float16)
input_text = "Write me a poem about Machine Learning." * 300 # This string is 2402 tokens long, which is larger than 2048, the sliding window attention width
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=20)
print(tokenizer.decode(outputs[0]))
Error message:
Traceback (most recent call last):
File "/data2/assaf/tmp/test_rg.py", line 13, in
outputs = model.generate(**input_ids, max_new_tokens=20)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 2326, in generate
result = self._sample(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/generation/utils.py", line 3289, in _sample
outputs = model_forward(**model_inputs, return_dict=True)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 852, in forward
outputs = self.model(
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 717, in forward
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
File "/data2/assaf/conda_envs/recurrent_gemma_tmp/lib/python3.10/site-packages/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py", line 764, in _update_causal_mask
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
RuntimeError: The size of tensor a (2048) must match the size of tensor b (2402) at non-singleton dimension 3
Expected behavior
If the sequence is longer than the sliding window width (like it is now in the script) then the script crashes with the error message above.
If the sequence is shorter than the sliding window width (e.g. replace *300 by *200) then the script runs fine.
The bug was seen in transformers version v4.50.3
It does not reproduce on earlier transformers versions, such as v4.42.4
The text was updated successfully, but these errors were encountered:
Hi @assafbk, can you use git bisect to identify the commit where this bug appeared? It'll help a lot in identifying the cause and assigning the right person to fix it!
System Info
System Info:
transformers
version: 4.50.3Who can help?
@ArthurZucker, @gante
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Code snippet for reproduction:
Error message:
Expected behavior
If the sequence is longer than the sliding window width (like it is now in the script) then the script crashes with the error message above.
If the sequence is shorter than the sliding window width (e.g. replace *300 by *200) then the script runs fine.
The bug was seen in transformers version v4.50.3
It does not reproduce on earlier transformers versions, such as v4.42.4
The text was updated successfully, but these errors were encountered: