I am trying to build a RAG-based Chatbot with Chain of Thought for WordPress Site. I am not very experienced with Hugging face. I am using API to retrieve the data for this.
Problem description:
I have built a custom retriever for the rag model. I know that I can use RagRetriever, I tried and I am having severe problems with version control as it requires me to import datasets.
After running the program, the Error says that, my ‘CustomRetriever’ object is not callable. Is there a way around this?
I’m using,
Python v3.10.12
Transformers v4.42.0 dev0
My Custom Retriever Class
# Custom Retriever Class
class CustomRetriever:
def __init__(self, index):
self.index = index
def retrieve(self, query_embedding, n_docs=5):
distances, indices = self.index.search(query_embedding, n_docs)
return indices
# Initialize RAG components
tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq")
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq")
# Initialize Custom Retriever
custom_retriever = CustomRetriever(index=index)
# Set the retriever for RAG model
rag_model.set_retriever(custom_retriever)
My implementation(passing input ids and attention mask)
# Maintain a conversation history
conversation_history = []
def process_query_with_chain_of_thought(user_query, previous_context=""):
# Tokenize the input query
inputs = tokenizer(user_query, return_tensors="pt")
# Generate embeddings for the query
query_embedding = model.encode([user_query]).reshape(1, -1)
# Retrieve relevant document indices
doc_indices = custom_retriever.retrieve(query_embedding, n_docs=5)
# Debugging: Print retrieved document indices
# print(f"Retrieved document indices: {doc_indices}")
# Fetch actual documents
retrieved_docs = []
for i in doc_indices[0]:
if 0 <= i < len(posts):
retrieved_docs.append(posts[i]['content']['rendered'])
else:
print(f"Index {i} is out of range.")
# Print the retrieved documents in a human-readable format
# print("Retrieved documents:")
full_content = ""
for doc in retrieved_docs:
# Clean HTML and convert to plain text
cleaned_text = clean_html(doc)
# print(cleaned_text) # Print the extracted text
# print("="*50) # Separate each document with a line of '='
full_content += cleaned_text + "nn"
# Combine user query and retrieved documents to create context
context = user_query + "nn" + full_content
if context:
# Tokenize the retrieved documents
context_inputs = tokenizer(full_content, return_tensors="pt", padding=True, truncation=True, max_length=512)
# Ensure that context_input_ids are passed correctly
response_ids = rag_model.generate(
input_ids=inputs['input_ids'],
context_input_ids=context_inputs['input_ids'],
context_attention_mask=context_inputs['attention_mask']
)
response_text = tokenizer.decode(response_ids[0], skip_special_tokens=True)
# Append the initial response to the conversation history
conversation_history.append(f"User: {user_query}")
conversation_history.append(f"Bot: {response_text}")
# Create a new context by combining the conversation history
new_context = "n".join(conversation_history)
# # Debug: print the final response
# print("Final response:")
# print(new_context)
return response_text, new_context
else:
print("No documents retrieved")
return "Sorry, I couldn't find any relevant information.", previous_context
Queries:
# Example multi-turn conversation
user_queries = [
"Are there any news on fiction stories?",
"Which one is the most populer",
"What are the reviews."
]
context = ""
for query in user_queries:
response, context = process_query_with_chain_of_thought(query, previous_context=context)
print(f"Processed response: {response}n")
Error message:
TypeError Traceback (most recent call last)
<ipython-input-22-c7d6867da07f> in <cell line: 2>()
1 context = ""
2 for query in user_queries:
----> 3 response, context = process_query_with_chain_of_thought(query, previous_context=context)
4 print(f"Processed response: {response}n")
8 frames
<ipython-input-20-3284408ddf21> in process_query_with_chain_of_thought(user_query, previous_context)
41
42 # Ensure that context_input_ids are passed correctly
---> 43 response_ids = rag_model.generate(
44 input_ids=inputs['input_ids'],
45 context_input_ids=context_inputs['input_ids'],
/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
116
117 return decorate_context
/usr/local/lib/python3.10/dist-packages/transformers/models/rag/modeling_rag.py in generate(self, input_ids, attention_mask, context_input_ids, context_attention_mask, doc_scores, do_deduplication, num_return_sequences, num_beams, n_docs, **model_kwargs)
1019 if input_ids is not None:
1020 new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
-> 1021 outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
1022 else: # input_ids is None, need context_input_ids/mask and doc_scores
1023 assert context_attention_mask is not None, (
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/transformers/models/rag/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, context_input_ids, context_attention_mask, doc_scores, use_cache, output_attentions, output_hidden_states, output_retrieved, exclude_bos_score, reduce_loss, labels, n_docs, **kwargs)
843 use_cache = False
844
--> 845 outputs = self.rag(
846 input_ids=input_ids,
847 attention_mask=attention_mask,
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
1533
1534 def _call_impl(self, *args, **kwargs):
/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1542
1543 try:
/usr/local/lib/python3.10/dist-packages/transformers/models/rag/modeling_rag.py in forward(self, input_ids, attention_mask, encoder_outputs, decoder_input_ids, decoder_attention_mask, past_key_values, doc_scores, context_input_ids, context_attention_mask, use_cache, output_attentions, output_hidden_states, output_retrieved, n_docs)
591 question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
592
--> 593 retriever_outputs = self.retriever(
594 input_ids,
595 question_encoder_last_hidden_state.cpu().detach().to(torch.float32).numpy(),
TypeError: 'CustomRetriever' object is not callable
If you want any more details, please do let me know. Thanks.