Skip to content

Instantly share code, notes, and snippets.

@BlGene
Created January 31, 2025 10:42
Show Gist options
  • Select an option

  • Save BlGene/607c7bee450e03835aa2bf0d2fd2959a to your computer and use it in GitHub Desktop.

Select an option

Save BlGene/607c7bee450e03835aa2bf0d2fd2959a to your computer and use it in GitHub Desktop.
```
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[13], line 8
6 print("works")
7 trainer.model.train(False)
----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416)
9 print("fails.")
12 orig_context_manager = trainer.compute_loss_context_manager
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3729 loss_kwargs["num_items_in_batch"] = num_items_in_batch
3730 inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
3732 # Save past state if it exists
3733 # TODO: this needs to be fixed and made cleaner later.
3734 if self.args.past_index >= 0:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
527 causal_mask = self._update_causal_mask(
528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
529 )
--> 530 outputs = self.language_model(
531 attention_mask=causal_mask,
532 position_ids=position_ids,
533 past_key_values=past_key_values,
534 inputs_embeds=inputs_embeds,
535 use_cache=use_cache,
536 output_attentions=output_attentions,
537 output_hidden_states=output_hidden_states,
538 return_dict=return_dict,
539 cache_position=cache_position,
540 num_logits_to_keep=num_logits_to_keep,
541 )
543 logits = outputs.logits
544 loss = None
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
843 input_ids=input_ids,
844 attention_mask=attention_mask,
845 position_ids=position_ids,
846 past_key_values=past_key_values,
847 inputs_embeds=inputs_embeds,
848 use_cache=use_cache,
849 output_attentions=output_attentions,
850 output_hidden_states=output_hidden_states,
851 return_dict=return_dict,
852 cache_position=cache_position,
853 )
855 hidden_states = outputs[0]
856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
617 layer_outputs = self._gradient_checkpointing_func(
618 decoder_layer.__call__,
619 hidden_states,
(...)
626 cache_position,
627 )
628 else:
--> 629 layer_outputs = decoder_layer(
630 hidden_states,
631 position_embeddings=position_embeddings,
632 attention_mask=causal_mask,
633 position_ids=position_ids,
634 past_key_value=past_key_values,
635 output_attentions=output_attentions,
636 use_cache=use_cache,
637 cache_position=cache_position,
638 **flash_attn_kwargs,
639 )
641 hidden_states = layer_outputs[0]
643 if output_attentions:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
296 hidden_states = self.input_layernorm(hidden_states)
298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
300 hidden_states=hidden_states,
301 position_embeddings=position_embeddings,
302 attention_mask=attention_mask,
303 position_ids=position_ids,
304 past_key_value=past_key_value,
305 output_attentions=output_attentions,
306 use_cache=use_cache,
307 cache_position=cache_position,
308 )
309 hidden_states = self.post_attention_layernorm(hidden_states)
310 hidden_states = residual + hidden_states
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
221 if past_key_value is not None:
222 # sin and cos are specific to RoPE models; cache_position needed for the static cache
223 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
--> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
226 attention_interface: Callable = eager_attention_forward
227 if self.config._attn_implementation != "eager":
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1714 else:
1715 update_fn = self._static_update
-> 1717 return update_fn(
1718 cache_position,
1719 layer_idx,
1720 key_states,
1721 value_states,
1722 k_out,
1723 v_out,
1724 k_out.shape[2],
1725 )
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694 k_out[:, :, cache_position] = key_states
1695 v_out[:, :, cache_position] = value_states
1697 self.key_cache[layer_idx] = k_out
```
Error for trainer with evaluation:
```
"name": "RuntimeError",
"message": "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[35], line 27
1 # batch = [valid_dataset[0], valid_dataset[1],]
2 # batch = [valid_dataset[i] for i in range(8)]
3 # inputs = collate_fn(batch)
(...)
24 # self.orig_context_manager.__exit__(type, value, traceback)
25 # trainer.compute_loss_context_manager = TempTrainContext(trainer)
---> 27 trainer.train()
28 #raise ValueError
29 # image, labels = valid_dataset.entries[0]
30 #print(\"xxx\")
31 #print(generate_ids)
32 #model(**inputs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2171, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
2169 hf_hub_utils.enable_progress_bars()
2170 else:
-> 2171 return inner_training_loop(
2172 args=args,
2173 resume_from_checkpoint=resume_from_checkpoint,
2174 trial=trial,
2175 ignore_keys_for_eval=ignore_keys_for_eval,
2176 )
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2598, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2596 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
2597 self.control = self.callback_handler.on_step_end(args, self.state, self.control)
-> 2598 self._maybe_log_save_evaluate(
2599 tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time
2600 )
2601 else:
2602 self.control = self.callback_handler.on_substep_end(args, self.state, self.control)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3071, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time)
3069 metrics = None
3070 if self.control.should_evaluate:
-> 3071 metrics = self._evaluate(trial, ignore_keys_for_eval)
3072 is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial)
3074 if self.args.save_strategy == SaveStrategy.BEST:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3025, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler)
3024 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False):
-> 3025 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
3026 self._report_to_hp_search(trial, self.state.global_step, metrics)
3028 # Run delayed LR scheduler now that metrics are populated
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer_seq2seq.py:197, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs)
195 self.gather_function = self.accelerator.gather
196 self._gen_kwargs = gen_kwargs
--> 197 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4073, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix)
4070 start_time = time.time()
4072 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
-> 4073 output = eval_loop(
4074 eval_dataloader,
4075 description=\"Evaluation\",
4076 # No point gathering the predictions if there are no metrics, otherwise we defer to
4077 # self.args.prediction_loss_only
4078 prediction_loss_only=True if self.compute_metrics is None else None,
4079 ignore_keys=ignore_keys,
4080 metric_key_prefix=metric_key_prefix,
4081 )
4083 total_batch_size = self.args.eval_batch_size * self.args.world_size
4084 if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4267, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix)
4264 batch_size = observed_batch_size
4266 # Prediction step
-> 4267 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
4268 main_input_name = getattr(self.model, \"main_input_name\", \"input_ids\")
4269 inputs_decode = (
4270 self._prepare_input(inputs[main_input_name]) if \"inputs\" in args.include_for_metrics else None
4271 )
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer_seq2seq.py:295, in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs)
271 \"\"\"
272 Perform an evaluation step on `model` using `inputs`.
273
(...)
291 labels (each being optional).
292 \"\"\"
294 if not self.args.predict_with_generate or prediction_loss_only:
--> 295 return super().prediction_step(
296 model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
297 )
299 has_labels = \"labels\" in inputs
300 inputs = self._prepare_inputs(inputs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4484, in Trainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys)
4481 if has_labels or loss_without_labels:
4482 with self.compute_loss_context_manager():
4483 #model.train() # XXX: Max
-> 4484 loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
4485 loss = loss.mean().detach()
4487 if isinstance(outputs, dict):
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
3729 loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch
3730 inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
3732 # Save past state if it exists
3733 # TODO: this needs to be fixed and made cleaner later.
3734 if self.args.past_index >= 0:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs)
818 def forward(*args, **kwargs):
--> 819 return model_forward(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs)
806 def __call__(self, *args, **kwargs):
--> 807 return convert_to_fp32(self.model_forward(*args, **kwargs))
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs)
41 @functools.wraps(func)
42 def decorate_autocast(*args, **kwargs):
43 with autocast_instance:
---> 44 return func(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep)
525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels)
527 causal_mask = self._update_causal_mask(
528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training
529 )
--> 530 outputs = self.language_model(
531 attention_mask=causal_mask,
532 position_ids=position_ids,
533 past_key_values=past_key_values,
534 inputs_embeds=inputs_embeds,
535 use_cache=use_cache,
536 output_attentions=output_attentions,
537 output_hidden_states=output_hidden_states,
538 return_dict=return_dict,
539 cache_position=cache_position,
540 num_logits_to_keep=num_logits_to_keep,
541 )
543 logits = outputs.logits
544 loss = None
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 842 outputs = self.model(
843 input_ids=input_ids,
844 attention_mask=attention_mask,
845 position_ids=position_ids,
846 past_key_values=past_key_values,
847 inputs_embeds=inputs_embeds,
848 use_cache=use_cache,
849 output_attentions=output_attentions,
850 output_hidden_states=output_hidden_states,
851 return_dict=return_dict,
852 cache_position=cache_position,
853 )
855 hidden_states = outputs[0]
856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs)
617 layer_outputs = self._gradient_checkpointing_func(
618 decoder_layer.__call__,
619 hidden_states,
(...)
626 cache_position,
627 )
628 else:
--> 629 layer_outputs = decoder_layer(
630 hidden_states,
631 position_embeddings=position_embeddings,
632 attention_mask=causal_mask,
633 position_ids=position_ids,
634 past_key_value=past_key_values,
635 output_attentions=output_attentions,
636 use_cache=use_cache,
637 cache_position=cache_position,
638 **flash_attn_kwargs,
639 )
641 hidden_states = layer_outputs[0]
643 if output_attentions:
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position)
296 hidden_states = self.input_layernorm(hidden_states)
298 # Self Attention
--> 299 hidden_states, self_attn_weights = self.self_attn(
300 hidden_states=hidden_states,
301 position_embeddings=position_embeddings,
302 attention_mask=attention_mask,
303 position_ids=position_ids,
304 past_key_value=past_key_value,
305 output_attentions=output_attentions,
306 use_cache=use_cache,
307 cache_position=cache_position,
308 )
309 hidden_states = self.post_attention_layernorm(hidden_states)
310 hidden_states = residual + hidden_states
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs)
1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1735 else:
-> 1736 return self._call_impl(*args, **kwargs)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs)
1742 # If we don't have any hooks, we want to skip the rest of the logic in
1743 # this function, and just call forward.
1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1745 or _global_backward_pre_hooks or _global_backward_hooks
1746 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1747 return forward_call(*args, **kwargs)
1749 result = None
1750 called_always_called_hooks = set()
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
221 if past_key_value is not None:
222 # sin and cos are specific to RoPE models; cache_position needed for the static cache
223 cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position}
--> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
226 attention_interface: Callable = eager_attention_forward
227 if self.config._attn_implementation != \"eager\":
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs)
1714 else:
1715 update_fn = self._static_update
-> 1717 return update_fn(
1718 cache_position,
1719 layer_idx,
1720 key_states,
1721 value_states,
1722 k_out,
1723 v_out,
1724 k_out.shape[2],
1725 )
File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len)
1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len):
-> 1694 k_out[:, :, cache_position] = key_states
1695 v_out[:, :, cache_position] = value_states
1697 self.key_cache[layer_idx] = k_out
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!"
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment