I want to get neuron contributions within a layer of llama2 using captum’s LayerConductance and the following code:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import LayerConductance
import bitsandbytes as bnb
def load_model(model_name, bnb_config):
n_gpus = torch.cuda.device_count()
max_memory = "10000MB"
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
# device_map="cpu"
device_map="auto", # dispatch efficiently the model on the available ressources
max_memory = {i: max_memory for i in range(n_gpus)},
)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=True)
# Needed for LLaMA tokenizer
tokenizer.pad_token = tokenizer.eos_token
return model, tokenizer
def create_bnb_config():
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
return bnb_config
model_name = "meta-llama/Llama-2-7b-chat-hf"
bnb_config = create_bnb_config()
model, tokenizer = load_model(model_name, bnb_config)
layer = model.model.layers[-1]
input_test = "The president of the USA is named"
inputs = tokenizer(input_test, return_tensors="pt")
input_ids = inputs["input_ids"].to("cuda:0").long()
layer_cond = LayerConductance(model, layer)
llama_att = layer_cond.attribute(input_ids, target=0) # first token
But I get the following error: RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
Stack trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[29], line 2
1 layer_cond = LayerConductance(model, layer)
----> 2 llama_att = layer_cond.attribute(input_ids, target=target)
File ~/venv/lib/python3.10/site-packages/captum/log/__init__.py:42, in log_usage.<locals>._log_usage.<locals>.wrapper(*args, **kwargs)
40 @wraps(func)
41 def wrapper(*args, **kwargs):
---> 42 return func(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:292, in LayerConductance.attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, internal_batch_size, return_convergence_delta, attribute_to_layer_input)
277 attrs = _batch_attribution(
278 self,
279 num_examples,
(...)
288 attribute_to_layer_input=attribute_to_layer_input,
289 )
291 else:
--> 292 attrs = self._attribute(
293 inputs=inputs,
294 baselines=baselines,
295 target=target,
296 additional_forward_args=additional_forward_args,
297 n_steps=n_steps,
298 method=method,
299 attribute_to_layer_input=attribute_to_layer_input,
300 )
302 is_layer_tuple = isinstance(attrs, tuple)
303 attributions = attrs if is_layer_tuple else (attrs,)
File ~/venv/lib/python3.10/site-packages/captum/attr/_core/layer/layer_conductance.py:360, in LayerConductance._attribute(self, inputs, baselines, target, additional_forward_args, n_steps, method, attribute_to_layer_input, step_sizes_and_alphas)
356 expanded_target = _expand_target(target, n_steps + 1)
358 # Conductance Gradients - Returns gradient of output with respect to
359 # hidden layer and hidden layer evaluated at each input.
--> 360 (layer_gradients, layer_evals,) = compute_layer_gradients_and_eval(
361 forward_fn=self.forward_func,
362 layer=self.layer,
363 inputs=scaled_features_tpl,
364 additional_forward_args=input_additional_args,
365 target_ind=expanded_target,
366 device_ids=self.device_ids,
367 attribute_to_layer_input=attribute_to_layer_input,
368 )
370 # Compute differences between consecutive evaluations of layer_eval.
371 # This approximates the total input gradient of each step multiplied
372 # by the step size.
373 grad_diffs = tuple(
374 layer_eval[num_examples:] - layer_eval[:-num_examples]
375 for layer_eval in layer_evals
376 )
File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:592, in compute_layer_gradients_and_eval(forward_fn, layer, inputs, target_ind, additional_forward_args, gradient_neuron_selector, device_ids, attribute_to_layer_input, output_fn)
541 r"""
542 Computes gradients of the output with respect to a given layer as well
543 as the output evaluation of the layer for an arbitrary forward function
(...)
587 Target layer output for given input.
588 """
589 with torch.autograd.set_grad_enabled(True):
590 # saved_layer is a dictionary mapping device to a tuple of
591 # layer evaluations on that device.
--> 592 saved_layer, output = _forward_layer_distributed_eval(
593 forward_fn,
594 inputs,
595 layer,
596 target_ind=target_ind,
597 additional_forward_args=additional_forward_args,
598 attribute_to_layer_input=attribute_to_layer_input,
599 forward_hook_with_return=True,
600 require_layer_grads=True,
601 )
602 assert output[0].numel() == 1, (
603 "Target not provided when necessary, cannot"
604 " take gradient with respect to multiple outputs."
605 )
607 device_ids = _extract_device_ids(forward_fn, saved_layer, device_ids)
File ~/venv/lib/python3.10/site-packages/captum/_utils/gradient.py:294, in _forward_layer_distributed_eval(forward_fn, inputs, layer, target_ind, additional_forward_args, attribute_to_layer_input, forward_hook_with_return, require_layer_grads)
290 else:
291 all_hooks.append(
292 single_layer.register_forward_hook(hook_wrapper(single_layer))
293 )
--> 294 output = _run_forward(
295 forward_fn,
296 inputs,
297 target=target_ind,
298 additional_forward_args=additional_forward_args,
299 )
300 finally:
301 for hook in all_hooks:
File ~/venv/lib/python3.10/site-packages/captum/_utils/common.py:531, in _run_forward(forward_func, inputs, target, additional_forward_args)
528 inputs = _format_inputs(inputs)
529 additional_forward_args = _format_additional_forward_args(additional_forward_args)
--> 531 output = forward_func(
532 *(*inputs, *additional_forward_args)
533 if additional_forward_args is not None
534 else inputs
535 )
536 return _select_targets(output, target)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1174, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1171 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1173 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1174 outputs = self.model(
1175 input_ids=input_ids,
1176 attention_mask=attention_mask,
1177 position_ids=position_ids,
1178 past_key_values=past_key_values,
1179 inputs_embeds=inputs_embeds,
1180 use_cache=use_cache,
1181 output_attentions=output_attentions,
1182 output_hidden_states=output_hidden_states,
1183 return_dict=return_dict,
1184 cache_position=cache_position,
1185 )
1187 hidden_states = outputs[0]
1188 if self.config.pretraining_tp > 1:
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:931, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
928 use_cache = False
930 if inputs_embeds is None:
--> 931 inputs_embeds = self.embed_tokens(input_ids)
933 return_legacy_cache = False
934 if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File ~/venv/lib/python3.10/site-packages/accelerate/hooks.py:169, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
167 output = module._old_forward(*args, **kwargs)
168 else:
--> 169 output = module._old_forward(*args, **kwargs)
170 return module._hf_hook.post_forward(module, output)
File ~/venv/lib/python3.10/site-packages/torch/nn/modules/sparse.py:163, in Embedding.forward(self, input)
162 def forward(self, input: Tensor) -> Tensor:
--> 163 return F.embedding(
164 input, self.weight, self.padding_idx, self.max_norm,
165 self.norm_type, self.scale_grad_by_freq, self.sparse)
File ~/venv/lib/python3.10/site-packages/torch/nn/functional.py:2264, in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
2258 # Note [embedding_renorm set_grad_enabled]
2259 # XXX: equivalent to
2260 # with torch.no_grad():
2261 # torch.embedding_renorm_
2262 # remove once script supports set_grad_enabled
2263 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2264 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
I made sure that the model and its inout are on the same device (cuda:0
) and that the input is an integer tensor. Maybe the problem is not related to the input at all. Does somebody have an idea?