【问题标题】:Asio SSL full-duplex socket synchronization problemAsio SSL 全双工套接字同步问题
【发布时间】:2022-12-11 13:08:45
【问题描述】:

我的 SSL 中继服务器 MVCE:

#pragma once

#include <stdint.h>
#include <iostream>
#include <asio.hpp>
#include <asio/ssl.hpp>

namespace test
{

namespace setup
{
    const uint32_t maxMessageSize = 1024 * 1024;
    const uint32_t maxSessionsNum = 10;
}

enum class MessageType
{
    LOG_ON = 0,
    TEXT_MESSAGE = 1
};

class MessageHeader
{
public:
    uint32_t messageType;
    uint32_t messageLength;

    MessageHeader(uint32_t messageType, uint32_t messageLength) : messageType(messageType), messageLength(messageLength) {}
};

class LogOn
{
public:
    MessageHeader header;
    uint32_t      sessionId;
    uint32_t      isClient0;

    LogOn() : header((uint32_t)MessageType::LOG_ON, sizeof(LogOn)) {}
};

class TextMessage
{
public:
    MessageHeader header;
    uint8_t       data[];

    TextMessage() : header((uint32_t)MessageType::TEXT_MESSAGE, sizeof(TextMessage)){}
};

class ClientSocket;

class Session
{
public:
    ClientSocket* pClient0;
    ClientSocket* pClient1;
};

Session* getSession(uint32_t sessionId);

class ClientSocket
{
public:
    bool useTLS;

    std::shared_ptr<asio::ip::tcp::socket> socket;
    std::shared_ptr<asio::ssl::stream<asio::ip::tcp::socket>> socketSSL;

    Session* pSession;
    bool     isClient0;

    std::recursive_mutex writeBufferLock;
    std::vector<char>    readBuffer;
    uint32_t             readPos;

    ClientSocket(asio::ip::tcp::socket& socket) : useTLS(false)
    {
        this->socket = std::make_shared<asio::ip::tcp::socket>(std::move(socket));
        this->readBuffer.resize(setup::maxMessageSize + sizeof(MessageHeader));
        this->readPos = 0;
    }

 ClientSocket(asio::ssl::stream<asio::ip::tcp::socket>& socket) : useTLS(true)
    {
        this->socketSSL = std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(std::move(socket));
        this->readBuffer.resize(setup::maxMessageSize + sizeof(MessageHeader));
        this->readPos = 0;
    }

    bool writeSocket(uint8_t* pBuffer, uint32_t bufferSize)
    {
        try
        {
            std::unique_lock<std::recursive_mutex>
lock(this->writeBufferLock);

            size_t writtenBytes = 0;

            if (true == this->useTLS)
            {
                writtenBytes = asio::write(*this->socketSSL,
asio::buffer(pBuffer, bufferSize));
            }
            else
            {
                writtenBytes = asio::write(*this->socket,
asio::buffer(pBuffer, bufferSize));
            }

            return (writtenBytes == bufferSize);
        }
        catch (asio::system_error e)
        {
            std::cout << e.what() << std::endl;
        }
        catch (std::exception e)
        {
            std::cout << e.what() << std::endl;
        }
        catch (...)
        {
            std::cout << "Some other exception" << std::endl;
        }

        return false;
    }

    void asyncReadNextMessage(uint32_t messageSize)
    {
        auto readMessageLambda = [&](const asio::error_code errorCode, std::size_t length)
        {
            this->readPos += (uint32_t)length;

            if (0 != errorCode.value())
            {
                //send socket to remove
                 printf("errorCode= %u, message=%s\n", errorCode.value(), errorCode.message().c_str());
                //sendRemoveMeSignal();
                return;
            }

            if ((this->readPos < sizeof(MessageHeader)))
            {
                asyncReadNextMessage(sizeof(MessageHeader) - this->readPos);
                return;
            }

            MessageHeader* pMessageHeader = (MessageHeader*)this->readBuffer.data();

            if (pMessageHeader->messageLength > setup::maxMessageSize)
            {
                //Message to big - should disconnect ?
                this->readPos = 0;
                asyncReadNextMessage(sizeof(MessageHeader));
                return;
            }

            if (this->readPos < pMessageHeader->messageLength)
            {
                asyncReadNextMessage(pMessageHeader->messageLength - this->readPos);
                return;
            }

            MessageType messageType = (MessageType)pMessageHeader->messageType;

            switch(messageType)
            {
                case MessageType::LOG_ON:
                {
                    LogOn* pLogOn = (LogOn*)pMessageHeader;
                    printf("LOG_ON message sessionId=%u, isClient0=%u\n", pLogOn->sessionId, pLogOn->isClient0);

                    this->isClient0 = pLogOn->isClient0;
                    this->pSession  = getSession(pLogOn->sessionId);

                    if (this->isClient0)
                        this->pSession->pClient0 = this;
                    else
                        this->pSession->pClient1 = this;

                }
                break;
                case MessageType::TEXT_MESSAGE:
                {
                    TextMessage* pTextMessage = (TextMessage*)pMessageHeader;

                    if (nullptr != pSession)
                    {
                        if (this->isClient0)
                        {
                            if (nullptr != pSession->pClient1)
                            {
                                pSession->pClient1->writeSocket((uint8_t*)pTextMessage, pTextMessage->header.messageLength);
                            }
                        }
                        else
                        {
                            if (nullptr != pSession->pClient0)
                            {
                                pSession->pClient0->writeSocket((uint8_t*)pTextMessage, pTextMessage->header.messageLength);
                            }
                        }
                    }
                }
                break;
            }

            this->readPos = 0;
            asyncReadNextMessage(sizeof(MessageHeader));
        };

        if (true == this->useTLS)
        {
            this->socketSSL->async_read_some(asio::buffer(this->readBuffer.data() + this->readPos, messageSize), readMessageLambda);
        }
        else
        {
            this->socket->async_read_some(asio::buffer(this->readBuffer.data() + this->readPos, messageSize), readMessageLambda);
        }
    }
};

class SSLRelayServer
{
public:
    static SSLRelayServer* pSingleton;

    asio::io_context   ioContext;
    asio::ssl::context sslContext;

    std::vector<std::thread> workerThreads;

    asio::ip::tcp::acceptor* pAcceptor;
    asio::ip::tcp::endpoint* pEndpoint;

    bool useTLS;

    Session* sessions[setup::maxSessionsNum];

    SSLRelayServer() : pAcceptor(nullptr), pEndpoint(nullptr), sslContext(asio::ssl::context::tlsv13_server)//sslContext(asio::ssl::context::sslv23)
    {
        this->useTLS     = false;
        this->pSingleton = this;

        //this->sslContext.set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2);
        this->sslContext.set_password_callback(std::bind(&SSLRelayServer::getPrivateKeyPEMFilePassword, this));
        this->sslContext.use_certificate_chain_file("server_cert.pem");
        this->sslContext.use_private_key_file("server_private_key.pem",
        asio::ssl::context::pem);
    }

    static SSLRelayServer* getSingleton()
    {
        return pSingleton;
    }

    std::string getPrivateKeyPEMFilePassword() const
    {
        return "";
    }

    void addClientSocket(asio::ip::tcp::socket& socket)
    {
        ClientSocket* pClientSocket = new ClientSocket(socket); // use smart pointers
        pClientSocket->asyncReadNextMessage(sizeof(MessageHeader));
    }

    void addSSLClientToken(asio::ssl::stream<asio::ip::tcp::socket>&sslSocket)
    {
        ClientSocket* pClientSocket = new ClientSocket(sslSocket); // use smart pointers
        pClientSocket->asyncReadNextMessage(sizeof(MessageHeader));
    }

    void handleAccept(asio::ip::tcp::socket& socket, const asio::error_code& errorCode)
    {
        if (!errorCode)
        {
            printf("accepted\n");

            if (true == socket.is_open())
            {
                asio::ip::tcp::no_delay no_delay_option(true);
                socket.set_option(no_delay_option);

                addClientSocket(socket);
            }
        }
    }

    void handleAcceptTLS(asio::ip::tcp::socket& socket, const asio::error_code& errorCode)
    {
        if (!errorCode)
        {
            printf("accepted\n");

            if (true == socket.is_open())
            {
                asio::ip::tcp::no_delay no_delay_option(true);

                asio::ssl::stream<asio::ip::tcp::socket> sslStream(std::move(socket), this->sslContext);

                try
                {
                    sslStream.handshake(asio::ssl::stream_base::server);
                    sslStream.lowest_layer().set_option(no_delay_option);

                    addSSLClientToken(sslStream);
                }
                catch (asio::system_error e)
                {
                    std::cout << e.what() << std::endl;
                    return;
                }
                catch (std::exception e)
                {
                    std::cout << e.what() << std::endl;
                    return;
                }
                catch (...)
                {
                    std::cout << "Other exception" << std::endl;
                    return;
                }

            }

        }
    }

    void startAccept()
    {
        auto acceptHandler = [this](const asio::error_code& errorCode, asio::ip::tcp::socket socket)
        {
            printf("acceptHandler\n");

            handleAccept(socket, errorCode);

            this->startAccept();
        };

        auto tlsAcceptHandler = [this](const asio::error_code& errorCode, asio::ip::tcp::socket socket)
        {
            printf("tlsAcceptHandler\n");

            handleAcceptTLS(socket, errorCode);

            this->startAccept();
        };

        if (true == this->useTLS)
        {
            this->pAcceptor->async_accept(tlsAcceptHandler);
        }
        else
        {
            this->pAcceptor->async_accept(acceptHandler);
        }
    }

    bool run(uint32_t servicePort, uint32_t threadsNum, bool useTLS)
    {
        this->useTLS = useTLS;

        this->pEndpoint = new asio::ip::tcp::endpoint(asio::ip::tcp::v4(), servicePort);
        this->pAcceptor = new asio::ip::tcp::acceptor(ioContext, *pEndpoint);

        this->pAcceptor->listen();

        this->startAccept();

        for (uint32_t threadIt = 0; threadIt < threadsNum; ++threadIt)
        {
            this->workerThreads.emplace_back([&]() {
#ifdef WINDOWS
                SetThreadDescription(GetCurrentThread(), L"SSLRelayServer worker thread");
#endif
                this->ioContext.run(); }
            );
        }

        return true;
    }

    Session* getSession(uint32_t sessionId)
    {
        if (nullptr == this->sessions[sessionId])
        {
            this->sessions[sessionId] = new Session();
        }

        return this->sessions[sessionId];
    }
};

SSLRelayServer* SSLRelayServer::pSingleton = nullptr;

Session* getSession(uint32_t sessionId)
{
    SSLRelayServer* pServer = SSLRelayServer::getSingleton();
    Session*        pSession = pServer->getSession(sessionId);
    return pSession;
}

class Client
{
public:
    asio::ssl::context sslContext;

    std::shared_ptr<asio::ip::tcp::socket> socket;
    std::shared_ptr<asio::ssl::stream<asio::ip::tcp::socket>> socketSSL;

    asio::io_context ioContext;

    bool useTLS;
    bool isClient0;

    uint32_t             readDataIt;
    std::vector<uint8_t> readBuffer;

    std::thread listenerThread;

    Client() : sslContext(asio::ssl::context::tlsv13_client)//sslContext(asio::ssl::context::sslv23)
    {
        sslContext.load_verify_file("server_cert.pem");
        //sslContext.set_verify_mode(asio::ssl::verify_peer);

        using asio::ip::tcp;
        using std::placeholders::_1;
        using std::placeholders::_2;
        sslContext.set_verify_callback(std::bind(&Client::verifyCertificate, this, _1, _2));

        this->readBuffer.resize(setup::maxMessageSize);
        this->readDataIt = 0;
    }

    bool verifyCertificate(bool preverified, asio::ssl::verify_context& verifyCtx)
    {
        return true;
    }

    void listenerRunner() 
    {
#ifdef WINDOWS
        if (this->isClient0)
        {
            SetThreadDescription(GetCurrentThread(), L"listenerRunner client0");
        }
        else
        {
            SetThreadDescription(GetCurrentThread(), L"listenerRunner client1");
        }
#endif

        while (1==1)
        {
            asio::error_code errorCode;

            size_t transferred = 0;
            if (true == this->useTLS)
            {
                transferred = this->socketSSL->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, sizeof(MessageHeader) - this->readDataIt), errorCode);
            }
            else
            {
                transferred = this->socket->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, sizeof(MessageHeader) - this->readDataIt), errorCode);
            }

            this->readDataIt += transferred;

            if (0 != errorCode.value())
            {
                this->readDataIt = 0;
                continue;
            }

            if (this->readDataIt < sizeof(MessageHeader))
                continue;

            MessageHeader* pMessageHeader = (MessageHeader*)this->readBuffer.data();

            if (pMessageHeader->messageLength > setup::maxMessageSize)
            {
                exit(1);
            }

            bool resetSocket = false;

            while (pMessageHeader->messageLength > this->readDataIt)
            {
                printf("readDataIt=%u, threadId=%u\n", this->readDataIt, GetCurrentThreadId());

                {
                    //message not complete
                    if (true == this->useTLS)
                    {
                        transferred = this->socketSSL->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, pMessageHeader->messageLength - this->readDataIt), errorCode);
                    }
                    else
                    {
                        transferred = this->socket->read_some(asio::buffer(this->readBuffer.data() + this->readDataIt, pMessageHeader->messageLength - this->readDataIt), errorCode);
                    }

                    this->readDataIt += transferred;
                }

                if (0 != errorCode.value())
                {
                    exit(1);
                }
            }

            MessageType messageType = (MessageType)pMessageHeader->messageType;

            switch (messageType)
            {
                case MessageType::TEXT_MESSAGE:
                {
                    TextMessage* pTextMessage = (TextMessage*)pMessageHeader;
                    printf("TEXT_MESSAGE: %s\n", pTextMessage->data);
                }
                break;
            }

            this->readDataIt = 0;
        }
    }

    void run(uint32_t sessionId, bool isClient0, bool useTLS, uint32_t servicePort)
    {
        this->useTLS    = useTLS;
        this->isClient0 = isClient0;

        if (useTLS)
        {
            socketSSL = std::make_shared<asio::ssl::stream<asio::ip::tcp::socket>>(ioContext, sslContext);
        }
        else
        {
            socket = std::make_shared<asio::ip::tcp::socket>(ioContext);
        }

        asio::ip::tcp::resolver resolver(ioContext);

        asio::ip::tcp::resolver::results_type endpoints = resolver.resolve(asio::ip::tcp::v4(), "127.0.0.1", std::to_string(servicePort));

        asio::ip::tcp::no_delay no_delay_option(true);

        if (true == useTLS)
        {
            asio::ip::tcp::endpoint sslEndpoint = asio::connect(socketSSL->lowest_layer(), endpoints);
            socketSSL->handshake(asio::ssl::stream_base::client);
            socketSSL->lowest_layer().set_option(no_delay_option);
        }
        else
        {
            asio::ip::tcp::endpoint endpoint = asio::connect(*socket, endpoints);
            socket->set_option(no_delay_option);
        }

        this->listenerThread = std::thread(&Client::listenerRunner, this);

        LogOn logOn;
        logOn.isClient0 = isClient0;
        logOn.sessionId = sessionId;

        const uint32_t logOnSize = sizeof(logOn);

        if (true == useTLS)
        {
            size_t transferred = asio::write(*socketSSL, asio::buffer(&logOn, sizeof(LogOn)));
        }
        else
        {
            size_t transferred = asio::write(*socket, asio::buffer(&logOn, sizeof(LogOn)));
        }

        uint32_t counter = 0;

        while (1 == 1)
        {
            std::string number  = std::to_string(counter);
            std::string message;

            if (this->isClient0)
            {
                message = "Client0: " + number;
            }
            else
            {
                message = "Client1: " + number;
            }

            TextMessage textMessage;
            textMessage.header.messageLength += message.size() + 1;

            if (this->useTLS)
            {
                size_t transferred = asio::write(*socketSSL, asio::buffer(&textMessage, sizeof(TextMessage)));
                transferred        = asio::write(*socketSSL, asio::buffer(message.c_str(), message.length() + 1));
            }
            else
            {
                size_t transferred = asio::write(*socket, asio::buffer(&textMessage, sizeof(TextMessage)));
                transferred        = asio::write(*socket, asio::buffer(message.c_str(), message.length() + 1));
            }

            ++counter;
            //Sleep(1000);
        }
    }
};

void clientTest(uint32_t sessionId, bool isClient0, bool useTLS,
uint32_t servicePort)
{
#ifdef WINDOWS
    if (isClient0)
    {
        SetThreadDescription(GetCurrentThread(), L"Client0");
    }
    else
    {
        SetThreadDescription(GetCurrentThread(), L"Client1");
    }
#endif

    Client client;

    client.run(sessionId, isClient0, useTLS, servicePort);

    while (1 == 1)
    {
        Sleep(1000);
    }
}

void SSLRelayTest()
{
    SSLRelayServer relayServer;

    const uint32_t threadsNum  = 1;
    const bool     useTLS      = true;
    const uint32_t servicePort = 777;
    relayServer.run(servicePort, threadsNum, useTLS);

    Sleep(5000);

    std::vector<std::thread> threads;

    const uint32_t sessionId = 0;
    threads.emplace_back(clientTest, sessionId, true, useTLS, servicePort);
    threads.emplace_back(clientTest, sessionId, false, useTLS,servicePort);

    for (std::thread& threadIt : threads)
    {
        threadIt.join();
    }
}

}

这个样本是做什么的? 它在连接两个客户端并允许交换的本地主机端口 777 上运行 SSL 中继服务器 他们之间的短信。

问题: 当我运行该示例服务器时,在无效的“asyncReadNextMessage(uint32_t messageSize)”中返回错误“errorCode = 167772441,消息=解密失败或错误记录mac(SSL例程)” 我发现这是由客户端引起的,它从不同的线程读取和写入客户端 SSL 套接字(将变量 useTLS 更改为 0 在正常套接字上运行它,这证明它是 SSL 套接字问题)。 显然 TLS 不是全双工协议(我不知道)。我无法使用互斥锁同步读取和写入的访问,因为当套接字进入读取状态并且没有 写入 socked 的传入消息将永远被阻止。在这个线程Boost ASIO, SSL: How do strands help the implementation? 有人推荐使用 strands 但其他人写道 asio 只同步而不是并发执行读写句柄,这不能解决问题。

我希望有一种方法可以同步读取和写入 SSL 套接字。我 100% 确定问题在于同步读取和写入套接字,因为当我编写示例并通过一个线程完成对套接字的读取和写入时,它起作用了。然而,客户端总是希望有消息可以读取,如果没有则可以阻止所有写入。如果不使用单独的套接字进行读取和写入,是否可以解决这个问题?

【问题讨论】:

  • “显然 TLS 不是全双工协议”- 这是。这不是协议本身的问题,而是具体实现的问题。 OpenSSL 有一个关于当前 TLS 状态的结构,在读取和写入时都需要更新 - 因此需要对多线程中的共享资源进行常见的保护。
  • 好的,如果我只在客户端内部的一个线程上调用 ioContext.run() 并使用 asio::async_write 和 async_read 而不是 asio::read 和 asio::write 这应该可以工作,因为使用了隐式链,但如果不是,则意味着asio 不可能吗?
  • 我不熟悉 asio 内部结构。但是注意the asio SSL documentation末尾关于线程的信息

标签: ssl asio


【解决方案1】:

好的,我通过编写许多不同的代码示例(包括 SSL 套接字)来解决这个问题。 当 asio::io_context 已经在运行时,你不能简单地从线程中调度 asio::async_write 或 asio::async_read 与连接到该插座的链相关联。

所以当有: asio::async_write(*this-&gt;socketSSL, asio::buffer(pBuffer, bufferSize), asio::bind_executor(readWriteStrand,writeMessageLambda)); 但是正在执行的线程不是从 readWriteStrand strand 运行的,那么它应该写成: asio::post(ioContext, asio::bind_executor(readWriteStrand, [&amp;]() {asio::async_read(*this-&gt;socketSSL, asio::buffer(readBuffer.data() + this-&gt;readDataIt, messageSize), asio::bind_executor(readWriteStrand, readMessageLambda)); }));

【讨论】:

    猜你喜欢
    • 2011-02-13
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2011-12-05
    • 1970-01-01
    • 2013-09-14
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多