/* Copyright 2016, Ableton AG, Berlin. All rights reserved.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see .
*
* If you would like to incorporate Link into a proprietary software application,
* please contact .
*/
#pragma once
#include
#include
#include
#include
namespace ableton
{
namespace discovery
{
template
class PeerGateway
{
public:
// The peer types are defined by the observer but must match with those
// used by the Messenger
using ObserverT = typename util::Injected::type;
using NodeState = typename ObserverT::GatewayObserverNodeState;
using NodeId = typename ObserverT::GatewayObserverNodeId;
using Timer = typename util::Injected::type::Timer;
using TimerError = typename Timer::ErrorCode;
PeerGateway(util::Injected messenger,
util::Injected observer,
util::Injected io)
: mpImpl(new Impl(std::move(messenger), std::move(observer), std::move(io)))
{
mpImpl->listen();
}
PeerGateway(const PeerGateway&) = delete;
PeerGateway& operator=(const PeerGateway&) = delete;
PeerGateway(PeerGateway&& rhs)
: mpImpl(std::move(rhs.mpImpl))
{
}
void updateState(NodeState state)
{
mpImpl->updateState(std::move(state));
}
private:
using PeerTimeout = std::pair;
using PeerTimeouts = std::vector;
struct Impl : std::enable_shared_from_this
{
Impl(util::Injected messenger,
util::Injected observer,
util::Injected io)
: mMessenger(std::move(messenger))
, mObserver(std::move(observer))
, mIo(std::move(io))
, mPruneTimer(mIo->makeTimer())
{
}
void updateState(NodeState state)
{
mMessenger->updateState(std::move(state));
try
{
mMessenger->broadcastState();
}
catch (const std::runtime_error& err)
{
info(mIo->log()) << "State broadcast failed on gateway: " << err.what();
}
}
void listen()
{
mMessenger->receive(util::makeAsyncSafe(this->shared_from_this()));
}
// Operators for handling incoming messages
void operator()(const PeerState& msg)
{
onPeerState(msg.peerState, msg.ttl);
listen();
}
void operator()(const ByeBye& msg)
{
onByeBye(msg.peerId);
listen();
}
void onPeerState(const NodeState& nodeState, const int ttl)
{
using namespace std;
const auto peerId = nodeState.ident();
const auto existing = findPeer(peerId);
if (existing != end(mPeerTimeouts))
{
// If the peer is already present in our timeout list, remove it
// as it will be re-inserted below.
mPeerTimeouts.erase(existing);
}
auto newTo = make_pair(mPruneTimer.now() + std::chrono::seconds(ttl), peerId);
mPeerTimeouts.insert(
upper_bound(begin(mPeerTimeouts), end(mPeerTimeouts), newTo, TimeoutCompare{}),
move(newTo));
sawPeer(*mObserver, nodeState);
scheduleNextPruning();
}
void onByeBye(const NodeId& peerId)
{
const auto it = findPeer(peerId);
if (it != mPeerTimeouts.end())
{
peerLeft(*mObserver, it->second);
mPeerTimeouts.erase(it);
}
}
void pruneExpiredPeers()
{
using namespace std;
const auto test = make_pair(mPruneTimer.now(), NodeId{});
debug(mIo->log()) << "pruning peers @ " << test.first.time_since_epoch().count();
const auto endExpired =
lower_bound(begin(mPeerTimeouts), end(mPeerTimeouts), test, TimeoutCompare{});
for_each(begin(mPeerTimeouts), endExpired, [this](const PeerTimeout& pto) {
info(mIo->log()) << "pruning peer " << pto.second;
peerTimedOut(*mObserver, pto.second);
});
mPeerTimeouts.erase(begin(mPeerTimeouts), endExpired);
scheduleNextPruning();
}
void scheduleNextPruning()
{
// Find the next peer to expire and set the timer based on it
if (!mPeerTimeouts.empty())
{
// Add a second of padding to the timer to avoid over-eager timeouts
const auto t = mPeerTimeouts.front().first + std::chrono::seconds(1);
debug(mIo->log()) << "scheduling next pruning for "
<< t.time_since_epoch().count() << " because of peer "
<< mPeerTimeouts.front().second;
mPruneTimer.expires_at(t);
mPruneTimer.async_wait([this](const TimerError e) {
if (!e)
{
pruneExpiredPeers();
}
});
}
}
struct TimeoutCompare
{
bool operator()(const PeerTimeout& lhs, const PeerTimeout& rhs) const
{
return lhs.first < rhs.first;
}
};
typename PeerTimeouts::iterator findPeer(const NodeId& peerId)
{
return std::find_if(begin(mPeerTimeouts), end(mPeerTimeouts),
[&peerId](const PeerTimeout& pto) { return pto.second == peerId; });
}
util::Injected mMessenger;
util::Injected mObserver;
util::Injected mIo;
Timer mPruneTimer;
PeerTimeouts mPeerTimeouts; // Invariant: sorted by time_point
};
std::shared_ptr mpImpl;
};
template
PeerGateway makePeerGateway(
util::Injected messenger,
util::Injected observer,
util::Injected io)
{
return {std::move(messenger), std::move(observer), std::move(io)};
}
// IpV4 gateway types
template
using IpV4Messenger = UdpMessenger<
IpV4Interface::type&, v1::kMaxMessageSize>,
StateQuery,
IoContext>;
template
using IpV4Gateway =
PeerGateway::type&>,
PeerObserver,
IoContext>;
// Factory function to bind a PeerGateway to an IpV4Interface with the given address.
template
IpV4Gateway makeIpV4Gateway(
util::Injected io,
const asio::ip::address_v4& addr,
util::Injected observer,
NodeState state)
{
using namespace std;
using namespace util;
const uint8_t ttl = 5;
const uint8_t ttlRatio = 20;
auto iface = makeIpV4Interface(injectRef(*io), addr);
auto messenger =
makeUdpMessenger(injectVal(move(iface)), move(state), injectRef(*io), ttl, ttlRatio);
return {injectVal(move(messenger)), move(observer), move(io)};
}
} // namespace discovery
} // namespace ableton