I was trying to re-implement the model.generate() function of transformers’ models from huggingface. I did that so I could implement logit-bias, that normal function does not allow. But before I could reach that, I encountered a lot of problems with my top-p sampling.
Here’s the code snippet:
generation_args = {
"max_new_tokens": 500,
"temperature": 0.4, # Adjust temperature if needed for more or less randomness
"do_sample": True, # Enable sampling
"top_p": 0.5, # Set the cumulative probability for nucleus sampling
"top_k": None, # Optionally, you can set top_k if you want to use it alongside or instead of top_p
}
def top_p_filtering(logits, top_p):
"""Filter the logits using top-p (nucleus) sampling."""
# Sort logits in descending order and get the sorted indices
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
# Compute the cumulative probabilities of the sorted logits
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
# Create a mask for the tokens to keep
sorted_indices_to_keep = cumulative_probs <= top_p
# Ensure that at least one token is kept (the first token, which has the highest logit)
sorted_indices_to_keep[..., 0] = True
# Filter out the tokens to remove by setting their logits to negative infinity
logits[sorted_indices[~sorted_indices_to_keep]] = float('-inf')
return logits
def custom_generate(input_ids, streamer, max_new_tokens, temperature, top_p):
past_key_values = None
attention_mask = torch.ones(input_ids.shape, device=input_ids.device)
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(
input_ids=input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
use_cache=True
)
logits = outputs.logits[:, -1, :] # Get logits of the last token
# Apply temperature to logits
if temperature != 1.0:
logits = logits / temperature
# Apply top-p sampling
if top_p is not None and top_p < 1.0:
logits = top_p_filtering(logits, top_p)
print("1")
next_token_probs = torch.nn.functional.softmax(logits, dim=-1)
print("2")
# Check if next_token_probs contains valid probabilities
next_token_id = torch.multinomial(next_token_probs,
num_samples=1)
print("3")
streamer.put(next_token_id) # Pass the tensor directly to the streamer
input_ids = next_token_id # Set the next input to the last generated token
attention_mask = torch.cat(
[attention_mask, torch.ones((attention_mask.shape[0], 1), device=attention_mask.device)], dim=1)
past_key_values = outputs.past_key_values
if next_token_id.item() == tokenizer.eos_token_id:
break
with torch.no_grad():
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
The error that I face:
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [10,0,0], thread: [63,0,0] Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.
Exception in thread Thread-18 (generate):
Traceback (most recent call last):
File "/usr/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
self.run()
File "/usr/lib/python3.10/threading.py", line 953, in run
self._target(*self._args, **self._kwargs)
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 130, in generate
custom_generate(input_ids, streamer, generation_args["max_new_tokens"], generation_args["temperature"], generation_args["top_p"])
File "/mnt/c/Users/User/Documents/EmpatheticChatBot/Inference-Server.py", line 108, in custom_generate
next_token_id = torch.multinomial(next_token_probs,
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
The entire problem arised only after adding top-p sampling.
I expected my sampling to work, as I have looked through my code maybe 30 times already. ChatGPT says this code is perfect, and that my error is really hard to debug. My hypothesis is that values are getting incorrectly filtered or setting them to “bad” values.
AndreiS is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.