(file) Return to protocol.c CVS log (file) (dir) Up to [OMI] / omi / protocol

File: [OMI] / omi / protocol / protocol.c (download)
Revision: 1.2, Mon Jun 25 18:51:02 2012 UTC (11 years, 11 months ago) by mike
Branch: MAIN
CVS Tags: OMI_1_0_2
Changes since 1.1: +4 -4 lines
1.0.2 drop

/*
**==============================================================================
**
** Open Management Infrastructure (OMI)
**
** Copyright (c) Microsoft Corporation
** 
** Licensed under the Apache License, Version 2.0 (the "License"); you may not 
** use this file except in compliance with the License. You may obtain a copy 
** of the License at 
**
**     http://www.apache.org/licenses/LICENSE-2.0 
**
** THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
** KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED 
** WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, 
** MERCHANTABLITY OR NON-INFRINGEMENT. 
**
** See the Apache 2 License for the specific language governing permissions 
** and limitations under the License.
**
**==============================================================================
*/

#include <assert.h>
#include "protocol.h"
#include <sock/addr.h>
#include <sock/sock.h>
#include <sock/selector.h>
#include "header.h"
#include <sock/thread.h>
#include <base/buf.h>
#include <base/log.h>
#include <base/result.h>
#include <base/user.h>
#include <base/io.h>

#define T MI_T

#if 0
#define ENABLE_TRACING
#endif

#ifdef  ENABLE_TRACING
#define PRINTF(a)  printf a
#else
#define PRINTF(a)
#endif

/*
**==============================================================================
**
** Local definitions:
**
**==============================================================================
*/

static const MI_Uint32 _MAGIC = 0xC764445E;
#define SR_SOCKET_OUT_QUEUE_WATERMARK_LOW  4
#define SR_SOCKET_OUT_QUEUE_WATERMARK_HIGH  12
#define SR_SOCKET_OUT_QUEUE_WATERMARK_CRITICAL  256

typedef enum _Protocol_AuthState
{
    /* authentication failed (intentionaly takes value '0')*/
    PRT_AUTH_FAILED,

    /* listener (server) waits for connect request */
    PRT_AUTH_WAIT_CONNECTION_REQUEST,

    /* listener (server) waits for second connect request with random data from file */
    PRT_AUTH_WAIT_CONNECTION_REQUEST_WITH_FILE_DATA,

    /* connector (client) waits for server's response */
    PRT_AUTH_WAIT_CONNECTION_RESPONSE,

    /* authentication completed */
    PRT_AUTH_OK
}
Protocol_AuthState;

typedef enum _Protocol_Type
{
    PRT_TYPE_LISTENER,
    PRT_TYPE_CONNECTOR,
    PRT_TYPE_FROM_SOCKET
}
Protocol_Type;

struct _Protocol
{
    MI_Uint32 magic;
    Selector internal_selector;
    Selector*   selector;
    Addr addr;
    struct _Protocol_SR_SocketData* connectorHandle;
    ProtocolCallback callback;
    void* callbackData;
    ProtocolEventCallback eventCallback;
    void* eventCallbackData;
    Protocol_Type   type;
    MI_Boolean  internal_selector_used;
    /* Indicates whether instance has to be upacked or stored as byte array */
    MI_Boolean  skipInstanceUnpack;
};

/* Keeps data for file-based authentication */
typedef struct _Protocol_AuthData
{
    char path[MAX_PATH_SIZE];
    char   authRandom[AUTH_RANDOM_DATA_SIZE];
}
Protocol_AuthData;

typedef struct _Protocol_SR_SocketData
{
    /* based member*/
    Handler     base;

    /* sending data */
    /* Linked list of messages to send */
    ListElem* head;
    ListElem* tail;

    /* ref counter 
    NOTE:
        socket may be disconnected, but structure is still alive
        if some outstanding request are not completed.
        refcounter gets '+1' for being connected and '+1' for each outstanding request */
    int     refcounter;

    /* currently sending message */
    Message*    message;
    size_t          sentCurrentBlockBytes;
    int             sendingPageIndex; /* 0 for header otherwise 1-N page index */
    int             sendingQueueLength; /* number of messages in list to send; -1 to disbale watermarks */

    /* receiving data */
    Batch       *receivingBatch;
    size_t          receivedCurrentBlockBytes;
    int             receivingPageIndex; /* 0 for header otherwise 1-N page index */

    /* send/recv buffers */
    Header recv_buffer;
    Header send_buffer;

    /* Auth state */
    Protocol_AuthState  authState;
    /* server side - auhtenticated user's ids */
    uid_t uid;
    gid_t gid;
    Protocol_AuthData* authData;

    /* Whether this is a connector */
    MI_Boolean isConnector;

    /* Whether connection has been established */
    MI_Boolean isConnected;

} 
Protocol_SR_SocketData;

/* helper functions result */
typedef enum _Protocol_CallbackResult
{
    PRT_CONTINUE,
    PRT_RETURN_TRUE,
    PRT_RETURN_FALSE
}
Protocol_CallbackResult;

/* Forward declaration */
static void _PrepareMessageForSending(
    Protocol_SR_SocketData *handler);

static MI_Boolean _RequestCallbackWrite(
    Protocol_SR_SocketData* handler);

/**************** Auth-support **********************************************************/
/* remove auth file and free auth data */
static void _FreeAuthData(
    Protocol_SR_SocketData* h)
{
    if (h->authData)
    {
#if defined(CONFIG_POSIX)
        unlink(h->authData->path);
#endif
        free(h->authData);
        h->authData = 0;
    }
}

/* Creates and sends authentication request message */
static MI_Boolean _SendAuthRequest(
    Protocol_SR_SocketData* h,
    const char* user,
    const char* password,
    const char* fileContent)
{
    BinProtocolNotification* req;

    req = BinProtocolNotification_New(BinNotificationConnectRequest);

    if (!req)
        return MI_FALSE;

    if (user && *user)
        req->user = Batch_Strdup(req->base.batch, user);

    if (password && *password)
        req->password = Batch_Strdup(req->base.batch, password);

    req->uid = geteuid();
    req->gid = getegid();

    if (fileContent)
    {
        memcpy(req->authData, fileContent, sizeof(req->authData));
    }

    /* send message */
    {
        /* add message to the list */
        List_Prepend(&h->head, &h->tail, (ListElem*)req);

        if (-1 != h->sendingQueueLength)
            h->sendingQueueLength++;

        Message_AddRef(&req->base);

        _PrepareMessageForSending(h);
        _RequestCallbackWrite(h);
    }

    BinProtocolNotification_Release(req);

    return MI_TRUE;
}

static MI_Boolean _SendAuthResponse(
    Protocol_SR_SocketData* h,
    MI_Result result,
    const char* path)
{
    BinProtocolNotification* req;

    req = BinProtocolNotification_New(BinNotificationConnectResponse);

    if (!req)
        return MI_FALSE;

    req->result = result;
    if (path && *path)
        req->authFile = Batch_Strdup(req->base.batch, path);

    /* send message */
    {
        /* add message to the list */
        List_Prepend(&h->head, &h->tail, (ListElem*)req);

        if (-1 != h->sendingQueueLength)
            h->sendingQueueLength++;

        Message_AddRef(&req->base);

        _PrepareMessageForSending(h);
        _RequestCallbackWrite(h);
    }

    BinProtocolNotification_Release(req);

    return MI_TRUE;
}

/*
    Processes auht message while waiting second connect request 
    with content of the file.
    Updates auth states correspondingly.
    Parameters:
    handler - socket handler
    binMsg - BinProtocolNotification message with connect request/response

    Return:
    "TRUE" if connection should stay open; "FALSE" if auth failed
        and conneciton should be closed immediately
*/
static MI_Boolean _ProcessAuthMessageWaitingConnectRequestFileData(
    Protocol_SR_SocketData* handler,
    BinProtocolNotification* binMsg)
{
    /* un-expected message */
    if (BinNotificationConnectRequest != binMsg->type)
        return MI_FALSE;

    /* Check internal state */
    if (!handler->authData)
        return MI_FALSE;

    if (0 == memcmp(binMsg->authData, handler->authData->authRandom, AUTH_RANDOM_DATA_SIZE))
    {
        if (!_SendAuthResponse(handler, MI_RESULT_OK, NULL))
            return MI_FALSE;

        /* Auth ok */
        handler->authState = PRT_AUTH_OK;
        _FreeAuthData(handler);

        /* Get gid from user name */
        if (0 != GetUserGidByUid(handler->uid, &handler->gid))
        {
            LOGW_CHAR(("cannot get user's gid for uid %d", (int)handler->uid));
            return MI_FALSE;
        }

        return MI_TRUE;
    }

    LOGW_CHAR(("auth failed - random data mismatch"));

    /* Auth failed */
    _SendAuthResponse(handler, MI_RESULT_ACCESS_DENIED, NULL);
    handler->authState = PRT_AUTH_FAILED;
    return MI_FALSE;
}

/*
    Processes auht message while waiting connect request 
    Updates auth states correspondingly.
    Parameters:
    handler - socket handler
    binMsg - BinProtocolNotification message with connect request/response

    Return:
    "TRUE" if connection should stay open; "FALSE" if auth failed
        and conneciton should be closed immediately
*/
static MI_Boolean _ProcessAuthMessageWaitingConnectRequest(
    Protocol_SR_SocketData* handler,
    BinProtocolNotification* binMsg)
{
    /* un-expected message */
    if (BinNotificationConnectRequest != binMsg->type)
        return MI_FALSE;

    /* Use explicit credentials if provided */
    if (binMsg->user)
    {
        /* use empty password if not set */
        if (!binMsg->password)
            binMsg->password = "";

        if ( 0 == AuthenticateUser(binMsg->user, binMsg->password) &&
            0 == LookupUser(binMsg->user, &handler->uid, &handler->gid))
        {
            if (!_SendAuthResponse(handler, MI_RESULT_OK, NULL))
                return MI_FALSE;

            /* Auth ok */
            handler->authState = PRT_AUTH_OK;
            _FreeAuthData(handler);
            return MI_TRUE;
        }

        LOGW_CHAR(("auth failed for user [%s]", binMsg->user));

        /* Auth failed */
        _SendAuthResponse(handler, MI_RESULT_ACCESS_DENIED, NULL);
        handler->authState = PRT_AUTH_FAILED;
        return MI_FALSE;
    }

    /* If system supports connection-based auth, use it for 
        implicit auth */
    if (0 == GetUIDByConnection((int)handler->base.sock, &handler->uid, &handler->gid))
    {
        if (!_SendAuthResponse(handler, MI_RESULT_OK, NULL))
            return MI_FALSE;

        /* Auth ok */
        handler->authState = PRT_AUTH_OK;
        return MI_TRUE;
    }
#if defined(CONFIG_OS_WINDOWS)
    /* ignore auth on Windows */
    {
        if (!_SendAuthResponse(handler, MI_RESULT_OK, NULL))
            return MI_FALSE;

        /* Auth ok */
        handler->uid = -1;
        handler->gid = -1;
        handler->authState = PRT_AUTH_OK;
        return MI_TRUE;
    }
#else

    /* If valid uid provided, try implicit credentials (file-based) 
        gid will be taken from user name */
    {
        handler->authData = (Protocol_AuthData*)calloc(1, sizeof(Protocol_AuthData));

        if (!handler->authData)
        {
            /* Auth failed */
            _SendAuthResponse(handler, MI_RESULT_ACCESS_DENIED, NULL);
            handler->authState = PRT_AUTH_FAILED;
            return MI_FALSE;
        }

        if (0 != CreateAuthFile(binMsg->uid, handler->authData->authRandom, AUTH_RANDOM_DATA_SIZE, handler->authData->path))
        {
            LOGW_CHAR(("cannot create file for user uid [%d]", (int)binMsg->uid));

            /* Auth failed */
            _SendAuthResponse(handler, MI_RESULT_ACCESS_DENIED, NULL);
            handler->authState = PRT_AUTH_FAILED;
            return MI_FALSE;
        }

        /* send file name to the client */
        if (!_SendAuthResponse(handler, MI_RESULT_IN_PROGRESS, handler->authData->path))
            return MI_FALSE;

        /* Auth posponed */
        handler->authState = PRT_AUTH_WAIT_CONNECTION_REQUEST_WITH_FILE_DATA;

        /* Remember uid we used to create file */
        handler->uid = binMsg->uid;
        handler->gid = -1;

        return MI_TRUE;

    }
#endif
}

/*
    Processes auht message (either connect request or connect-response)
    Updates auth states correspondingly.
    Parameters:
    handler - socket handler
    msg - BinProtocolNotification message with connect request/response

    Return:
    "TRUE" if connection should stay open; "FALSE" if auth failed
        and conneciton should be closed immediately
*/
static MI_Boolean _ProcessAuthMessage(
    Protocol_SR_SocketData* handler,
    Message *msg)
{
    BinProtocolNotification* binMsg;

    if (msg->tag != BinProtocolNotificationTag)
        return MI_FALSE;

    binMsg = (BinProtocolNotification*) msg;

    /* server waiting client's first request? */
    if (PRT_AUTH_WAIT_CONNECTION_REQUEST == handler->authState)
    {
        return _ProcessAuthMessageWaitingConnectRequest(handler, binMsg);
    }

    /* server waiting for client's file's content request? */
    if (PRT_AUTH_WAIT_CONNECTION_REQUEST_WITH_FILE_DATA == handler->authState)
    {
        return _ProcessAuthMessageWaitingConnectRequestFileData(handler, binMsg);
    }

    /* client waiting for server's response? */
    if (PRT_AUTH_WAIT_CONNECTION_RESPONSE == handler->authState)
    {
        /* un-expected message */
        if (BinNotificationConnectResponse != binMsg->type)
            return MI_FALSE;

        if (binMsg->result == MI_RESULT_OK)
        {
            handler->authState = PRT_AUTH_OK;

            /* process backlog items (if any) */
            _PrepareMessageForSending(handler);
            _RequestCallbackWrite(handler);
            return MI_TRUE;
        }
        else if (binMsg->result == MI_RESULT_IN_PROGRESS && binMsg->authFile)
        {
            /* send back file's content */
            char buf[AUTH_RANDOM_DATA_SIZE];
            FILE* is = Fopen(binMsg->authFile, "r");

            if (!is)
            {
                LOGE_CHAR(("cannot open auth data file: %s", binMsg->authFile));
                return MI_FALSE;
            }

            /* Read auth data from the file. */
            if (sizeof(buf) != fread(buf, 1, sizeof(buf), is))
            {
                LOGE_CHAR(("cannot read from auth data file: %s", binMsg->authFile));
                fclose(is);
                return MI_FALSE;
            }

            fclose(is);
            return _SendAuthRequest(handler, 0, 0, buf);
        }
        else
        {
            /* PROTOCOLEVENT_DISCONNECT */
            if (handler->isConnector)
            {
                Protocol* self = (Protocol*)handler->base.data;
                if (self->eventCallback)
                {
                    (*self->eventCallback)(self,  handler->isConnected ? PROTOCOLEVENT_DISCONNECT : PROTOCOLEVENT_CONNECT_FAILED,
                        self->eventCallbackData);
                }
                handler->isConnected = MI_FALSE;
            }
        }

        return MI_FALSE;
    }

    /* unknown state? */
    return MI_FALSE;
}


/****************************************************************************************/
static void _RemoveAllMessages(
    Protocol_SR_SocketData* handler)
{
    while (handler->head)
    {
        Message* msg = (Message*)handler->head;

        List_Remove(&handler->head, &handler->tail, (ListElem*)msg);
        Message_Release(msg);

        if (-1 != handler->sendingQueueLength)
            handler->sendingQueueLength--;
    }
}

static void _Release(
    Protocol_SR_SocketData* handler)
{
    if (--handler->refcounter == 0)
    {
        free(handler);
    }
}

static void _PrepareMessageForSending(
    Protocol_SR_SocketData *handler)
{
    /* check for hi watermark */
    if (handler->sendingQueueLength == SR_SOCKET_OUT_QUEUE_WATERMARK_HIGH)
        handler->base.mask &= ~SELECTOR_READ;

    /* check for hi watermark */
    if (handler->sendingQueueLength == SR_SOCKET_OUT_QUEUE_WATERMARK_LOW)
        handler->base.mask |= SELECTOR_READ;

    if (handler->message)
        return; /* already sending */

    if (!handler->head)
    {
        handler->base.mask &= ~SELECTOR_WRITE;
        return; /*nothing to do*/
    }

    /* before auht is complete, only auht-related messages should be sent */
    if (PRT_AUTH_OK != handler->authState && BinProtocolNotificationTag != ((Message*)handler->head)->tag)
    {
        handler->base.mask &= ~SELECTOR_WRITE;
        return; /*nothing to do*/
    }

    handler->message = (Message*)handler->head;
    List_Remove(&handler->head, &handler->tail, (ListElem*)handler->message);

    if (-1 != handler->sendingQueueLength)
        handler->sendingQueueLength--;

    /* reset sending attributes */
    handler->sendingPageIndex = 0;
    handler->sentCurrentBlockBytes = 0;

    memset(&handler->send_buffer,0,sizeof(handler->send_buffer));
    handler->send_buffer.base.magic = PROTOCOL_MAGIC;
    handler->send_buffer.base.version = PROTOCOL_VERSION;
    handler->send_buffer.base.pageCount = (MI_Uint32)Batch_GetPageCount(handler->message->batch);
    handler->send_buffer.base.originalMessagePointer = handler->message;

    /* ATTN! */
    assert (handler->send_buffer.base.pageCount <= PROTOCOL_HEADER_MAX_PAGES);

    /* get page info */

    Batch_GetPageInfo(
        handler->message->batch, handler->send_buffer.batchInfo);

    /* mark handler as 'want-write' */
    handler->base.mask |= SELECTOR_WRITE;

}

static MI_Boolean _RequestCallbackWrite(
    Protocol_SR_SocketData* handler)
{
    /* try to write to socket as much as possible */
    size_t sent;
    MI_Result r;

    for(;;)
    {
        /* buffers to write */
        IOVec   buffers[32];
        size_t counter;

        if ( !handler->message )
        { /* nothing to send */
            handler->base.mask &= ~SELECTOR_WRITE;
            return MI_TRUE;
        }

        for ( counter = 0; counter < MI_COUNT(buffers); counter++ )
        {
            const char* buf;
            MI_Uint32 index = (MI_Uint32)(handler->sendingPageIndex + counter);

            buf = (index == 0) ? 
                &handler->send_buffer : 
                handler->send_buffer.batchInfo[index - 1].pagePointer;

            if (!counter)
                buf += handler->sentCurrentBlockBytes;

            buffers[counter].ptr = (void*)buf;

            buffers[counter].len = (index == 0) ? (sizeof(HeaderBase) + sizeof(Header_BatchInfoItem) * handler->send_buffer.base.pageCount)
                : handler->send_buffer.batchInfo[index - 1].pageSize;

            if (!counter)
                buffers[counter].len -= handler->sentCurrentBlockBytes;

            if ( index == handler->send_buffer.base.pageCount)
            {
                counter++;
                break;
            }
        }

        sent = 0;

        r = Sock_WriteV(handler->base.sock, buffers, counter, &sent);

        PRINTF(("sent %d\n", sent));

        if ( r == MI_RESULT_OK && 0 == sent )
            return MI_FALSE; /* conection closed */

        if ( r != MI_RESULT_OK && r != MI_RESULT_WOULD_BLOCK )
            return MI_FALSE;

        if (!sent)
            return MI_TRUE;

        /* update index */
        for ( counter = 0; counter < MI_COUNT(buffers); counter++ )
        {
            if (!sent)
                break;

            if (sent >= buffers[counter].len)
            {
                sent -= buffers[counter].len;
                handler->sendingPageIndex++;
                handler->sentCurrentBlockBytes = 0;
                continue;
            }

            handler->sentCurrentBlockBytes += sent;
            break;
        }

        if ( (handler->sendingPageIndex - 1) == (int)handler->send_buffer.base.pageCount )
        {
            PRINTF(("done with sending message tag %d\n", handler->message->tag));

            /* next message */
            Message_Release(handler->message);
            handler->message = 0;

            _PrepareMessageForSending(handler);
        }
    }
}

/*
    Processes incoming message, including:
        - decoding message from batch
        - invoking callback to process message

    Parameters:
        handler - pointer to received data
    Returns:
        it returns result if 'callback' with the followinf meaning:
        MI_TRUE - to continue normal operations
        MI_FALSE - to close connection
*/
static MI_Boolean _ProcessReceivedMessage(
    Protocol_SR_SocketData* handler)
{
    MI_Result r;
    Message* msg = 0;
    Protocol* self = (Protocol*)handler->base.data;
    MI_Boolean ret = MI_TRUE;

    /* create a message from a batch */
    r = MessageFromBatch(
        handler->receivingBatch,
        handler->recv_buffer.base.originalMessagePointer,
        handler->recv_buffer.batchInfo,
        handler->recv_buffer.base.pageCount,
        self->skipInstanceUnpack,
        &msg);

    if (MI_RESULT_OK == r)
    {
        PRINTF(("done with receiving message tag %d\n", msg->tag));

        if (PRT_AUTH_OK != handler->authState)
        {
            ret = _ProcessAuthMessage(handler, msg);
        }
        else
        {
            /* attach client id */
            msg->clientID = PtrToUint64(handler);

            /* +1 for incoming request */
            handler->refcounter++;

            /* auth info */
            msg->uid = handler->uid;
            msg->gid = handler->gid;

            /* count message in for back-pressure feature (only Instances) */
            if (PostInstanceMsgTag == msg->tag &&
                PRT_TYPE_FROM_SOCKET == self->type)
                Selector_NewInstanceCreated(self->selector, msg);

            ret = (*self->callback)(self, msg, self->callbackData);
        }

        Message_Release(msg);
    }
    else
    {
        LOGW((T("failed to restore message %d [%s]\n"), r,
            Result_ToString(r)));
        Batch_Destroy( handler->receivingBatch );
    }

    /* clean up the state */
    handler->receivingBatch = 0;
    handler->receivingPageIndex = 0;
    memset(&handler->recv_buffer,0,sizeof(handler->recv_buffer));

    return ret;
}


static Protocol_CallbackResult _ReadHeader(
    Protocol_SR_SocketData* handler)
{
    char* buf;
    size_t buf_size, received;
    MI_Result r;
    MI_Uint32 index;

    /* are we done with header? */
    if (0!= handler->receivingPageIndex)
        return PRT_CONTINUE;

    for (;;)
    {
        buf = (char*)&handler->recv_buffer;
        buf_size = (sizeof(HeaderBase) + sizeof(Header_BatchInfoItem) * handler->recv_buffer.base.pageCount);
        received = 0;

        r = Sock_Read(handler->base.sock, buf + handler->receivedCurrentBlockBytes, buf_size - handler->receivedCurrentBlockBytes, &received);

        PRINTF(("read %d\n", received));

        if ( r == MI_RESULT_OK && 0 == received )
            return PRT_RETURN_FALSE; /* conection closed */

        if ( r != MI_RESULT_OK && r != MI_RESULT_WOULD_BLOCK )
            return PRT_RETURN_FALSE;

        if (!received)
            return PRT_RETURN_TRUE;

        handler->receivedCurrentBlockBytes += received;

        if (handler->receivedCurrentBlockBytes == buf_size)
        {
            /* got header - validate/allocate as required */
            if (handler->recv_buffer.base.pageCount > PROTOCOL_HEADER_MAX_PAGES)
                return PRT_RETURN_FALSE;

            if (handler->recv_buffer.base.magic != PROTOCOL_MAGIC)
                return PRT_RETURN_FALSE;

            for (index =0; index < handler->recv_buffer.base.pageCount; index++)
            {
                if (handler->recv_buffer.batchInfo[index].pageSize > (64*1024))
                    return PRT_RETURN_FALSE;
            }

            /* check if page info is also retrieved */
            if (buf_size != ((sizeof(HeaderBase) + sizeof(Header_BatchInfoItem) * handler->recv_buffer.base.pageCount)) )
                continue;

            /* create a batch */
            if (!Batch_CreateBatchByPageInfo(
                &handler->receivingBatch, 
                handler->recv_buffer.batchInfo,
                handler->recv_buffer.base.pageCount))
                return PRT_RETURN_FALSE;

            /* skip to next page */
            handler->receivingPageIndex++;
            handler->receivedCurrentBlockBytes = 0;

            if ( (handler->receivingPageIndex - 1) == (int)handler->recv_buffer.base.pageCount )
            {   /* received the whole message - process it */
                if (!_ProcessReceivedMessage(handler))
                    return PRT_RETURN_FALSE;
            }
            break;
        } /* if we read the whole buffer */
    } /* for(;;)*/
    return PRT_CONTINUE;
}


static Protocol_CallbackResult _ReadAllPages(
    Protocol_SR_SocketData* handler)
{
    size_t received;
    MI_Result r;
    /* buffers to write */
    IOVec   buffers[32];
    size_t counter;

    /* are we done with header? - if not, return 'continue' */
    if (0== handler->receivingPageIndex)
        return PRT_CONTINUE;


    for ( counter = 0; counter < MI_COUNT(buffers); counter++ )
    {
        const char* buf;
        MI_Uint32 index = (MI_Uint32)(handler->receivingPageIndex + counter);

        buf = Batch_GetPageByIndex(handler->receivingBatch, index - 1);
        
        if (!counter)
            buf += handler->receivedCurrentBlockBytes;

        buffers[counter].ptr = (void*)buf;
        buffers[counter].len = handler->recv_buffer.batchInfo[index - 1].pageSize;

        if (!counter)
            buffers[counter].len -= handler->receivedCurrentBlockBytes;

        if ( index == handler->recv_buffer.base.pageCount)
        {
            counter++;
            break;
        }
    }

    received = 0;

    r = Sock_ReadV(handler->base.sock, buffers, counter, &received);

    PRINTF(("read %d\n", received));

    if ( r == MI_RESULT_OK && 0 == received )
        return PRT_RETURN_FALSE; /* conection closed */

    if ( r != MI_RESULT_OK && r != MI_RESULT_WOULD_BLOCK )
        return PRT_RETURN_FALSE;

    if (!received)
        return PRT_RETURN_TRUE;

    /* update index */
    for ( counter = 0; counter < MI_COUNT(buffers); counter++ )
    {
        if (!received)
            break;

        if (received >= buffers[counter].len)
        {
            received -= buffers[counter].len;
            handler->receivingPageIndex++;
            handler->receivedCurrentBlockBytes = 0;
            continue;
        }

        handler->receivedCurrentBlockBytes += received;
        break;
    }

    if ( (handler->receivingPageIndex - 1) == (int)handler->recv_buffer.base.pageCount )
    {   /* received the whole message - process it */
        if (!_ProcessReceivedMessage(handler))
            return PRT_RETURN_FALSE;
    }

    return PRT_CONTINUE;
}

static MI_Boolean _RequestCallbackRead(
    Protocol_SR_SocketData* handler)
{
    int fullMessagesREceived = 0;

    /* we have to keep repeating read until 'WOULD_BLOCK is returned;
        windows does not reset event until read buffer is empty */
    for (;fullMessagesREceived < 3;)
    {
        switch (_ReadHeader(handler))
        {
        case PRT_CONTINUE: break;
        case PRT_RETURN_TRUE: return MI_TRUE;
        case PRT_RETURN_FALSE: return MI_FALSE;
        }

        switch (_ReadAllPages(handler))
        {
        case PRT_CONTINUE: break;
        case PRT_RETURN_TRUE: return MI_TRUE;
        case PRT_RETURN_FALSE: return MI_FALSE;
        }
    } /* for(;;)*/
    return MI_TRUE;
}

static MI_Boolean _RequestCallback(
    Selector* sel,
    Handler* handlerIn,
    MI_Uint32 mask, 
    MI_Uint64 currentTimeUsec)
{
    Protocol_SR_SocketData* handler = (Protocol_SR_SocketData*)handlerIn;

    MI_UNUSED(sel);
    MI_UNUSED(currentTimeUsec);

    if (mask & SELECTOR_READ)
    {
        if (!_RequestCallbackRead(handler))
        {
            /* PROTOCOLEVENT_DISCONNECT */
            if (handler->isConnector)
            {
                Protocol* self = (Protocol*)handler->base.data;
                if (self->eventCallback)
                {
                    (*self->eventCallback)(self,  handler->isConnected ? PROTOCOLEVENT_DISCONNECT : PROTOCOLEVENT_CONNECT_FAILED,
                        self->eventCallbackData);
                }
                handler->isConnected = MI_FALSE;
            }
            goto closeConnection;
        }
        else if (handler->isConnector && !handler->isConnected)
        {
            Protocol* self = (Protocol*)handler->base.data;
            if (self->eventCallback)
            {
                (*self->eventCallback)(self, PROTOCOLEVENT_CONNECT,
                    self->eventCallbackData);
            }
            handler->isConnected = MI_TRUE;
        }
    }

    if (mask & SELECTOR_WRITE)
    {
        if (!_RequestCallbackWrite(handler))
        {
            /* PROTOCOLEVENT_DISCONNECT */
            if (handler->isConnector && handler->isConnected)
            {
                Protocol* self = (Protocol*)handler->base.data;
                if (self->eventCallback)
                {
                    (*self->eventCallback)(self, PROTOCOLEVENT_DISCONNECT,
                        self->eventCallbackData);
                }
                handler->isConnected = MI_FALSE;
            }
            goto closeConnection;
        }
        else if (handler->isConnector && !handler->isConnected)
        {
            Protocol* self = (Protocol*)handler->base.data;
            if (self->eventCallback)
            {
                (*self->eventCallback)(self, PROTOCOLEVENT_CONNECT,
                    self->eventCallbackData);
            }
            handler->isConnected = MI_TRUE;
        }
    }

    /* Close connection by timeout */
    if (mask & SELECTOR_TIMEOUT)
        return MI_FALSE;

    if ((mask & SELECTOR_REMOVE) != 0 ||
        (mask & SELECTOR_DESTROY) != 0)
    {
        Protocol* self = (Protocol*)handler->base.data;

        _FreeAuthData(handler);

        /* free outstanding messages, batch */
        if (handler->receivingBatch)
            Batch_Destroy( handler->receivingBatch );

        handler->receivingBatch = 0;

        if (handler->message)
            Message_Release(handler->message);

        handler->message = 0;

        _RemoveAllMessages(handler);

        Sock_Close(handler->base.sock);

        /* Mark handler as closed */
        handler->base.sock = INVALID_SOCK;

        /* if connection sokcet was released, invalidate pointer to it */
        if (self && handler == self->connectorHandle)
            self->connectorHandle = 0;

        if (handler->isConnector)
            free(handler);
        else
            _Release(handler);
    }

    return MI_TRUE;

closeConnection:

    PRINTF(("LOG: closed client connection\n"));

    return MI_FALSE;
}

static MI_Boolean _ListenerCallback(
    Selector* sel,
    Handler* handler,
    MI_Uint32 mask, 
    MI_Uint64 currentTimeUsec)
{
    Protocol* self = (Protocol*)handler->data;
    MI_Result r;
    Sock s;
    Addr addr;
    Protocol_SR_SocketData* h;

    sel=sel;
    mask=mask;
    currentTimeUsec = currentTimeUsec;

    if (mask & SELECTOR_READ)
    {
        /* Accept the incoming connection */
        r = Sock_Accept(handler->sock, &s, &addr);

        if (MI_RESULT_WOULD_BLOCK == r)
            return MI_TRUE;

        if (r != MI_RESULT_OK)
        {
            LOGW((T("Sock_Accept() failed; err %d\n"), Sock_GetLastError()));
            return MI_TRUE;
        }

        r = Sock_SetBlocking(s, MI_FALSE);
        if (r != MI_RESULT_OK)
        {
            LOGW((T("Sock_SetBlocking() failed\n")));
            Sock_Close(s);
            return MI_TRUE;
        }

        /* Create handler */
        h = (Protocol_SR_SocketData*)calloc(1, sizeof(Protocol_SR_SocketData));

        if (!h)
        {
            Sock_Close(s);
            return MI_TRUE;
        }

        h->base.sock = s;
        h->base.mask = SELECTOR_READ | SELECTOR_EXCEPTION;
        h->base.callback = _RequestCallback;
        h->base.data = self;

        /* get '1' for connected */
        h->refcounter = 1;

        /* waiting for connect-request */
        h->authState = PRT_AUTH_WAIT_CONNECTION_REQUEST;

        /* Watch for read events on the incoming connection */
        r = Selector_AddHandler(self->selector, &h->base);

        if (r != MI_RESULT_OK)
        {
            LOGW((T("Selector_AddHandler() failed\n")));
            return MI_TRUE;
        }
    }

    if ((mask & SELECTOR_REMOVE) != 0 ||
        (mask & SELECTOR_DESTROY) != 0)
    {
        Sock_Close(handler->sock);
        free(handler);
    }

    return MI_TRUE;
}

static MI_Result _CreateListener(
    Sock* s,
    const char* locator)
{
    const char* posColon;

    posColon = strchr(locator, ':');

    if (!posColon)
        return Sock_CreateLocalListener(s, locator);

    /* create listener for remote address like host:port or :port (ANYADDR) */
    {
        unsigned short port = (unsigned short)atol(posColon+1);
        char host[128];
        unsigned int len = (unsigned int)(posColon - locator);
        Addr addr;
        MI_Result r;

        if (len > 0)
        {
            if (len >= sizeof(host))
                return MI_RESULT_FAILED;

            memcpy(host, locator, len);
            host[len] = 0;

            // Initialize address.
            r = Addr_Init(&addr, host, port);
            if (r != MI_RESULT_OK)
                return MI_RESULT_FAILED;
        }
        else
        {
            Addr_InitAny(&addr, port);
        }

        return Sock_CreateListener(s, &addr);
    }
}

static MI_Result _CreateConnector(
    Sock* s,
    const char* locator)
{
    const char* posColon;

    posColon = strchr(locator, ':');

    if (!posColon)
        return Sock_CreateLocalConnector(s, locator);

    /* create connector to remote address like host:port */
    {
        unsigned short port = (unsigned short)atol(posColon+1);
        char host[128];
        unsigned int len = (unsigned int)(posColon - locator);
        Addr addr;
        MI_Result r;

        if (len >= sizeof(host))
            return MI_RESULT_FAILED;

        memcpy(host, locator, len);
        host[len] = 0;

        // Initialize address.
        r = Addr_Init(&addr, host, port);
        if (r != MI_RESULT_OK)
            return MI_RESULT_FAILED;

        // Create client socket.
        r = Sock_Create(s);
        if (r != MI_RESULT_OK)
        {
            Sock_Close(*s);
            return MI_RESULT_FAILED;
        }

        r = Sock_SetBlocking(*s, MI_FALSE);
        if (r != MI_RESULT_OK)
        {
            Sock_Close(*s);
            return MI_RESULT_FAILED;
        }

        // Connect to server.
        r = Sock_Connect(*s, &addr);
        if (r != MI_RESULT_OK && r != MI_RESULT_WOULD_BLOCK)
        {
            Sock_Close(*s);
            return MI_RESULT_FAILED;
        }
        return r;
    }
}

static MI_Result _New_Protocol(
    Protocol** selfOut,
    Selector* selector, /*optional, maybe NULL*/
    ProtocolCallback callback,
    void* callbackData,
    ProtocolEventCallback eventCallback,
    void* eventCallbackData)
{
    Protocol* self;

    /* Check parameters */
    if (!selfOut)
        return MI_RESULT_INVALID_PARAMETER;

    /* Clear output parameter */
    *selfOut = NULL;

    /* Allocate structure */
    {
        self = (Protocol*)calloc(1, sizeof(Protocol));

        if (!self)
            return MI_RESULT_FAILED;
    }

    if (selector)
    {   /* attach the exisiting selector */
        self->selector = selector;
        self->internal_selector_used = MI_FALSE;
    }
    else
    {   /* creaet a new selector */
        /* Initialize the network */
        Sock_Start();

        /* Initialize the selector */
        if (Selector_Init(&self->internal_selector) != MI_RESULT_OK)
        {
            free(self);
            return MI_RESULT_FAILED;
        }
        self->selector = &self->internal_selector;
        self->internal_selector_used = MI_TRUE;
    }

    /* Save the callback and callbackData */
    self->callback = callback;
    self->callbackData = callbackData;
    self->eventCallback = eventCallback;
    self->eventCallbackData = eventCallbackData;

    /* Set the magic number */
    self->magic = _MAGIC;

    /* Set output parameter */
    *selfOut = self;
    return MI_RESULT_OK;
}

/*
**==============================================================================
**
** Public definitions:
**
**==============================================================================
*/

MI_Result Protocol_New_Listener(
    Protocol** selfOut,
    Selector* selector, /*optional, maybe NULL*/
    const char* locator,
    ProtocolCallback callback,
    void* callbackData)
{
    Protocol* self;
    MI_Result r;
    Sock listener;

    r = _New_Protocol(selfOut, selector, callback, callbackData, NULL, NULL);

    if (MI_RESULT_OK != r)
        return r;

    self = *selfOut;

    self->type = PRT_TYPE_LISTENER;

    /* Create listener socket */
    {
        r = _CreateListener(&listener, locator);

        if (r != MI_RESULT_OK)
        {
            Protocol_Delete(self);
            return r;
        }

        r = Sock_SetBlocking(listener, MI_FALSE);

        if (r != MI_RESULT_OK)
        {
            Sock_Close(listener);
            Protocol_Delete(self);
            return r;
        }
    }

    /* Watch for read events on the listener socket (client connections) */
    {
        Handler* h = (Handler*)calloc(1, sizeof(Handler));

        if (!h)
        {
            Sock_Close(listener);
            Protocol_Delete(self);
            return MI_RESULT_FAILED;
        }

        h->sock = listener;
        h->mask = SELECTOR_READ | SELECTOR_EXCEPTION;
        h->callback = _ListenerCallback;
        h->data = self;

        r = Selector_AddHandler(self->selector, h);

        if (r != MI_RESULT_OK)
        {
            Sock_Close(listener);
            free(h);
            Protocol_Delete(self);
            return r;
        }
    }

    return MI_RESULT_OK;
}

MI_Result Protocol_New_Connector(
    Protocol** selfOut,
    Selector* selector, /*optional, maybe NULL*/
    const char* locator,
    ProtocolCallback callback,
    void* callbackData,
    ProtocolEventCallback eventCallback,
    void* eventCallbackData,
    const char* user,
    const char* password)
{
    Protocol* self;
    MI_Result r;
    Sock connector;

    r = _New_Protocol(selfOut, selector, callback, callbackData,
        eventCallback, eventCallbackData);

    if (MI_RESULT_OK != r)
        return r;

    self = *selfOut;
    *selfOut = 0;

    self->type = PRT_TYPE_CONNECTOR;

    /* Create connector socket */
    {
        // Connect to server.
        r = _CreateConnector(&connector, locator);
        if (r != MI_RESULT_OK && r != MI_RESULT_WOULD_BLOCK)
        {
            Protocol_Delete(self);
            return MI_RESULT_FAILED;
        }
    }

    /* Allocating connector's structure */
    {
        Protocol_SR_SocketData* h = (Protocol_SR_SocketData*)calloc(1, sizeof(Protocol_SR_SocketData));


        if (!h)
        {
            Sock_Close(connector);
            Protocol_Delete(self);
            return MI_RESULT_FAILED;
        }

        h->base.sock = connector;
        h->base.mask = SELECTOR_READ | SELECTOR_WRITE | SELECTOR_EXCEPTION;
        h->base.callback = _RequestCallback;
        h->base.data = self;
        h->sendingQueueLength = -1; /* disable watermarks for client */
        h->isConnector = MI_TRUE;
        h->isConnected = MI_FALSE;
        h->authState = PRT_AUTH_WAIT_CONNECTION_RESPONSE;

        r = Selector_AddHandler(self->selector, &h->base);

        if (r != MI_RESULT_OK)
        {
            Sock_Close(connector);
            Protocol_Delete(self);
            free(h);
            return MI_RESULT_FAILED;
        }
        self->connectorHandle = h;

        /* send connect request */
        if (!_SendAuthRequest(h, user, password, NULL))
        {
            /* remove handler will free 'h' pointer */
            Selector_RemoveHandler(self->selector, &h->base);
            Protocol_Delete(self);
            return MI_RESULT_FAILED;
        }
    }

    /* Set output parameter */
    *selfOut = self;
    return MI_RESULT_OK;
}

MI_Result Protocol_New_From_Socket(
    Protocol** selfOut,
    Selector* selector, /*optional, maybe NULL*/
    Sock s,
    MI_Boolean skipInstanceUnpack,
    ProtocolCallback callback,
    void* callbackData,
    ProtocolEventCallback eventCallback,
    void* eventCallbackData)
{
    Protocol* self;
    MI_Result r;

    r = _New_Protocol(selfOut, selector, callback, callbackData,
        eventCallback, eventCallbackData);

    if (MI_RESULT_OK != r)
        return r;

    self = *selfOut;
    *selfOut = 0;

    self->type = PRT_TYPE_FROM_SOCKET;

    self->skipInstanceUnpack = skipInstanceUnpack;

    /* Attach provided socket to connector */
    {
        Protocol_SR_SocketData* h = (Protocol_SR_SocketData*)calloc(1, sizeof(Protocol_SR_SocketData));

        if (!h)
        {
            Protocol_Delete(self);
            return MI_RESULT_FAILED;
        }

        h->base.sock = s;
        h->base.mask = SELECTOR_READ  | SELECTOR_EXCEPTION;

        if (skipInstanceUnpack)
        {
            /* skipInstanceUnpack indicates that call made from server
                and socket connected to the agent
                In that case we can use back=pressure feature and
                ignore socket operations under stress */
            h->base.mask |= SELECTOR_IGNORE_READ_OVERLOAD;
        }

        h->base.callback = _RequestCallback;
        h->base.data = self;
        h->isConnector = MI_TRUE;
        h->isConnected = MI_TRUE;
        /* skip authentication for established connections 
            (only used in server/agent communication) */
        h->authState = PRT_AUTH_OK;

        r = Selector_AddHandler(self->selector, &h->base);

        if (r != MI_RESULT_OK)
        {
            Protocol_Delete(self);
            free(h);
            return MI_RESULT_FAILED;
        }
        self->connectorHandle = h;
    }

    /* Set output parameter */
    *selfOut = self;
    return MI_RESULT_OK;
}

MI_Result Protocol_Delete(
    Protocol* self)
{
    /* Check parameters */
    if (!self)
        return MI_RESULT_INVALID_PARAMETER;

    /* Check magic number */
    if (self->magic != _MAGIC)
        return MI_RESULT_INVALID_PARAMETER;

    if (self->internal_selector_used)
    {
        /* Release selector;
        Note: selector-destory closes all sockects in a list including connector and listener */
        Selector_Destroy(self->selector);

        /* Shutdown the network */
        Sock_Stop();
    }

    /* if connector, invalide 'self' pointr in connector */
    if (self->connectorHandle)
        self->connectorHandle->base.data = 0;

    /* Clear magic number */
    self->magic = 0xDDDDDDDD;

    /* Free self pointer */
    free(self);


    return MI_RESULT_OK;
}

MI_Result Protocol_Run(
    Protocol* self,
    MI_Uint64 timeoutUsec)
{
    /* Run the selector */
    return Selector_Run(self->selector, timeoutUsec);
}

static MI_Result _SendIN_IO_thread(
    void* self_,
    Message* message)
{
    Protocol* self = (Protocol*)self_;
    Protocol_SR_SocketData* sendSock;

    /* check params */
    if (!self || !message )
        return MI_RESULT_INVALID_PARAMETER;

    if (self->magic != _MAGIC)
    {
        LOGW((T("_SendIN_IO_thread: invalid magic!") ));
        return MI_RESULT_INVALID_PARAMETER;
    }

    /* find where to send it */
    if (self->connectorHandle)
        sendSock = self->connectorHandle;
    else
        sendSock = (Protocol_SR_SocketData*)Uint64ToPtr(message->clientID);

    /* validate handler */
    if (!sendSock || INVALID_SOCK == sendSock->base.sock) 
    {
        //LOGW((T("cannot send message: expired handler (msg->clientID) %p\n"), sendSock));

        /* connection was closed - ignore message, but release handler if needed */
        if (sendSock && Message_IsFinalRepsonse(message))
            _Release(sendSock);

        return MI_RESULT_FAILED;
    }

    /* decrement number of outstanding requests */
    if (Message_IsFinalRepsonse(message) && !sendSock->isConnector)
    {
        DEBUG_ASSERT(sendSock->refcounter > 1);
        _Release(sendSock);
    }


    /* add message to the list */
    List_Append(&sendSock->head, &sendSock->tail, (ListElem*)message);

    if (-1 != sendSock->sendingQueueLength)
        sendSock->sendingQueueLength++;

    Message_AddRef(message);

    _PrepareMessageForSending(sendSock);
    
    if (!_RequestCallbackWrite(sendSock) && !sendSock->isConnector)
    {
        //LOGW((T("cannot send message: queue overflow) %p\n"), sendSock));
        _RemoveAllMessages(sendSock);
        return MI_RESULT_FAILED;
    }

    {
        int counter = 0;

        while (sendSock->sendingQueueLength > SR_SOCKET_OUT_QUEUE_WATERMARK_CRITICAL)
        {
            _RequestCallbackWrite(sendSock);

            counter++;
            /* give system a chance to clear backlog */
            Sleep_ms(1);

            if (counter > 40000)
            {
                LOGW((T("cannot send message: queue overflow) %p\n"), sendSock));
                _RemoveAllMessages(sendSock);
                return MI_RESULT_FAILED;
            }
        }
    }
    return MI_RESULT_OK;
}

/* Signature must not have return type so we created this wrapper */
static void _SendIN_IO_thread_wrapper(void* self, Message* message)
{
    MI_Result r;
    r = _SendIN_IO_thread(self, message);

    /* ATTN: log failed result? */
}

MI_Result Protocol_Send(
    Protocol* self,
    Message* message)
{
    return Selector_CallInIOThread(
        self->selector, _SendIN_IO_thread_wrapper, self, message );
}

ViewCVS 0.9.2