Home | History | Annotate | Line # | Download | only in std
      1 // <barrier> -*- C++ -*-
      2 
      3 // Copyright (C) 2020-2024 Free Software Foundation, Inc.
      4 //
      5 // This file is part of the GNU ISO C++ Library.  This library is free
      6 // software; you can redistribute it and/or modify it under the
      7 // terms of the GNU General Public License as published by the
      8 // Free Software Foundation; either version 3, or (at your option)
      9 // any later version.
     10 
     11 // This library is distributed in the hope that it will be useful,
     12 // but WITHOUT ANY WARRANTY; without even the implied warranty of
     13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
     14 // GNU General Public License for more details.
     15 
     16 // Under Section 7 of GPL version 3, you are granted additional
     17 // permissions described in the GCC Runtime Library Exception, version
     18 // 3.1, as published by the Free Software Foundation.
     19 
     20 // You should have received a copy of the GNU General Public License and
     21 // a copy of the GCC Runtime Library Exception along with this program;
     22 // see the files COPYING3 and COPYING.RUNTIME respectively.  If not, see
     23 // <http://www.gnu.org/licenses/>.
     24 
     25 // This implementation is based on libcxx/include/barrier
     26 //===-- barrier.h --------------------------------------------------===//
     27 //
     28 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
     29 // See https://llvm.org/LICENSE.txt for license information.
     30 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
     31 //
     32 //===---------------------------------------------------------------===//
     33 
     34 /** @file include/barrier
     35  *  This is a Standard C++ Library header.
     36  */
     37 
     38 #ifndef _GLIBCXX_BARRIER
     39 #define _GLIBCXX_BARRIER 1
     40 
     41 #pragma GCC system_header
     42 
     43 #include <bits/requires_hosted.h> // threading primitive
     44 
     45 #define __glibcxx_want_barrier
     46 #include <bits/version.h>
     47 
     48 #ifdef __cpp_lib_barrier // C++ >= 20 && __cpp_aligned_new && lib_atomic_wait
     49 #include <bits/atomic_base.h>
     50 #include <bits/std_thread.h>
     51 #include <bits/unique_ptr.h>
     52 
     53 #include <array>
     54 
     55 namespace std _GLIBCXX_VISIBILITY(default)
     56 {
     57 _GLIBCXX_BEGIN_NAMESPACE_VERSION
     58 
     59   struct __empty_completion
     60   {
     61     _GLIBCXX_ALWAYS_INLINE void
     62     operator()() noexcept
     63     { }
     64   };
     65 
     66 /*
     67 
     68 The default implementation of __tree_barrier is a classic tree barrier.
     69 
     70 It looks different from literature pseudocode for two main reasons:
     71  1. Threads that call into std::barrier functions do not provide indices,
     72     so a numbering step is added before the actual barrier algorithm,
     73     appearing as an N+1 round to the N rounds of the tree barrier.
     74  2. A great deal of attention has been paid to avoid cache line thrashing
     75     by flattening the tree structure into cache-line sized arrays, that
     76     are indexed in an efficient way.
     77 
     78 */
     79 
     80   enum class __barrier_phase_t : unsigned char { };
     81 
     82   template<typename _CompletionF>
     83     class __tree_barrier
     84     {
     85       using __atomic_phase_ref_t = std::__atomic_ref<__barrier_phase_t>;
     86       using __atomic_phase_const_ref_t = std::__atomic_ref<const __barrier_phase_t>;
     87       static constexpr auto __phase_alignment =
     88 		      __atomic_phase_ref_t::required_alignment;
     89 
     90       using __tickets_t = std::array<__barrier_phase_t, 64>;
     91       struct alignas(64) /* naturally-align the heap state */ __state_t
     92       {
     93 	alignas(__phase_alignment) __tickets_t __tickets;
     94       };
     95 
     96       ptrdiff_t _M_expected;
     97       unique_ptr<__state_t[]> _M_state;
     98       __atomic_base<ptrdiff_t> _M_expected_adjustment;
     99       _CompletionF _M_completion;
    100 
    101       alignas(__phase_alignment) __barrier_phase_t  _M_phase;
    102 
    103       bool
    104       _M_arrive(__barrier_phase_t __old_phase, size_t __current)
    105       {
    106 	const auto __old_phase_val = static_cast<unsigned char>(__old_phase);
    107 	const auto __half_step =
    108 			   static_cast<__barrier_phase_t>(__old_phase_val + 1);
    109 	const auto __full_step =
    110 			   static_cast<__barrier_phase_t>(__old_phase_val + 2);
    111 
    112 	size_t __current_expected = _M_expected;
    113 	__current %= ((_M_expected + 1) >> 1);
    114 
    115 	for (int __round = 0; ; ++__round)
    116 	  {
    117 	    if (__current_expected <= 1)
    118 		return true;
    119 	    size_t const __end_node = ((__current_expected + 1) >> 1),
    120 			 __last_node = __end_node - 1;
    121 	    for ( ; ; ++__current)
    122 	      {
    123 		if (__current == __end_node)
    124 		  __current = 0;
    125 		auto __expect = __old_phase;
    126 		__atomic_phase_ref_t __phase(_M_state[__current]
    127 						.__tickets[__round]);
    128 		if (__current == __last_node && (__current_expected & 1))
    129 		  {
    130 		    if (__phase.compare_exchange_strong(__expect, __full_step,
    131 						        memory_order_acq_rel))
    132 		      break;     // I'm 1 in 1, go to next __round
    133 		  }
    134 		else if (__phase.compare_exchange_strong(__expect, __half_step,
    135 						         memory_order_acq_rel))
    136 		  {
    137 		    return false; // I'm 1 in 2, done with arrival
    138 		  }
    139 		else if (__expect == __half_step)
    140 		  {
    141 		    if (__phase.compare_exchange_strong(__expect, __full_step,
    142 						        memory_order_acq_rel))
    143 		      break;    // I'm 2 in 2, go to next __round
    144 		  }
    145 	      }
    146 	    __current_expected = __last_node + 1;
    147 	    __current >>= 1;
    148 	  }
    149       }
    150 
    151     public:
    152       using arrival_token = __barrier_phase_t;
    153 
    154       static constexpr ptrdiff_t
    155       max() noexcept
    156       { return __PTRDIFF_MAX__; }
    157 
    158       __tree_barrier(ptrdiff_t __expected, _CompletionF __completion)
    159 	  : _M_expected(__expected), _M_expected_adjustment(0),
    160 	    _M_completion(move(__completion)),
    161 	    _M_phase(static_cast<__barrier_phase_t>(0))
    162       {
    163 	size_t const __count = (_M_expected + 1) >> 1;
    164 
    165 	_M_state = std::make_unique<__state_t[]>(__count);
    166       }
    167 
    168       [[nodiscard]] arrival_token
    169       arrive(ptrdiff_t __update)
    170       {
    171 	std::hash<std::thread::id> __hasher;
    172 	size_t __current = __hasher(std::this_thread::get_id());
    173 	__atomic_phase_ref_t __phase(_M_phase);
    174 	const auto __old_phase = __phase.load(memory_order_relaxed);
    175 	const auto __cur = static_cast<unsigned char>(__old_phase);
    176 	for(; __update; --__update)
    177 	  {
    178 	    if(_M_arrive(__old_phase, __current))
    179 	      {
    180 		_M_completion();
    181 		_M_expected += _M_expected_adjustment.load(memory_order_relaxed);
    182 		_M_expected_adjustment.store(0, memory_order_relaxed);
    183 		auto __new_phase = static_cast<__barrier_phase_t>(__cur + 2);
    184 		__phase.store(__new_phase, memory_order_release);
    185 		__phase.notify_all();
    186 	      }
    187 	  }
    188 	return __old_phase;
    189       }
    190 
    191       void
    192       wait(arrival_token&& __old_phase) const
    193       {
    194 	__atomic_phase_const_ref_t __phase(_M_phase);
    195 	auto const __test_fn = [=]
    196 	  {
    197 	    return __phase.load(memory_order_acquire) != __old_phase;
    198 	  };
    199 	std::__atomic_wait_address(&_M_phase, __test_fn);
    200       }
    201 
    202       void
    203       arrive_and_drop()
    204       {
    205 	_M_expected_adjustment.fetch_sub(1, memory_order_relaxed);
    206 	(void)arrive(1);
    207       }
    208     };
    209 
    210   template<typename _CompletionF = __empty_completion>
    211     class barrier
    212     {
    213       // Note, we may introduce a "central" barrier algorithm at some point
    214       // for more space constrained targets
    215       using __algorithm_t = __tree_barrier<_CompletionF>;
    216       __algorithm_t _M_b;
    217 
    218     public:
    219       class arrival_token final
    220       {
    221       public:
    222 	arrival_token(arrival_token&&) = default;
    223 	arrival_token& operator=(arrival_token&&) = default;
    224 	~arrival_token() = default;
    225 
    226       private:
    227 	friend class barrier;
    228 	using __token = typename __algorithm_t::arrival_token;
    229 	explicit arrival_token(__token __tok) noexcept : _M_tok(__tok) { }
    230 	__token _M_tok;
    231       };
    232 
    233       static constexpr ptrdiff_t
    234       max() noexcept
    235       { return __algorithm_t::max(); }
    236 
    237       explicit
    238       barrier(ptrdiff_t __count, _CompletionF __completion = _CompletionF())
    239       : _M_b(__count, std::move(__completion))
    240       { }
    241 
    242       barrier(barrier const&) = delete;
    243       barrier& operator=(barrier const&) = delete;
    244 
    245       [[nodiscard]] arrival_token
    246       arrive(ptrdiff_t __update = 1)
    247       { return arrival_token{_M_b.arrive(__update)}; }
    248 
    249       void
    250       wait(arrival_token&& __phase) const
    251       { _M_b.wait(std::move(__phase._M_tok)); }
    252 
    253       void
    254       arrive_and_wait()
    255       { wait(arrive()); }
    256 
    257       void
    258       arrive_and_drop()
    259       { _M_b.arrive_and_drop(); }
    260     };
    261 
    262 _GLIBCXX_END_NAMESPACE_VERSION
    263 } // namespace
    264 #endif // __cpp_lib_barrier
    265 #endif // _GLIBCXX_BARRIER
    266