Add srt caller mode and stream encryption support. (#4088)

Add srt caller mode and stream encryption support.
1. Support srt caller mode, realize srt proxy pull stream proxy push
stream;
url parameter format such as: srt://127.0.0.1:9000?streamid=#!
::r=live/test11
2. Support srt stream encrypted transmission in caller and listener
mode.

---------

Co-authored-by: xiongguangjie <xiong_panda@163.com>
This commit is contained in:
baigao-X
2024-12-28 20:21:29 +08:00
committed by GitHub
parent cb4db80502
commit 1c8ed1c55a
27 changed files with 3002 additions and 17 deletions

View File

@@ -28,6 +28,10 @@ static inline int64_t DurationCountMicroseconds(SteadyClock::duration dur) {
return std::chrono::duration_cast<std::chrono::microseconds>(dur).count();
}
static inline uint32_t DurationCountSeconds(SteadyClock::duration dur) {
return std::chrono::duration_cast<std::chrono::seconds>(dur).count();
}
static inline uint32_t loadUint32(uint8_t *ptr) {
return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3];
}
@@ -113,4 +117,4 @@ private:
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_COMMON_H
#endif // ZLMEDIAKIT_SRT_COMMON_H

507
srt/Crypto.cpp Normal file
View File

@@ -0,0 +1,507 @@
#include <atomic>
#include "Util/MD5.h"
#include "Util/logger.h"
#include "Crypto.hpp"
#if defined(ENABLE_OPENSSL)
#include "openssl/evp.h"
#endif
using namespace toolkit;
using namespace std;
using namespace SRT;
namespace SRT {
#if defined(ENABLE_OPENSSL)
inline const EVP_CIPHER* aes_key_len_mapping_wrap_cipher(int key_len) {
switch (key_len) {
case 192/8: return EVP_aes_192_wrap();
case 256/8: return EVP_aes_256_wrap();
case 128/8:
default:
return EVP_aes_128_wrap();
}
}
inline const EVP_CIPHER* aes_key_len_mapping_ctr_cipher(int key_len) {
switch (key_len) {
case 192/8: return EVP_aes_192_ctr();
case 256/8: return EVP_aes_256_ctr();
case 128/8:
default:
return EVP_aes_128_ctr();
}
}
#endif
/**
* @brief: aes_wrap
* @param [in]: in 待warp的数据
* @param [in]: in_len 待warp的数据长度
* @param [out]: out warp后输出的数据
* @param [out]: outLen 加密后输出的数据长度
* @param [in]: key 密钥
* @param [in]: key_len 密钥长度
* @return : true: 成功false: 失败
**/
static bool aes_wrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) {
WarnL << "EVP_EncryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_EncryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_EncryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes_unwrap
* @param [in]: in 待unwrap的数据
* @param [in]: in_len 待unwrap的数据长度
* @param [out]: out unwrap后输出的数据
* @param [out]: outLen unwrap后输出的数据长度
* @param [in]: key 密钥
* @param [in]: key_len 密钥长度
* @return : true: 成功false: 失败
**/
static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) {
WarnL << "EVP_DecryptInit_ex fail";
break;
}
//设置pkcs7padding
if (1 != EVP_CIPHER_CTX_set_padding(ctx, 1)) {
WarnL << "EVP_CIPHER_CTX_set_padding fail";
break;
}
int len1 = 0;
if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_DecryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_DecryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes ctr 加密
* @param [in]: in 待加密的数据
* @param [in]: in_len 待加密的数据长度
* @param [out]: out 加密后输出的数据
* @param [out]: outLen 加密后输出的数据长度
* @param [in]: key 密钥
* @param [in]: key_len 密钥长度
* @param [in]: iv iv向量(16byte)
* @return : true: 成功false: 失败
**/
static bool aes_ctr_encrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) {
WarnL << "EVP_EncryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_EncryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_EncryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes ctr 解密
* @param [in]: in 待解密的数据
* @param [in]: in_len 待解密的数据长度
* @param [out]: out 解密后输出的数据
* @param [out]: outLen 解密后输出的数据长度
* @param [in]: key 密钥
* @param [in]: key_len 密钥长度
* @param [in]: iv iv向量(16byte)
* @return : true: 成功false: 失败
**/
static bool aes_ctr_decrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) {
WarnL << "EVP_DecryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_DecryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_DecryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
///////////////////////////////////////////////////
// CryptoContext
CryptoContext::CryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) :
_passparase(passparase), _kk(kk) {
if (packet) {
loadFromKeyMaterial(packet);
} else {
refresh();
}
}
void CryptoContext::refresh() {
if (_salt.empty()) {
_salt = makeRandStr(_slen, false);
generateKEK();
}
_sek = makeRandStr(_klen, false);
return;
}
std::string CryptoContext::generateWarppedKey() {
string warpped_key;
int size = (_sek.size() + 15) /16 * 16 + 8;
warpped_key.resize(size);
auto res = aes_wrap((uint8_t*)_sek.data(), _sek.size(), (uint8_t*)warpped_key.data(), &size, (uint8_t*)_kek.data(), _kek.size());
if (!res) {
return "";
}
warpped_key.resize(size);
return warpped_key;
}
void CryptoContext::loadFromKeyMaterial(KeyMaterial::Ptr packet) {
_slen = packet->_slen;
_klen = packet->_klen;
_salt = packet->_salt;
generateKEK();
auto warpped_key = packet->_warpped_key;
BufferLikeString sek;
int size = warpped_key.size();
sek.resize(size);
auto ret = aes_unwrap((uint8_t*)warpped_key.data(), warpped_key.size(), (uint8_t*)sek.data(), &size, (uint8_t*)_kek.data(), _kek.size());
if (!ret) {
throw std::runtime_error(StrPrinter <<"warpped_key unwrap fail, password may mismatch");
}
sek.resize(size);
if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_BOTH_SEK) {
if (_kk == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_sek = sek.substr(0, _slen);
} else {
_sek = sek.substr(_slen, _slen);
}
} else {
_sek = sek;
}
return;
}
bool CryptoContext::generateKEK() {
/**
SEK = PRNG(KLen)
Salt = PRNG(128)
KEK = PBKDF2(passphrase, LSB(64,Salt), Iter, KLen)
**/
_kek.resize(_klen);
#if defined(ENABLE_OPENSSL)
if (PKCS5_PBKDF2_HMAC(_passparase.data(), _passparase.length(), (uint8_t*)_salt.data() + _slen - 64/8, 64 /8, _iter, EVP_sha1(), _klen, (uint8_t*)_kek.data()) != 1) {
return false;
}
return true;
#else
return false;
#endif
}
BufferLikeString::Ptr CryptoContext::generateIv(uint32_t pkt_seq_no) {
auto iv = std::make_shared<BufferLikeString>();
iv->resize(128 /8);
uint8_t* saltData = (uint8_t*)_salt.data();
uint8_t* ivData = (uint8_t*)iv->data();
memset((void*)ivData, 0, iv->size());
memcpy((void*)(ivData + 10), (void*)&pkt_seq_no, 4);
for (size_t i = 0; i < std::min<size_t>(_salt.size(), (size_t)112 /8); ++i) {
ivData[i] ^= saltData[i];
}
return iv;
}
///////////////////////////////////////////////////
// AesCtrCryptoContext
AesCtrCryptoContext::AesCtrCryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) :
CryptoContext(passparase, kk, packet) {
}
BufferLikeString::Ptr AesCtrCryptoContext::encrypt(uint32_t pkt_seq_no, const char *buf, int len) {
auto iv = generateIv(htonl(pkt_seq_no));
auto payload = std::make_shared<BufferLikeString>();
int size = (len + 15) /16 * 16 + 8;
payload->resize(size);
auto ret = aes_ctr_encrypt((const uint8_t*)buf, len, (uint8_t*)payload->data(), &size, (uint8_t*)_sek.data(), _sek.size(), (uint8_t*)iv->data());
if (!ret) {
return nullptr;
}
payload->resize(size);
return payload;
}
BufferLikeString::Ptr AesCtrCryptoContext::decrypt(uint32_t pkt_seq_no, const char *buf, int len) {
auto iv = generateIv(htonl(pkt_seq_no));
auto payload = std::make_shared<BufferLikeString>();
int size = len;
payload->resize(size);
auto ret = aes_ctr_decrypt((const uint8_t*)buf, len, (uint8_t*)payload->data(), &size, (uint8_t*)_sek.data(), _sek.size(), (uint8_t*)iv->data());
if (!ret) {
return nullptr;
}
payload->resize(size);
return payload;
}
///////////////////////////////////////////////////
// Crypto
Crypto::Crypto(const std::string& passparase) :
_passparase(passparase) {
#ifndef ENABLE_OPENSSL
throw std::invalid_argument("openssl disable, please set ENABLE_OPENSSL when compile");
#endif
_ctx_pair[0] = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK);
_ctx_pair[1] = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK);
_ctx_idx = 0;
}
CryptoContext::Ptr Crypto::createCtx(int cipher, const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) {
switch (cipher){
case KeyMaterial::CIPHER_AES_CTR:
return std::make_shared<AesCtrCryptoContext>(passparase, kk, packet);
case KeyMaterial::CIPHER_AES_ECB:
case KeyMaterial::CIPHER_AES_CBC:
case KeyMaterial::CIPHER_AES_GCM:
default:
throw std::runtime_error(StrPrinter <<"not support cipher " << cipher);
}
}
HSExtKeyMaterial::Ptr Crypto::generateKeyMaterialExt(uint16_t extension_type) {
HSExtKeyMaterial::Ptr ext = std::make_shared<HSExtKeyMaterial>();
ext->extension_type = extension_type;
ext->_kk = _ctx_pair[_ctx_idx]->_kk;
ext->_cipher = _ctx_pair[_ctx_idx]->getCipher();
ext->_slen = _ctx_pair[_ctx_idx]->_slen;
ext->_klen = _ctx_pair[_ctx_idx]->_klen;
ext->_salt = _ctx_pair[_ctx_idx]->_salt;
ext->_warpped_key = _ctx_pair[_ctx_idx]->generateWarppedKey();
return ext;
}
KeyMaterialPacket::Ptr Crypto::generateAnnouncePacket(CryptoContext::Ptr ctx) {
KeyMaterialPacket::Ptr pkt = std::make_shared<KeyMaterialPacket>();
pkt->sub_type = HSExt::SRT_CMD_KMREQ;
pkt->_kk = ctx->_kk;
pkt->_cipher = ctx->getCipher();
pkt->_slen = ctx->_slen;
pkt->_klen = ctx->_klen;
pkt->_salt = ctx->_salt;
pkt->_warpped_key = ctx->generateWarppedKey();
return pkt;
}
KeyMaterialPacket::Ptr Crypto::takeAwayAnnouncePacket() {
auto pkt = _re_announce_pkt;
_re_announce_pkt = nullptr;
return pkt;
}
bool Crypto::loadFromKeyMaterial(KeyMaterial::Ptr packet) {
try {
if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_ctx_pair[0] = createCtx(packet->_cipher, _passparase, packet->_kk, packet);
} else if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK) {
_ctx_pair[1] = createCtx(packet->_cipher, _passparase, packet->_kk, packet);
} else if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_BOTH_SEK) {
_ctx_pair[0] = createCtx(packet->_cipher, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK, packet);
_ctx_pair[1] = createCtx(packet->_cipher, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK, packet);
}
} catch (std::exception &ex) {
WarnL << ex.what();
return false;
}
return true;
}
BufferLikeString::Ptr Crypto::encrypt(DataPacket::Ptr pkt, const char *buf, int len) {
_pkt_count++;
//refresh
if (_pkt_count == _re_announcement_period) {
auto ctx = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, _ctx_pair[!_ctx_idx]->_kk);
_ctx_pair[!_ctx_idx] = ctx;
_re_announce_pkt = generateAnnouncePacket(ctx);
}
if (_pkt_count > _refresh_period) {
_pkt_count = 0;
_ctx_idx = !_ctx_idx;
}
pkt->KK = _ctx_pair[_ctx_idx]->_kk;
return _ctx_pair[_ctx_idx]->encrypt(pkt->packet_seq_number, buf, len);
}
BufferLikeString::Ptr Crypto::decrypt(DataPacket::Ptr pkt, const char *buf, int len) {
CryptoContext::Ptr _ctx;
if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_NO_SEK) {
auto payload = std::make_shared<BufferLikeString>();
payload->assign(buf, len);
return payload;
} else if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_ctx = _ctx_pair[0];
} else if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK) {
_ctx = _ctx_pair[1];
}
if (!_ctx) {
WarnL << "not has effective KeyMaterial with kk: " << pkt->KK;
return nullptr;
}
return _ctx->decrypt(pkt->packet_seq_number, buf, len);
}
} // namespace SRT

102
srt/Crypto.hpp Normal file
View File

@@ -0,0 +1,102 @@
#ifndef ZLMEDIAKIT_SRT_CRYPTO_H
#define ZLMEDIAKIT_SRT_CRYPTO_H
#include <stdint.h>
#include <vector>
#include "Network/Buffer.h"
#include "Network/sockutil.h"
#include "Util/logger.h"
#include "Common.hpp"
#include "HSExt.hpp"
#include "Packet.hpp"
namespace SRT {
class CryptoContext : public std::enable_shared_from_this<CryptoContext> {
public:
using Ptr = std::shared_ptr<CryptoContext>;
CryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
virtual ~CryptoContext() = default;
virtual void refresh();
virtual std::string generateWarppedKey();
virtual BufferLikeString::Ptr encrypt(uint32_t pkt_seq_no, const char *buf, int len) = 0;
virtual BufferLikeString::Ptr decrypt(uint32_t pkt_seq_no, const char *buf, int len) = 0;
virtual uint8_t getCipher() const = 0;
protected:
virtual void loadFromKeyMaterial(KeyMaterial::Ptr packet);
virtual bool generateKEK();
BufferLikeString::Ptr generateIv(uint32_t pkt_seq_no);
private:
public:
std::string _passparase;
uint8_t _kk = SRT::KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK;
BufferLikeString _kek;
const uint32_t _iter = 2048;
size_t _slen = 16;
BufferLikeString _salt;
size_t _klen = 16;
BufferLikeString _sek;
};
class AesCtrCryptoContext : public CryptoContext {
public:
using Ptr = std::shared_ptr<AesCtrCryptoContext>;
AesCtrCryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
virtual ~AesCtrCryptoContext() = default;
uint8_t getCipher() const override {
return KeyMaterial::CIPHER_AES_CTR;
}
BufferLikeString::Ptr encrypt(uint32_t pkt_seq_no, const char *buf, int len) override;
BufferLikeString::Ptr decrypt(uint32_t pkt_seq_no, const char *buf, int len) override;
};
class Crypto : public std::enable_shared_from_this<Crypto>{
public:
using Ptr = std::shared_ptr<Crypto>;
Crypto(const std::string& passparase);
virtual ~Crypto() = default;
HSExtKeyMaterial::Ptr generateKeyMaterialExt(uint16_t extension_type);
KeyMaterialPacket::Ptr takeAwayAnnouncePacket();
bool loadFromKeyMaterial(KeyMaterial::Ptr packet);
// for encryption
std::string _passparase;
//The recommended KM Refresh Period is after 2^25 packets encrypted with the same SEK are sent.
const uint32_t _refresh_period = 1 <<25;
const uint32_t _re_announcement_period = (1 <<25) - 4000;
uint32_t _pkt_count = 0;
KeyMaterialPacket::Ptr _re_announce_pkt;
CryptoContext::Ptr _ctx_pair[2]; /* Even(0)/Odd(1) crypto contexts */
uint32_t _ctx_idx = 0;
BufferLikeString::Ptr encrypt(DataPacket::Ptr pkt, const char *buf, int len);
BufferLikeString::Ptr decrypt(DataPacket::Ptr pkt, const char *buf, int len);
private:
CryptoContext::Ptr createCtx(int cipher, const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
KeyMaterialPacket::Ptr generateAnnouncePacket(CryptoContext::Ptr ctx);
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_CRYPTO_H

View File

@@ -131,4 +131,162 @@ std::string HSExtStreamID::dump() {
return std::move(printer);
}
} // namespace SRT
size_t KeyMaterial::getContentSize() {
size_t variable_width = _slen + _warpped_key.size();
size_t content_size = variable_width + 16;
return content_size;
}
bool KeyMaterial::loadFromData(uint8_t *buf, size_t len) {
if (buf == NULL || len < 16) {
return false;
}
uint8_t *ptr = (uint8_t *)buf;
_km_version = (*ptr & 0x70) >> 4;
_pt = *ptr & 0x0f;
ptr += 1;
_sign = loadUint16(ptr);
ptr += 2;
_kk = *ptr & 0x03;
auto sek_num = 1;
if (_kk == KEY_BASED_ENCRYPTION_BOTH_SEK) {
sek_num = 2;
}
ptr += 1;
_keki = loadUint32(ptr);
ptr += 4;
_cipher = *ptr;
ptr += 1;
_auth = *ptr;
ptr += 1;
_se = *ptr;
ptr += 1;
//Resv2
ptr += 1;
//Resv3
ptr += 2;
_slen = *ptr *4;
ptr += 1;
_klen = *ptr *4;
ptr += 1;
size_t wrapped_key_len = 8 + sek_num * _klen;
size_t variable_width = _slen + wrapped_key_len;
if (len < variable_width + 16) {
return false;
}
_salt.assign((const char*)ptr, (size_t)_slen);
ptr += _slen;
_warpped_key.assign((const char*)ptr, (size_t)wrapped_key_len);
return true;
}
bool KeyMaterial::storeToData(uint8_t *buf, size_t len) {
auto content_size = getContentSize();
if (len < content_size) {
return false;
}
uint8_t *ptr = (uint8_t *)buf;
memset(ptr, 0, len);
*ptr = ((_km_version << 4)& 0x70) | (_pt & 0x0f);
ptr += 1;
storeUint16(ptr, _sign);
ptr += 2;
*ptr = _kk & 0x03;
ptr += 1;
storeUint32(ptr, _keki);
ptr += 4;
*ptr = _cipher;
ptr += 1;
*ptr = _auth;
ptr += 1;
*ptr = _se;
ptr += 1;
*ptr = 0; //Resv2
ptr += 1;
storeUint16(ptr, 0);//Resv3
ptr += 2;
*ptr = (uint8_t)(_slen/4);
ptr += 1;
*ptr = (uint8_t)(_klen/4);
ptr += 1;
const char *src = _salt.data();
for (size_t i = 0; i < _salt.size(); ptr++, src++, i++) {
*ptr = *src;
}
src = _warpped_key.data();
for (size_t i = 0; i < _warpped_key.size(); ptr++, src++, i++) {
*ptr = *src;
}
return true;
}
std::string KeyMaterial::dump() {
_StrPrinter printer;
printer << "kmVersion: " << _km_version
<< " pt : " << _pt
<< " sign : " << std::hex << _sign
<< " kk : " << _kk
<< " keki : " << _keki
<< " cipher : " << _cipher
<< " auth : " << _auth
<< " se : " << _se
<< " sLen : " << _slen
<< " salt : " << std::hex << _salt.data()
<< " kLen : " << _klen;
return std::move(printer);
}
bool HSExtKeyMaterial::loadFromData(uint8_t *buf, size_t len) {
if (buf == NULL || len < 4) {
return false;
}
HSExt::_data = BufferRaw::create();
HSExt::_data->assign((char *)buf, len);
HSExt::loadHeader();
assert(extension_type == SRT_CMD_KMREQ || extension_type == SRT_CMD_KMRSP);
return KeyMaterial::loadFromData(buf +4, len -4);
}
bool HSExtKeyMaterial::storeToData() {
size_t content_size = ((KeyMaterial::getContentSize() + 4) + 3) / 4 * 4;
HSExt::_data = BufferRaw::create();
HSExt::_data->setCapacity(content_size);
HSExt::_data->setSize(content_size);
extension_length = (content_size - 4) / 4;
HSExt::storeHeader();
return KeyMaterial::storeToData((uint8_t*)_data->data() + 4, content_size - 4);
}
std::string HSExtKeyMaterial::dump() {
return KeyMaterial::dump();
}
} // namespace SRT

View File

@@ -125,5 +125,118 @@ public:
std::string dump() override;
std::string streamid;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|S| V | PT | Sign | Resv1 | KK|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| KEKI |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Cipher | Auth | SE | Resv2 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv3 | SLen/4 | KLen/4 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Salt |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Wrapped Key +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 11: Key Material Message structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material
*/
class KeyMaterial {
public:
using Ptr = std::shared_ptr<KeyMaterial>;
KeyMaterial() = default;
virtual ~KeyMaterial() = default;
bool loadFromData(uint8_t *buf, size_t len);
bool storeToData(uint8_t *buf, size_t len);
std::string dump();
protected:
size_t getContentSize();
public:
enum {
PACKET_TYPE_RESERVED = 0b0000,
PACKET_TYPE_MSMSG = 0b0001, // 1-Media Strem Message
PACKET_TYPE_KMMSG = 0b0010, // 2-Keying Material Message
PACKET_TYPE_MPEG_TS = 0b0111, // 7-MPEG-TS packet
};
enum {
KEY_BASED_ENCRYPTION_NO_SEK = 0b00,
KEY_BASED_ENCRYPTION_EVEN_SEK = 0b01,
KEY_BASED_ENCRYPTION_ODD_SEK = 0b10,
KEY_BASED_ENCRYPTION_BOTH_SEK = 0b11,
};
enum {
CIPHER_NONE = 0x00,
CIPHER_AES_ECB = 0x01, //reserved, not support
CIPHER_AES_CTR = 0x02,
CIPHER_AES_CBC = 0x03, //reserved, not support
CIPHER_AES_GCM = 0x04
};
enum {
AUTHENTICATION_NONE = 0x00,
AUTH_AES_GCM = 0x01,
};
enum {
STREAM_ENCAPSUALTION_UNSPECIFIED = 0x00,
STREAM_ENCAPSUALTION_MPEG_TS_UDP = 0x01,
STREAM_ENCAPSUALTION_MPEG_TS_SRT = 0x02,
};
uint8_t _km_version = 0b001;
uint8_t _pt = PACKET_TYPE_KMMSG;
uint16_t _sign = 0x2029;
uint8_t _kk = KEY_BASED_ENCRYPTION_EVEN_SEK;
uint32_t _keki = 0;
uint8_t _cipher = CIPHER_AES_CTR;
uint8_t _auth = AUTHENTICATION_NONE;
uint8_t _se = STREAM_ENCAPSUALTION_MPEG_TS_SRT;
uint16_t _slen = 16;
uint16_t _klen = 16;
BufferLikeString _salt;
BufferLikeString _warpped_key;
};
class HSExtKeyMaterial : public HSExt, public KeyMaterial {
public:
using Ptr = std::shared_ptr<HSExtKeyMaterial>;
HSExtKeyMaterial() = default;
virtual ~HSExtKeyMaterial() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| KM State |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 7: KM Response Error
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material-extension-mess
*/
class HSExtKMResponseError : public HSExt {
public:
using Ptr = std::shared_ptr<HSExtKMResponseError>;
HSExtKMResponseError() = default;
~HSExtKMResponseError() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_HS_EXT_H
#endif // ZLMEDIAKIT_SRT_HS_EXT_H

View File

@@ -55,6 +55,13 @@ bool DataPacket::loadFromData(uint8_t *buf, size_t len) {
return true;
}
bool DataPacket::reloadPayload(uint8_t *buf, size_t len) {
_data->setCapacity(len + HEADER_SIZE);
_data->setSize(len + HEADER_SIZE);
memcpy(_data->data() + HEADER_SIZE, buf, len);
return true;
}
bool DataPacket::storeToHeader() {
if (!_data || _data->size() < HEADER_SIZE) {
WarnL << "data size less " << HEADER_SIZE;
@@ -162,6 +169,12 @@ uint16_t ControlPacket::getControlType(uint8_t *buf, size_t len) {
return control_type;
}
uint16_t ControlPacket::getSubType(uint8_t *buf, size_t len) {
uint8_t *ptr = buf;
uint16_t subtype = loadUint16(ptr + 2);
return subtype;
}
bool ControlPacket::loadHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
f = ptr[0] >> 7;
@@ -225,6 +238,20 @@ size_t ControlPacket::size() const {
uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len) {
return loadUint32(buf + 12);
}
#define XX(name, value, str) {str, name},
std::map<std::string, SRT_REJECT_REASON> reject_map = {REJ_MAP(XX)};
#undef XX
std::string getRejectReason(SRT_REJECT_REASON code) {
switch (code) {
#define XX(name, value, str) case name : return str;
REJ_MAP(XX)
#undef XX
default : return "invalid";
}
}
std::string HandshakePacket::dump(){
_StrPrinter printer;
printer <<"flag:"<< (int)f<<"\r\n";
@@ -324,6 +351,9 @@ bool HandshakePacket::loadExtMessage(uint8_t *buf, size_t len) {
case HSExt::SRT_CMD_HSREQ:
case HSExt::SRT_CMD_HSRSP: ext = std::make_shared<HSExtMessage>(); break;
case HSExt::SRT_CMD_SID: ext = std::make_shared<HSExtStreamID>(); break;
case HSExt::SRT_CMD_KMREQ:
case HSExt::SRT_CMD_KMRSP:
ext = std::make_shared<HSExtKeyMaterial>(); break;
default: WarnL << "not support ext " << type; break;
}
if (ext) {
@@ -451,6 +481,23 @@ void HandshakePacket::assignPeerIP(struct sockaddr_storage *addr) {
}
}
void HandshakePacket::assignPeerIPBE(struct sockaddr_storage *addr) {
memset(peer_ip_addr, 0, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
storeUint32(peer_ip_addr, ipv4->sin_addr.s_addr);
} else if (addr->ss_family == AF_INET6) {
if (IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)addr)->sin6_addr)) {
struct in_addr addr4;
memcpy(&addr4, 12 + (char *)&(((struct sockaddr_in6 *)addr)->sin6_addr), 4);
storeUint32(peer_ip_addr, addr4.s_addr);
} else {
const sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
memcpy(peer_ip_addr, ipv6->sin6_addr.s6_addr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
}
}
}
uint32_t HandshakePacket::generateSynCookie(
struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie, int correction) {
static std::atomic<uint32_t> distractor { 0 };
@@ -619,4 +666,4 @@ bool MsgDropReqPacket::storeToData() {
ptr += 4;
return true;
}
} // namespace SRT
} // namespace SRT

View File

@@ -57,6 +57,7 @@ public:
static bool isDataPacket(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
bool loadFromData(uint8_t *buf, size_t len);
bool reloadPayload(uint8_t *buf, size_t len);
bool storeToData(uint8_t *buf, size_t len);
bool storeToHeader();
@@ -105,6 +106,7 @@ public:
static const size_t HEADER_SIZE = 16;
static bool isControlPacket(uint8_t *buf, size_t len);
static uint16_t getControlType(uint8_t *buf, size_t len);
static uint16_t getSubType(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
ControlPacket() = default;
@@ -180,6 +182,37 @@ protected:
Figure 5: Handshake packet structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake
*/
// REJ code,from libsrt
#define REJ_MAP(XX) \
XX(SRT_REJ_UNKNOWN, 1000, "Unknown or erroneous") \
XX(SRT_REJ_SYSTEM, 1001, "Error in system calls") \
XX(SRT_REJ_PEER, 1002, "Peer rejected connection") \
XX(SRT_REJ_RESOURCE, 1003, "Resource allocation failure") \
XX(SRT_REJ_ROGUE, 1004, "Rogue peer or incorrect parameters") \
XX(SRT_REJ_BACKLOG, 1005, "Listener's backlog exceeded") \
XX(SRT_REJ_IPE, 1006, "Internal Program Error") \
XX(SRT_REJ_CLOSE, 1007, "Socket is being closed") \
XX(SRT_REJ_VERSION, 1008, "Peer version too old") \
XX(SRT_REJ_RDVCOOKIE, 1009, "Rendezvous-mode cookie collision") \
XX(SRT_REJ_BADSECRET, 1010, "Incorrect passphrase") \
XX(SRT_REJ_UNSECURE, 1011, "Password required or unexpected") \
XX(SRT_REJ_MESSAGEAPI, 1012, "MessageAPI/StreamAPI collision") \
XX(SRT_REJ_CONGESTION, 1013, "Congestion controller type collision") \
XX(SRT_REJ_FILTER, 1014, "Packet Filter settings error") \
XX(SRT_REJ_GROUP, 1015, "Group settings collision") \
XX(SRT_REJ_TIMEOUT, 1016, "Connection timeout") \
XX(SRT_REJ_CRYPTO, 1017, "Crypto mode")
typedef enum {
#define XX(name, value, str) name = value,
REJ_MAP(XX)
#undef XX
SRT_REJ_E_SIZE
} SRT_REJECT_REASON;
std::string getRejectReason(SRT_REJECT_REASON code);
class HandshakePacket : public ControlPacket {
public:
using Ptr = std::shared_ptr<HandshakePacket>;
@@ -205,6 +238,10 @@ public:
generateSynCookie(struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie = 0, int correction = 0);
std::string dump();
void assignPeerIP(struct sockaddr_storage *addr);
void assignPeerIPBE(struct sockaddr_storage *addr);
bool isReject() {
return (handshake_type >= SRT_REJ_UNKNOWN && handshake_type < SRT_REJ_E_SIZE);
}
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
@@ -367,6 +404,56 @@ public:
}
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type = 0x7FFF | Subtype = 3/4 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type-specific Information |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
the Control Type field of the SRT packet header is set to User-Defined Type (see Table 1),
the Subtype field of the header is set to SRT_CMD_KMREQ for key-refresh request
and SRT_CMD_KMRSP for key-refresh response (Table 5). The KM Refresh mechanism is described in Section 6.1.6.
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material
*/
class KeyMaterialPacket : public ControlPacket, public KeyMaterial {
public:
using Ptr = std::shared_ptr<KeyMaterialPacket>;
KeyMaterialPacket() = default;
~KeyMaterialPacket() = default;
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
_data = BufferRaw::create();
_data->assign((char *)buf, len);
loadHeader();
assert(sub_type == HSExt::SRT_CMD_KMREQ || sub_type == HSExt::SRT_CMD_KMRSP);
return KeyMaterial::loadFromData(buf + HEADER_SIZE, len - HEADER_SIZE);
}
bool storeToData() override {
size_t content_size = ((KeyMaterial::getContentSize() + HEADER_SIZE) + 3) / 4 * 4;
control_type = ControlPacket::USERDEFINEDTYPE;
/* sub_type = HSExt::SRT_CMD_KMREQ; */
/* sub_type = HSExt::SRT_CMD_KMRSP; */
_data = BufferRaw::create();
_data->setCapacity(content_size);
_data->setSize(content_size);
storeToHeader();
return KeyMaterial::storeToData((uint8_t*)_data->data() + HEADER_SIZE, content_size - HEADER_SIZE);
}
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_PACKET_H
#endif // ZLMEDIAKIT_SRT_PACKET_H

View File

@@ -18,12 +18,14 @@ const std::string kTimeOutSec = SRT_FIELD "timeoutSec";
const std::string kPort = SRT_FIELD "port";
const std::string kLatencyMul = SRT_FIELD "latencyMul";
const std::string kPktBufSize = SRT_FIELD "pktBufSize";
const std::string kPassPhrase = SRT_FIELD "passPhrase";
static onceToken token([]() {
mINI::Instance()[kTimeOutSec] = 5;
mINI::Instance()[kPort] = 9000;
mINI::Instance()[kLatencyMul] = 4;
mINI::Instance()[kPktBufSize] = 8192;
mINI::Instance()[kPassPhrase] = "";
});
static std::atomic<uint32_t> s_srt_socket_id_generate { 125 };
@@ -228,6 +230,8 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
// first
HSExtMessage::Ptr req;
HSExtStreamID::Ptr sid;
HSExtKeyMaterial::Ptr keyMaterial;
uint32_t srt_flag = 0xbf;
uint16_t delay = DurationCountMicroseconds(_now - _induction_ts) * getLatencyMul() / 1000;
if (delay <= 120) {
@@ -241,6 +245,9 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
if (!sid) {
sid = std::dynamic_pointer_cast<HSExtStreamID>(ext);
}
if (!keyMaterial) {
keyMaterial = std::dynamic_pointer_cast<HSExtKeyMaterial>(ext);
}
}
if (sid) {
_stream_id = sid->streamid;
@@ -252,6 +259,22 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
srt_flag = req->srt_flag;
delay = delay <= req->recv_tsbpd_delay ? req->recv_tsbpd_delay : delay;
}
if (!keyMaterial && getPassphrase().empty()) {
//nop
} else if (keyMaterial && !getPassphrase().empty()) {
_crypto = std::make_shared<SRT::Crypto>(getPassphrase());
if (!_crypto->loadFromKeyMaterial(keyMaterial)) {
sendRejectPacket(SRT_REJ_BADSECRET, addr);
onShutdown(SockException(Err_other, StrPrinter << "handshake fail, reject resaon: " << SRT::getRejectReason(SRT_REJ_BADSECRET)));
return;
}
} else {
sendRejectPacket(SRT_REJ_UNSECURE, addr);
onShutdown(SockException(Err_other, StrPrinter << "handshake fail, reject resaon: " << SRT::getRejectReason(SRT_REJ_UNSECURE)));
return;
}
TraceL << getIdentifier() << " CONCLUSION Phase from"<<SockUtil::inet_ntoa((struct sockaddr *)addr) << ":" << SockUtil::inet_port((struct sockaddr *)addr);;
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
@@ -262,6 +285,12 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ;
if (_crypto) {
//The default value is 0 (no encryption advertised).
//If neither peer advertises encryption, AES-128 is selected by default
/* req->encryption_field = SRT::HandshakePacket::AES_128; */
res->extension_field |= HandshakePacket::HS_EXT_FILED_KMREQ;
}
res->handshake_type = HandshakePacket::HS_TYPE_CONCLUSION;
res->srt_socket_id = _socket_id;
res->syn_cookie = 0;
@@ -272,6 +301,10 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
ext->srt_flag = srt_flag;
ext->recv_tsbpd_delay = ext->send_tsbpd_delay = delay;
res->ext_list.push_back(std::move(ext));
if (keyMaterial) {
keyMaterial->extension_type = HSExt::SRT_CMD_KMRSP;
res->ext_list.push_back(std::move(keyMaterial));
}
res->storeToData();
_handleshake_res = res;
unregisterSelfHandshake();
@@ -366,6 +399,42 @@ void SrtTransport::sendMsgDropReq(uint32_t first, uint32_t last) {
sendControlPacket(pkt, true);
}
void SrtTransport::tryAnnounceKeyMaterial() {
//TraceL;
if (!_crypto) {
return;
}
auto pkt = _crypto->takeAwayAnnouncePacket();
if (!pkt) {
return;
}
auto now = SteadyClock::now();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = SRT::DurationCountMicroseconds(now - _start_timestamp);
pkt->storeToData();
_announce_req = pkt;
sendControlPacket(pkt, true);
std::weak_ptr<SrtTransport> weak_self = std::static_pointer_cast<SrtTransport>(shared_from_this());
_announce_timer = std::make_shared<Timer>(0.2, [weak_self]()->bool{
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
if (!strong_self->_announce_req) {
return false;
}
strong_self->sendControlPacket(strong_self->_announce_req, true);
return true;
}, getPoller());
return;
}
void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr) {
// TraceL;
NAKPacket pkt;
@@ -433,6 +502,8 @@ void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage
*/
}
void SrtTransport::checkAndSendAckNak(){
//SRT Periodic NAK reports are sent with a period of (RTT + 4 * RTTVar) / 2 (so called NAKInterval),
//with a 20 milliseconds floor
auto nak_interval = (_rtt + _rtt_variance * 4) / 2;
if (nak_interval <= 20 * 1000) {
nak_interval = 20 * 1000;
@@ -468,7 +539,52 @@ void SrtTransport::checkAndSendAckNak(){
_light_ack_pkt_count++;
}
void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr) {
TraceL;
/* TraceL; */
using srt_userd_defined_handler = void (SrtTransport::*)(uint8_t * buf, int len, struct sockaddr_storage *addr);
static std::unordered_map<uint16_t /*sub_type*/, srt_userd_defined_handler> s_userd_defined_functions;
static onceToken token([]() {
s_userd_defined_functions.emplace(SRT::HSExt::SRT_CMD_KMREQ, &SrtTransport::handleKeyMaterialReqPacket);
s_userd_defined_functions.emplace(SRT::HSExt::SRT_CMD_KMRSP, &SrtTransport::handleKeyMaterialRspPacket);
});
uint16_t subtype = ControlPacket::getSubType(buf, len);
auto it = s_userd_defined_functions.find(subtype);
if (it == s_userd_defined_functions.end()) {
WarnL << " not support subtype in user defined msg ignore: " << subtype;
return;
} else {
(this->*(it->second))(buf, len, addr);
}
return;
}
void SrtTransport::handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) {
/* TraceL; */
if (!_crypto) {
WarnL << " not enable crypto, ignore";
return;
}
KeyMaterialPacket::Ptr pkt = std::make_shared<KeyMaterialPacket>();
pkt->loadFromData(buf, len);
_crypto->loadFromKeyMaterial(pkt);
//rsp
pkt->sub_type = SRT::HSExt::SRT_CMD_KMRSP;
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
pkt->storeToData();
sendControlPacket(pkt, true);
return;
}
void SrtTransport::handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) {
/* TraceL; */
_announce_req = nullptr;
return;
}
void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr) {
@@ -603,6 +719,25 @@ void SrtTransport::sendNAKPacket(std::list<PacketQueue::LostPair> &lost_list) {
// TraceL<<"send NAK "<<pkt->dump();
}
void SrtTransport::sendRejectPacket(SRT_REJECT_REASON reason, struct sockaddr_storage *addr) {
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
res->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
res->mtu = _mtu;
res->max_flow_window_size = _max_window_size;
res->initial_packet_sequence_number = _init_seq_number;
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ;
res->handshake_type = reason;
res->srt_socket_id = _socket_id;
res->syn_cookie = 0;
res->assignPeerIP(addr);
res->storeToData();
sendControlPacket(res, true);
return;
}
void SrtTransport::sendShutDown() {
ShutDownPacket::Ptr pkt = std::make_shared<ShutDownPacket>();
pkt->dst_socket_id = _peer_socket_id;
@@ -615,6 +750,16 @@ void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_stora
DataPacket::Ptr pkt = std::make_shared<DataPacket>();
pkt->loadFromData(buf, len);
if (_crypto) {
auto payload = _crypto->decrypt(pkt, pkt->payloadData(), pkt->payloadSize());
if (!payload) {
WarnL << "decrypt pkt->packet_seq_number: " << pkt->packet_seq_number << ", timestamp: " << "pkt->timestamp " << " fail";
return;
}
pkt->reloadPayload((uint8_t*)payload->data(), payload->size());
}
_estimated_link_capacity_context->inputPacket(_now,pkt);
std::list<DataPacket::Ptr> list;
@@ -684,9 +829,26 @@ void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_stora
}
void SrtTransport::sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush) {
pkt->storeToData((uint8_t *)buf, len);
auto data = buf;
auto size = len;
BufferLikeString::Ptr payload;
if (_crypto) {
payload = _crypto->encrypt(pkt, const_cast<char*>(buf), len);
if (!payload) {
WarnL << "encrypt pkt->packet_seq_number: " << pkt->packet_seq_number << ", timestamp: " << "pkt->timestamp " << " fail";
return;
}
data = payload->data();
size = payload->size();
tryAnnounceKeyMaterial();
}
pkt->storeToData((uint8_t *)data, size);
sendPacket(pkt, flush);
_send_buf->inputPacket(pkt);
return;
}
void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) {
@@ -836,4 +998,4 @@ SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const uint32_t key) {
return it->second.lock();
}
} // namespace SRT
} // namespace SRT

View File

@@ -13,6 +13,7 @@
#include "Common.hpp"
#include "NackContext.hpp"
#include "Packet.hpp"
#include "Crypto.hpp"
#include "PacketQueue.hpp"
#include "PacketSendQueue.hpp"
#include "Statistic.hpp"
@@ -24,6 +25,7 @@ extern const std::string kPort;
extern const std::string kTimeOutSec;
extern const std::string kLatencyMul;
extern const std::string kPktBufSize;
extern const std::string kPassPhrase;
class SrtTransport : public std::enable_shared_from_this<SrtTransport> {
public:
@@ -60,6 +62,7 @@ protected:
virtual int getLatencyMul() { return 4; };
virtual int getPktBufSize() { return 8192; };
virtual float getTimeOutSec(){return 5.0;};
virtual std::string getPassphrase() {return "";};
private:
void registerSelf();
@@ -79,15 +82,19 @@ private:
void handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void sendNAKPacket(std::list<PacketQueue::LostPair> &lost_list);
void sendACKPacket();
void sendRejectPacket(SRT_REJECT_REASON reason, struct sockaddr_storage *addr);
void sendLightACKPacket();
void sendKeepLivePacket();
void sendShutDown();
void sendMsgDropReq(uint32_t first, uint32_t last);
void tryAnnounceKeyMaterial();
size_t getPayloadSize() const;
@@ -159,6 +166,11 @@ private:
Ticker _alive_ticker;
bool _is_handleshake_finished = false;
// for encryption
Crypto::Ptr _crypto;
Timer::Ptr _announce_timer;
KeyMaterialPacket::Ptr _announce_req;
};
class SrtTransportManager {
@@ -185,4 +197,4 @@ private:
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_TRANSPORT_H
#endif // ZLMEDIAKIT_SRT_TRANSPORT_H

View File

@@ -370,6 +370,11 @@ float SrtTransportImp::getTimeOutSec() {
return timeOutSec;
}
std::string SrtTransportImp::getPassphrase() {
GET_CONFIG(string, passphrase, kPassPhrase);
return passphrase;
}
int SrtTransportImp::getPktBufSize() {
// kPktBufSize
GET_CONFIG(int, pktBufSize, kPktBufSize);
@@ -380,4 +385,4 @@ int SrtTransportImp::getPktBufSize() {
return pktBufSize;
}
} // namespace SRT
} // namespace SRT

View File

@@ -38,6 +38,7 @@ protected:
int getLatencyMul() override;
int getPktBufSize() override;
float getTimeOutSec() override;
std::string getPassphrase() override;
void onSRTData(DataPacket::Ptr pkt) override;
void onShutdown(const SockException &ex) override;
void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) override;