Fix SrtCaller Crash problem (#4171)

1.Fix the crash problem when SrtPlayer reveives an Ack packet 
2.Remove SrtCaller's Check on streamid format to make it compatible with
other SRT streaming service。
3.Modify the coding format: replace tab to space
This commit is contained in:
baigao-X
2025-02-28 12:46:43 +08:00
committed by GitHub
parent 52ca731392
commit c0a93f3c8f
7 changed files with 793 additions and 819 deletions

View File

@@ -21,57 +21,38 @@ using namespace SRT;
namespace mediakit {
//zlm play format
//srt://127.0.0.1:9000?streamid=#!::r=live/test
//srt://127.0.0.1:9000?streamid=#!::r=live/test,h=__defaultVhost__
//zlm push format
//srt://127.0.0.1:9000?streamid=#!::r=live/test,m=publish
//srt://127.0.0.1:9000?streamid=#!::r=live/test,h=__defaultVhost__,m=publish
void SrtUrl::parse(const string &strUrl) {
//DebugL << "url: " << strUrl;
//DebugL << "url: " << strUrl;
_full_url = strUrl;
auto url = strUrl;
auto ip = findSubString(url.data(), "://", "?");
splitUrl(ip, _host, _port);
auto _params = findSubString(url.data(), "?" , NULL);
auto _params = findSubString(url.data(), "?" , NULL);
auto kv = Parser::parseArgs(_params);
auto it = kv.find("streamid");
if (it != kv.end()) {
auto streamid = it->second;
if (!toolkit::start_with(streamid, "#!::")) {
return;
}
std::string real_streamid = streamid.substr(4);
if (it != kv.end()) {
auto streamid = it->second;
if (!toolkit::start_with(streamid, "#!::")) {
return;
}
_streamid = streamid;
}
auto params = Parser::parseArgs(real_streamid, ",", "=");
for (auto iit : params) {
if (iit.first == "h") {
_vhost = iit.second;
} else if (iit.first == "r") {
auto tmps = toolkit::split(iit.second, "/");
if (tmps.size() < 2) {
continue;
}
_app = tmps[0];
_stream = tmps[1];
} else {
//nop
}
}
if (_vhost.empty()) {
_vhost = DEFAULT_VHOST;
}
}
//TraceL << "ip: " << ip;
//TraceL << "_host: " << _host;
//TraceL << "_port: " << _port;
//TraceL << "_params: " << _params;
//TraceL << "_vhost: " << _vhost;
//TraceL << "_app: " << _app;
//TraceL << "_stream: " << _stream;
return;
//TraceL << "ip: " << ip;
//TraceL << "_host: " << _host;
//TraceL << "_port: " << _port;
//TraceL << "_params: " << _params;
//TraceL << "_streamid: " << _streamid;
return;
}
@@ -79,10 +60,10 @@ void SrtUrl::parse(const string &strUrl) {
SrtCaller::SrtCaller(const toolkit::EventPoller::Ptr &poller) {
_poller = poller ? std::move(poller) : EventPollerPool::Instance().getPoller();
_start_timestamp = SteadyClock::now();
_socket_id = generateSocketId();
_socket_id = generateSocketId();
/* _init_seq_number = generateInitSeq(); */
_init_seq_number = 0;
/* _init_seq_number = generateInitSeq(); */
_init_seq_number = 0;
_last_pkt_seq = _init_seq_number - 1;
_pkt_recv_rate_context = std::make_shared<SRT::PacketRecvRateContext>(_start_timestamp);
@@ -93,16 +74,16 @@ SrtCaller::SrtCaller(const toolkit::EventPoller::Ptr &poller) {
}
SrtCaller::~SrtCaller(void) {
DebugL;
DebugL;
}
void SrtCaller::onConnect() {
//DebugL;
//DebugL;
auto peer_addr = SockUtil::make_sockaddr(_url._host.c_str(), (_url._port));
_socket = Socket::createSocket(_poller, false);
auto peer_addr = SockUtil::make_sockaddr(_url._host.c_str(), (_url._port));
_socket = Socket::createSocket(_poller, false);
_socket->bindUdpSock(0, SockUtil::is_ipv4(_url._host.data()) ? "0.0.0.0" : "::");
_socket->bindPeerAddr((struct sockaddr *)&peer_addr, 0, true);
_socket->bindPeerAddr((struct sockaddr *)&peer_addr, 0, true);
weak_ptr<SrtCaller> weak_self = shared_from_this();
_socket->setOnRead([weak_self](const Buffer::Ptr &buf, struct sockaddr *addr, int addr_len) mutable {
@@ -110,10 +91,10 @@ void SrtCaller::onConnect() {
if (!strong_self) {
return;
}
strong_self->inputSockData((uint8_t*)buf->data(), buf->size(), addr);
strong_self->inputSockData((uint8_t*)buf->data(), buf->size(), addr);
});
doHandshake();
doHandshake();
}
void SrtCaller::onResult(const SockException &ex) {
@@ -138,7 +119,7 @@ void SrtCaller::onResult(const SockException &ex) {
void SrtCaller::onHandShakeFinished() {
DebugL;
_is_handleshake_finished = true;
_is_handleshake_finished = true;
if (_handleshake_timer) {
_handleshake_timer.reset();
}
@@ -165,7 +146,7 @@ void SrtCaller::onHandShakeFinished() {
void SrtCaller::onSRTData(DataPacket::Ptr pkt) {
InfoL;
if (!isPlayer()) {
WarnL << "this is not a player data ignore";
WarnL << "this is not a player data ignore";
return;
}
}
@@ -215,7 +196,7 @@ void SrtCaller::onSendTSData(const Buffer::Ptr &buffer, bool flush) {
}
void SrtCaller::inputSockData(uint8_t *buf, int len, struct sockaddr *addr) {
//TraceL << hexdump((void*)buf, len);
//TraceL << hexdump((void*)buf, len);
using srt_control_handler = void (SrtCaller::*)(uint8_t * buf, int len, struct sockaddr *addr);
static std::unordered_map<uint16_t, srt_control_handler> s_control_functions;
@@ -277,16 +258,16 @@ void SrtCaller::doHandshake() {
_crypto = std::make_shared<SRT::Crypto>(getPassphrase());
}
sendHandshakeInduction();
sendHandshakeInduction();
return;
}
void SrtCaller::sendHandshakeInduction() {
DebugL;
DebugL;
_induction_ts = SteadyClock::now();
SRT::HandshakePacket::Ptr req = std::make_shared<SRT::HandshakePacket>();
req->timestamp = DurationCountMicroseconds(_induction_ts - _start_timestamp);
SRT::HandshakePacket::Ptr req = std::make_shared<SRT::HandshakePacket>();
req->timestamp = DurationCountMicroseconds(_induction_ts - _start_timestamp);
req->dst_socket_id = 0;
req->version = 4;
@@ -299,11 +280,11 @@ void SrtCaller::sendHandshakeInduction() {
req->srt_socket_id = _socket_id;
req->syn_cookie = 0;
auto dataSenderAddr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port);
req->assignPeerIPBE(&dataSenderAddr);
auto dataSenderAddr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port);
req->assignPeerIPBE(&dataSenderAddr);
req->storeToData();
_handleshake_req = req;
sendControlPacket(req, true);
_handleshake_req = req;
sendControlPacket(req, true);
std::weak_ptr<SrtCaller> weak_self = std::static_pointer_cast<SrtCaller>(shared_from_this());
_handleshake_timer = std::make_shared<Timer>(0.2, [weak_self]()->bool{
@@ -323,10 +304,10 @@ void SrtCaller::sendHandshakeInduction() {
}
void SrtCaller::sendHandshakeConclusion() {
DebugL;
DebugL;
SRT::HandshakePacket::Ptr req = std::make_shared<SRT::HandshakePacket>();
req->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
SRT::HandshakePacket::Ptr req = std::make_shared<SRT::HandshakePacket>();
req->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
req->dst_socket_id = 0;
req->version = 5;
@@ -345,13 +326,13 @@ void SrtCaller::sendHandshakeConclusion() {
req->srt_socket_id = _socket_id;
req->syn_cookie = _sync_cookie;
auto addr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port);
req->assignPeerIPBE(&addr);
auto addr = SockUtil::make_sockaddr(_url._host.c_str(), _url._port);
req->assignPeerIPBE(&addr);
HSExtMessage::Ptr ext = std::make_shared<HSExtMessage>();
ext->extension_type = HSExt::SRT_CMD_HSREQ;
ext->srt_version = srtVersion(1, 5, 0);
ext->srt_flag = 0xbf;
HSExtMessage::Ptr ext = std::make_shared<HSExtMessage>();
ext->extension_type = HSExt::SRT_CMD_HSREQ;
ext->srt_version = srtVersion(1, 5, 0);
ext->srt_flag = 0xbf;
// if set latency, use set value
_delay = getLatency();
@@ -364,13 +345,13 @@ void SrtCaller::sendHandshakeConclusion() {
}
}
ext->recv_tsbpd_delay = _delay;
ext->send_tsbpd_delay = _delay;
req->ext_list.push_back(std::move(ext));
ext->recv_tsbpd_delay = _delay;
ext->send_tsbpd_delay = _delay;
req->ext_list.push_back(std::move(ext));
HSExtStreamID::Ptr extStreamId = std::make_shared<HSExtStreamID>();
extStreamId->streamid = generateStreamId();
req->ext_list.push_back(std::move(extStreamId));
HSExtStreamID::Ptr extStreamId = std::make_shared<HSExtStreamID>();
extStreamId->streamid = generateStreamId();
req->ext_list.push_back(std::move(extStreamId));
if (_crypto) {
HSExtKeyMaterial::Ptr keyMaterial = _crypto->generateKeyMaterialExt(HSExt::SRT_CMD_KMREQ);
@@ -378,8 +359,8 @@ void SrtCaller::sendHandshakeConclusion() {
}
req->storeToData();
_handleshake_req = req;
sendControlPacket(req);
_handleshake_req = req;
sendControlPacket(req);
return;
}
@@ -491,7 +472,7 @@ void SrtCaller::sendMsgDropReq(uint32_t first, uint32_t last) {
void SrtCaller::sendKeepLivePacket() {
auto now = SteadyClock::now();
SRT::KeepLivePacket::Ptr req = std::make_shared<SRT::KeepLivePacket>();
SRT::KeepLivePacket::Ptr req = std::make_shared<SRT::KeepLivePacket>();
req->timestamp = SRT::DurationCountMicroseconds(now - _start_timestamp);
req->dst_socket_id = _peer_socket_id;
req->storeToData();
@@ -510,7 +491,7 @@ void SrtCaller::sendShutDown() {
}
void SrtCaller::tryAnnounceKeyMaterial() {
//TraceL;
//TraceL;
if (!_crypto) {
return;
@@ -546,9 +527,9 @@ void SrtCaller::tryAnnounceKeyMaterial() {
}
void SrtCaller::sendControlPacket(SRT::ControlPacket::Ptr pkt, bool flush) {
//TraceL;
//TraceL;
sendPacket(pkt, flush);
return;
return;
}
void SrtCaller::sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, bool flush) {
@@ -571,22 +552,22 @@ void SrtCaller::sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, boo
pkt->storeToData((uint8_t *)data, size);
sendPacket(pkt, flush);
_send_buf->inputPacket(pkt);
return;
return;
}
void SrtCaller::sendPacket(Buffer::Ptr pkt, bool flush) {
//TraceL << pkt->size();
//TraceL << pkt->size();
auto tmp = _packet_pool.obtain2();
tmp->assign(pkt->data(), pkt->size());
_socket->send(std::move(tmp), nullptr, 0, flush);
_socket->send(std::move(tmp), nullptr, 0, flush);
_send_ticker.resetTime();
return;
return;
}
void SrtCaller::handleHandshake(uint8_t *buf, int len, struct sockaddr *addr) {
//DebugL;
SRT::HandshakePacket pkt;
//DebugL;
SRT::HandshakePacket pkt;
if(!pkt.loadFromData(buf, len)){
WarnL<< "is not vaild HandshakePacket";
return;
@@ -610,96 +591,96 @@ void SrtCaller::handleHandshake(uint8_t *buf, int len, struct sockaddr *addr) {
}
void SrtCaller::handleHandshakeInduction(SRT::HandshakePacket &pkt, struct sockaddr *addr) {
DebugL;
DebugL;
if (!_handleshake_req) {
WarnL << "must Induction Phase for handleshake";
return;
}
if (!_handleshake_req) {
WarnL << "must Induction Phase for handleshake";
return;
}
if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) {
WarnL << "should be Conclusion Phase for handleshake ";
return;
} else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_INDUCTION) {
WarnL <<"not reach this";
return;
}
if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_CONCLUSION) {
WarnL << "should be Conclusion Phase for handleshake ";
return;
} else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_INDUCTION) {
WarnL <<"not reach this";
return;
}
// Induction Phase
// Induction Phase
if (pkt.version != 5) {
WarnL << "not support handleshake version: " << pkt.version;
return;
}
WarnL << "not support handleshake version: " << pkt.version;
return;
}
if (pkt.extension_field != 0x4A17) {
WarnL << "not match SRT MAGIC";
return;
}
if (pkt.extension_field != 0x4A17) {
WarnL << "not match SRT MAGIC";
return;
}
if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) {
WarnL << "not match _socket_id";
return;
}
if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) {
WarnL << "not match _socket_id";
return;
}
// TODO: encryption_field
_sync_cookie = pkt.syn_cookie;
_sync_cookie = pkt.syn_cookie;
_mtu = std::min<uint32_t>(pkt.mtu, _mtu);
_max_flow_window_size = std::min<uint32_t>(pkt.max_flow_window_size, _max_flow_window_size);
sendHandshakeConclusion();
sendHandshakeConclusion();
return;
}
void SrtCaller::handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sockaddr *addr) {
DebugL;
DebugL;
if (!_handleshake_req) {
WarnL << "must Conclusion Phase for handleshake ";
return;
}
if (!_handleshake_req) {
WarnL << "must Conclusion Phase for handleshake ";
return;
}
if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) {
WarnL << "should be Conclusion Phase for handleshake ";
return;
} else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_CONCLUSION) {
WarnL <<"not reach this";
return;
}
if (_handleshake_req->handshake_type == HandshakePacket::HS_TYPE_INDUCTION) {
WarnL << "should be Conclusion Phase for handleshake ";
return;
} else if (_handleshake_req->handshake_type != HandshakePacket::HS_TYPE_CONCLUSION) {
WarnL <<"not reach this";
return;
}
// Conclusion Phase
// Conclusion Phase
if (pkt.version != 5) {
WarnL << "not support handleshake version: " << pkt.version;
return;
}
WarnL << "not support handleshake version: " << pkt.version;
return;
}
if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) {
WarnL << "not match _socket_id";
return;
}
if (pkt.dst_socket_id != _handleshake_req->srt_socket_id) {
WarnL << "not match _socket_id";
return;
}
// TODO: encryption_field
_peer_socket_id = pkt.srt_socket_id;
_peer_socket_id = pkt.srt_socket_id;
HSExtMessage::Ptr resp;
HSExtMessage::Ptr resp;
HSExtKeyMaterial::Ptr keyMaterial;
for (auto& ext : pkt.ext_list) {
if (!resp) {
resp = std::dynamic_pointer_cast<HSExtMessage>(ext);
}
for (auto& ext : pkt.ext_list) {
if (!resp) {
resp = std::dynamic_pointer_cast<HSExtMessage>(ext);
}
if (!keyMaterial) {
keyMaterial = std::dynamic_pointer_cast<HSExtKeyMaterial>(ext);
}
}
}
if (resp) {
if (resp) {
_delay = std::max<uint16_t>(_delay, resp->recv_tsbpd_delay);
//DebugL << "flag " << resp->srt_flag;
//DebugL << "recv_tsbpd_delay " << resp->recv_tsbpd_delay;
//DebugL << "send_tsbpd_delay " << resp->send_tsbpd_delay;
}
//DebugL << "flag " << resp->srt_flag;
//DebugL << "recv_tsbpd_delay " << resp->recv_tsbpd_delay;
//DebugL << "send_tsbpd_delay " << resp->send_tsbpd_delay;
}
if (keyMaterial && _crypto) {
_crypto->loadFromKeyMaterial(keyMaterial);
@@ -715,7 +696,7 @@ void SrtCaller::handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sock
}
onHandShakeFinished();
return;
return;
}
void SrtCaller::handleACK(uint8_t *buf, int len, struct sockaddr *addr) {
@@ -730,7 +711,9 @@ void SrtCaller::handleACK(uint8_t *buf, int len, struct sockaddr *addr) {
pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
pkt->ack_number = ack.ack_number;
pkt->storeToData();
_send_buf->drop(ack.last_ack_pkt_seq_number);
if (_send_buf) {
_send_buf->drop(ack.last_ack_pkt_seq_number);
}
sendControlPacket(pkt, true);
// TraceL<<"ack number "<<ack.ack_number;
return;
@@ -892,9 +875,9 @@ void SrtCaller::handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockadd
}
void SrtCaller::handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr) {
//TraceL;
DataPacket::Ptr pkt = std::make_shared<DataPacket>();
pkt->loadFromData(buf, len);
//TraceL;
DataPacket::Ptr pkt = std::make_shared<DataPacket>();
pkt->loadFromData(buf, len);
if (_crypto) {
auto payload = _crypto->decrypt(pkt, pkt->payloadData(), pkt->payloadSize());
@@ -906,10 +889,10 @@ void SrtCaller::handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr) {
pkt->reloadPayload((uint8_t*)payload->data(), payload->size());
}
_estimated_link_capacity_context->inputPacket(_now, pkt);
_estimated_link_capacity_context->inputPacket(_now, pkt);
std::list<DataPacket::Ptr> list;
_recv_buf->inputPacket(pkt, list);
std::list<DataPacket::Ptr> list;
_recv_buf->inputPacket(pkt, list);
for (auto& data : list) {
if (_last_pkt_seq + 1 != data->packet_seq_number) {
TraceL << "pkt lost " << _last_pkt_seq + 1 << "->" << data->packet_seq_number;
@@ -1008,14 +991,7 @@ float SrtCaller::getTimeOutSec() {
};
std::string SrtCaller::generateStreamId() {
auto streamId = "#!::r=" + _url._app + "/" + _url._stream;
if (_url._vhost != DEFAULT_VHOST) {
streamId += ",h=" +_url._vhost;
}
if (!isPlayer()) {
streamId += ",m=publish";
}
return streamId;
return _url._streamid;
};
uint32_t SrtCaller::generateSocketId() {