Created
January 31, 2025 10:42
-
-
Save BlGene/607c7bee450e03835aa2bf0d2fd2959a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| ``` | |
| --------------------------------------------------------------------------- | |
| RuntimeError Traceback (most recent call last) | |
| Cell In[13], line 8 | |
| 6 print("works") | |
| 7 trainer.model.train(False) | |
| ----> 8 trainer.compute_loss(model, inputs, return_outputs=False, num_items_in_batch=416) | |
| 9 print("fails.") | |
| 12 orig_context_manager = trainer.compute_loss_context_manager | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch) | |
| 3729 loss_kwargs["num_items_in_batch"] = num_items_in_batch | |
| 3730 inputs = {**inputs, **loss_kwargs} | |
| -> 3731 outputs = model(**inputs) | |
| 3732 # Save past state if it exists | |
| 3733 # TODO: this needs to be fixed and made cleaner later. | |
| 3734 if self.args.past_index >= 0: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep) | |
| 525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) | |
| 527 causal_mask = self._update_causal_mask( | |
| 528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training | |
| 529 ) | |
| --> 530 outputs = self.language_model( | |
| 531 attention_mask=causal_mask, | |
| 532 position_ids=position_ids, | |
| 533 past_key_values=past_key_values, | |
| 534 inputs_embeds=inputs_embeds, | |
| 535 use_cache=use_cache, | |
| 536 output_attentions=output_attentions, | |
| 537 output_hidden_states=output_hidden_states, | |
| 538 return_dict=return_dict, | |
| 539 cache_position=cache_position, | |
| 540 num_logits_to_keep=num_logits_to_keep, | |
| 541 ) | |
| 543 logits = outputs.logits | |
| 544 loss = None | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs) | |
| 840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| --> 842 outputs = self.model( | |
| 843 input_ids=input_ids, | |
| 844 attention_mask=attention_mask, | |
| 845 position_ids=position_ids, | |
| 846 past_key_values=past_key_values, | |
| 847 inputs_embeds=inputs_embeds, | |
| 848 use_cache=use_cache, | |
| 849 output_attentions=output_attentions, | |
| 850 output_hidden_states=output_hidden_states, | |
| 851 return_dict=return_dict, | |
| 852 cache_position=cache_position, | |
| 853 ) | |
| 855 hidden_states = outputs[0] | |
| 856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs) | |
| 617 layer_outputs = self._gradient_checkpointing_func( | |
| 618 decoder_layer.__call__, | |
| 619 hidden_states, | |
| (...) | |
| 626 cache_position, | |
| 627 ) | |
| 628 else: | |
| --> 629 layer_outputs = decoder_layer( | |
| 630 hidden_states, | |
| 631 position_embeddings=position_embeddings, | |
| 632 attention_mask=causal_mask, | |
| 633 position_ids=position_ids, | |
| 634 past_key_value=past_key_values, | |
| 635 output_attentions=output_attentions, | |
| 636 use_cache=use_cache, | |
| 637 cache_position=cache_position, | |
| 638 **flash_attn_kwargs, | |
| 639 ) | |
| 641 hidden_states = layer_outputs[0] | |
| 643 if output_attentions: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position) | |
| 296 hidden_states = self.input_layernorm(hidden_states) | |
| 298 # Self Attention | |
| --> 299 hidden_states, self_attn_weights = self.self_attn( | |
| 300 hidden_states=hidden_states, | |
| 301 position_embeddings=position_embeddings, | |
| 302 attention_mask=attention_mask, | |
| 303 position_ids=position_ids, | |
| 304 past_key_value=past_key_value, | |
| 305 output_attentions=output_attentions, | |
| 306 use_cache=use_cache, | |
| 307 cache_position=cache_position, | |
| 308 ) | |
| 309 hidden_states = self.post_attention_layernorm(hidden_states) | |
| 310 hidden_states = residual + hidden_states | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) | |
| 221 if past_key_value is not None: | |
| 222 # sin and cos are specific to RoPE models; cache_position needed for the static cache | |
| 223 cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} | |
| --> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| 226 attention_interface: Callable = eager_attention_forward | |
| 227 if self.config._attn_implementation != "eager": | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs) | |
| 1714 else: | |
| 1715 update_fn = self._static_update | |
| -> 1717 return update_fn( | |
| 1718 cache_position, | |
| 1719 layer_idx, | |
| 1720 key_states, | |
| 1721 value_states, | |
| 1722 k_out, | |
| 1723 v_out, | |
| 1724 k_out.shape[2], | |
| 1725 ) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len) | |
| 1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | |
| -> 1694 k_out[:, :, cache_position] = key_states | |
| 1695 v_out[:, :, cache_position] = value_states | |
| 1697 self.key_cache[layer_idx] = k_out | |
| ``` | |
| Error for trainer with evaluation: | |
| ``` | |
| "name": "RuntimeError", | |
| "message": "Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!", | |
| "stack": "--------------------------------------------------------------------------- | |
| RuntimeError Traceback (most recent call last) | |
| Cell In[35], line 27 | |
| 1 # batch = [valid_dataset[0], valid_dataset[1],] | |
| 2 # batch = [valid_dataset[i] for i in range(8)] | |
| 3 # inputs = collate_fn(batch) | |
| (...) | |
| 24 # self.orig_context_manager.__exit__(type, value, traceback) | |
| 25 # trainer.compute_loss_context_manager = TempTrainContext(trainer) | |
| ---> 27 trainer.train() | |
| 28 #raise ValueError | |
| 29 # image, labels = valid_dataset.entries[0] | |
| 30 #print(\"xxx\") | |
| 31 #print(generate_ids) | |
| 32 #model(**inputs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2171, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs) | |
| 2169 hf_hub_utils.enable_progress_bars() | |
| 2170 else: | |
| -> 2171 return inner_training_loop( | |
| 2172 args=args, | |
| 2173 resume_from_checkpoint=resume_from_checkpoint, | |
| 2174 trial=trial, | |
| 2175 ignore_keys_for_eval=ignore_keys_for_eval, | |
| 2176 ) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:2598, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval) | |
| 2596 self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch | |
| 2597 self.control = self.callback_handler.on_step_end(args, self.state, self.control) | |
| -> 2598 self._maybe_log_save_evaluate( | |
| 2599 tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time | |
| 2600 ) | |
| 2601 else: | |
| 2602 self.control = self.callback_handler.on_substep_end(args, self.state, self.control) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3071, in Trainer._maybe_log_save_evaluate(self, tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval, start_time) | |
| 3069 metrics = None | |
| 3070 if self.control.should_evaluate: | |
| -> 3071 metrics = self._evaluate(trial, ignore_keys_for_eval) | |
| 3072 is_new_best_metric = self._determine_best_metric(metrics=metrics, trial=trial) | |
| 3074 if self.args.save_strategy == SaveStrategy.BEST: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3025, in Trainer._evaluate(self, trial, ignore_keys_for_eval, skip_scheduler) | |
| 3024 def _evaluate(self, trial, ignore_keys_for_eval, skip_scheduler=False): | |
| -> 3025 metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) | |
| 3026 self._report_to_hp_search(trial, self.state.global_step, metrics) | |
| 3028 # Run delayed LR scheduler now that metrics are populated | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer_seq2seq.py:197, in Seq2SeqTrainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix, **gen_kwargs) | |
| 195 self.gather_function = self.accelerator.gather | |
| 196 self._gen_kwargs = gen_kwargs | |
| --> 197 return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4073, in Trainer.evaluate(self, eval_dataset, ignore_keys, metric_key_prefix) | |
| 4070 start_time = time.time() | |
| 4072 eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop | |
| -> 4073 output = eval_loop( | |
| 4074 eval_dataloader, | |
| 4075 description=\"Evaluation\", | |
| 4076 # No point gathering the predictions if there are no metrics, otherwise we defer to | |
| 4077 # self.args.prediction_loss_only | |
| 4078 prediction_loss_only=True if self.compute_metrics is None else None, | |
| 4079 ignore_keys=ignore_keys, | |
| 4080 metric_key_prefix=metric_key_prefix, | |
| 4081 ) | |
| 4083 total_batch_size = self.args.eval_batch_size * self.args.world_size | |
| 4084 if f\"{metric_key_prefix}_jit_compilation_time\" in output.metrics: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4267, in Trainer.evaluation_loop(self, dataloader, description, prediction_loss_only, ignore_keys, metric_key_prefix) | |
| 4264 batch_size = observed_batch_size | |
| 4266 # Prediction step | |
| -> 4267 losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) | |
| 4268 main_input_name = getattr(self.model, \"main_input_name\", \"input_ids\") | |
| 4269 inputs_decode = ( | |
| 4270 self._prepare_input(inputs[main_input_name]) if \"inputs\" in args.include_for_metrics else None | |
| 4271 ) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer_seq2seq.py:295, in Seq2SeqTrainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys, **gen_kwargs) | |
| 271 \"\"\" | |
| 272 Perform an evaluation step on `model` using `inputs`. | |
| 273 | |
| (...) | |
| 291 labels (each being optional). | |
| 292 \"\"\" | |
| 294 if not self.args.predict_with_generate or prediction_loss_only: | |
| --> 295 return super().prediction_step( | |
| 296 model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys | |
| 297 ) | |
| 299 has_labels = \"labels\" in inputs | |
| 300 inputs = self._prepare_inputs(inputs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:4484, in Trainer.prediction_step(self, model, inputs, prediction_loss_only, ignore_keys) | |
| 4481 if has_labels or loss_without_labels: | |
| 4482 with self.compute_loss_context_manager(): | |
| 4483 #model.train() # XXX: Max | |
| -> 4484 loss, outputs = self.compute_loss(model, inputs, return_outputs=True) | |
| 4485 loss = loss.mean().detach() | |
| 4487 if isinstance(outputs, dict): | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch) | |
| 3729 loss_kwargs[\"num_items_in_batch\"] = num_items_in_batch | |
| 3730 inputs = {**inputs, **loss_kwargs} | |
| -> 3731 outputs = model(**inputs) | |
| 3732 # Save past state if it exists | |
| 3733 # TODO: this needs to be fixed and made cleaner later. | |
| 3734 if self.args.past_index >= 0: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32.<locals>.forward(*args, **kwargs) | |
| 818 def forward(*args, **kwargs): | |
| --> 819 return model_forward(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.__call__(self, *args, **kwargs) | |
| 806 def __call__(self, *args, **kwargs): | |
| --> 807 return convert_to_fp32(self.model_forward(*args, **kwargs)) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/amp/autocast_mode.py:44, in autocast_decorator.<locals>.decorate_autocast(*args, **kwargs) | |
| 41 @functools.wraps(func) | |
| 42 def decorate_autocast(*args, **kwargs): | |
| 43 with autocast_instance: | |
| ---> 44 return func(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/paligemma/modeling_paligemma.py:530, in PaliGemmaForConditionalGeneration.forward(self, input_ids, pixel_values, attention_mask, position_ids, past_key_values, token_type_ids, cache_position, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, num_logits_to_keep) | |
| 525 labels = torch.where(input_ids == self.pad_token_id, self.config.ignore_index, labels) | |
| 527 causal_mask = self._update_causal_mask( | |
| 528 attention_mask, token_type_ids, past_key_values, cache_position, input_ids, inputs_embeds, is_training | |
| 529 ) | |
| --> 530 outputs = self.language_model( | |
| 531 attention_mask=causal_mask, | |
| 532 position_ids=position_ids, | |
| 533 past_key_values=past_key_values, | |
| 534 inputs_embeds=inputs_embeds, | |
| 535 use_cache=use_cache, | |
| 536 output_attentions=output_attentions, | |
| 537 output_hidden_states=output_hidden_states, | |
| 538 return_dict=return_dict, | |
| 539 cache_position=cache_position, | |
| 540 num_logits_to_keep=num_logits_to_keep, | |
| 541 ) | |
| 543 logits = outputs.logits | |
| 544 loss = None | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:842, in Gemma2ForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs) | |
| 840 return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| 841 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) | |
| --> 842 outputs = self.model( | |
| 843 input_ids=input_ids, | |
| 844 attention_mask=attention_mask, | |
| 845 position_ids=position_ids, | |
| 846 past_key_values=past_key_values, | |
| 847 inputs_embeds=inputs_embeds, | |
| 848 use_cache=use_cache, | |
| 849 output_attentions=output_attentions, | |
| 850 output_hidden_states=output_hidden_states, | |
| 851 return_dict=return_dict, | |
| 852 cache_position=cache_position, | |
| 853 ) | |
| 855 hidden_states = outputs[0] | |
| 856 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:629, in Gemma2Model.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, **flash_attn_kwargs) | |
| 617 layer_outputs = self._gradient_checkpointing_func( | |
| 618 decoder_layer.__call__, | |
| 619 hidden_states, | |
| (...) | |
| 626 cache_position, | |
| 627 ) | |
| 628 else: | |
| --> 629 layer_outputs = decoder_layer( | |
| 630 hidden_states, | |
| 631 position_embeddings=position_embeddings, | |
| 632 attention_mask=causal_mask, | |
| 633 position_ids=position_ids, | |
| 634 past_key_value=past_key_values, | |
| 635 output_attentions=output_attentions, | |
| 636 use_cache=use_cache, | |
| 637 cache_position=cache_position, | |
| 638 **flash_attn_kwargs, | |
| 639 ) | |
| 641 hidden_states = layer_outputs[0] | |
| 643 if output_attentions: | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:299, in Gemma2DecoderLayer.forward(self, hidden_states, position_embeddings, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position) | |
| 296 hidden_states = self.input_layernorm(hidden_states) | |
| 298 # Self Attention | |
| --> 299 hidden_states, self_attn_weights = self.self_attn( | |
| 300 hidden_states=hidden_states, | |
| 301 position_embeddings=position_embeddings, | |
| 302 attention_mask=attention_mask, | |
| 303 position_ids=position_ids, | |
| 304 past_key_value=past_key_value, | |
| 305 output_attentions=output_attentions, | |
| 306 use_cache=use_cache, | |
| 307 cache_position=cache_position, | |
| 308 ) | |
| 309 hidden_states = self.post_attention_layernorm(hidden_states) | |
| 310 hidden_states = residual + hidden_states | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1736, in Module._wrapped_call_impl(self, *args, **kwargs) | |
| 1734 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc] | |
| 1735 else: | |
| -> 1736 return self._call_impl(*args, **kwargs) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/torch/nn/modules/module.py:1747, in Module._call_impl(self, *args, **kwargs) | |
| 1742 # If we don't have any hooks, we want to skip the rest of the logic in | |
| 1743 # this function, and just call forward. | |
| 1744 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks | |
| 1745 or _global_backward_pre_hooks or _global_backward_hooks | |
| 1746 or _global_forward_hooks or _global_forward_pre_hooks): | |
| -> 1747 return forward_call(*args, **kwargs) | |
| 1749 result = None | |
| 1750 called_always_called_hooks = set() | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs) | |
| 168 output = module._old_forward(*args, **kwargs) | |
| 169 else: | |
| --> 170 output = module._old_forward(*args, **kwargs) | |
| 171 return module._hf_hook.post_forward(module, output) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/models/gemma2/modeling_gemma2.py:224, in Gemma2Attention.forward(self, hidden_states, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs) | |
| 221 if past_key_value is not None: | |
| 222 # sin and cos are specific to RoPE models; cache_position needed for the static cache | |
| 223 cache_kwargs = {\"sin\": sin, \"cos\": cos, \"cache_position\": cache_position} | |
| --> 224 key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| 226 attention_interface: Callable = eager_attention_forward | |
| 227 if self.config._attn_implementation != \"eager\": | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1717, in HybridCache.update(self, key_states, value_states, layer_idx, cache_kwargs) | |
| 1714 else: | |
| 1715 update_fn = self._static_update | |
| -> 1717 return update_fn( | |
| 1718 cache_position, | |
| 1719 layer_idx, | |
| 1720 key_states, | |
| 1721 value_states, | |
| 1722 k_out, | |
| 1723 v_out, | |
| 1724 k_out.shape[2], | |
| 1725 ) | |
| File ~/local/miniconda3/envs/paligemma/lib/python3.12/site-packages/transformers/cache_utils.py:1694, in HybridCache._static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len) | |
| 1693 def _static_update(self, cache_position, layer_idx, key_states, value_states, k_out, v_out, max_cache_len): | |
| -> 1694 k_out[:, :, cache_position] = key_states | |
| 1695 v_out[:, :, cache_position] = value_states | |
| 1697 self.key_cache[layer_idx] = k_out | |
| RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1!" | |
| ``` |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment