Working through a textbook which suggests the following implementation for a concurrent queue. I am wondering why the two if
statements, which are highlighted in the code, are necessary. After the snippet I include some reasoning about why I don’t think they are necessary, but I’m still hesitant to go against the word of my textbook, so I would love some explanation about why we need these.
Please feel free to ignore my “proof” of why we don’t need the if
statements: just trying to give some insight into how I’m thinking about the problem.
The implementation:
template<typename T>
struct Node
{
Node(T val, Node* next = nullptr)
: value(val), next(next)
{}
T value;
std::atomic<Node*> next;
};
template<typename T>
struct UnboundedQueue
{
std::atomic<Node*> head;
std::atomic<Node*> tail;
void enq(T val);
T deq(void);
};
template<typename T>
void UnboundedQueue::enq(T val)
{
Node* node = new Node(val);
while(true)
{
Node* last = tail.load();
Node* next = last->next.load();
// ********* THIS if statement
if(last == tail.load())
{
if(next == nullptr)
{
if(last->next.compare_exchange_weak(next, node))
{
tail.compare_exchange_weak(last, node);
return;
}
}
else
{
tail.compare_exchange_weak(last, next);
}
}
}
}
template<typename T>
std::optional<T> UnboundedQueue::deq(void)
{
while(true)
{
Node* first = head.load();
Node* last = tail.load();
Node* next = first->next.load();
// ********* and THIS if statement
if(first == head.load())
{
if(last == first)
{
if(first->next == nullptr)
{
return std::nullopt;
}
else
{
tail.compare_exchange_weak(last, next);
}
}
else
{
if(head.compare_exchange_weak(first, next))
{
T res = first->value;
// this is UB but question isn't about that
delete first;
return res;
}
}
}
}
}
Intuitively, these checks feel unnecessary.
Firstly, we note that we have the following two invariants:
- The
head
node never moves past thetail
node.
Informal Proof: To have this be the case we would need that head == tail
and that head == first
(to execute the CAS). But as we load last
after loading first
, this would mean that last == tail
, and so we would enter the first if
statement and not move head
past tail
.
- If a node was ever the
tail
node and its next value isnullptr
, then it is currently thetail
node.
Informal Proof: We never set the next value of a node to nullptr
. So for a node to lose its status as the tail we would first need to add another node after it, setting its next
value to not nullptr
.
And so now we assume that we remove the if
statements and enter each of the methods with the following conditions being true, showing that in each case, nothing incorrect happens to our state:
Assume that last != tail.load()
Case 1:
next == nullptr
and by our second invariant we have that our assumption must be false.
Case 2:
next != nullptr
. Then we performtail.compare_exchange_weak(last, next)
, which will do nothing becausetail != last
Now assume that first != head.load()
Case 1:
last == first
>>Case 1.1:first->next == nullptr
.
Solast == first
, and thereforefirst
was once equal totail
. And therefore by our second invariant we have thatfirst == tail
and so, by our second invariant, as ourhead
never moves past tail, we must have thatfirst == head
and so our assumption is false.
>>Case 1.2:first->next != nullptr
>>>Case 1.2.1:first != tail
: and sotail.compare_exchange_weak(last, next)
does nothing.
>>>Case 1.2.2:first == tail
: and thereforefirst == head
, as the head can only move forward, but cannot move forward past the tail by our first invariant.
Case 2:
last != first
:
And we will performhead.compare_exchange_weak(first, next)
, which will do nothing becausehead != first
(by assumption)
I’ll note that I don’t have a ton of experience with concurrency, so certainly could have made a mistake somewhere.