Commit 3b7bffc0 authored by Ignacio Corderi's avatar Ignacio Corderi
Browse files

Sending and receiving with hmacs in 3.0.0

parent 371b3f3c
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -75,13 +75,13 @@ class BaseAsync(Client):
            raise common.ConnectionFaulted("Connection {0} is faulted. Can't receive message when connection is on a faulted state.".format(self))

        try:
            header,value = self.network_recv()
            seq = header.command.header.ackSequence
            resp,value = self.network_recv()
            seq = resp.header.ackSequence
            LOG.debug("Received message with ackSequence={0} on connection {1}.".format(seq,self))
            onSuccess,_ = self._pending[seq]
            del self._pending[seq]
            try:
                self.dispatch(onSuccess,header,value)
                self.dispatch(onSuccess,resp,value)
            except Exception as e:
                self._raise(e)
        except Exception as e:
@@ -93,22 +93,22 @@ class BaseAsync(Client):

    ### Override BaseClient methods

    def send(self, header, value):
    def send(self, command, value):
        done = threading.Event()
        class Dummy : pass
        d = Dummy()
        d.error = None
        d.result = None

        def innerSuccess(header, value):
            d.result = (header, value)
        def innerSuccess(response, value):
            d.result = (response, value)
            done.set()

        def innerError(e):
            d.error = e
            done.set()

        self.sendAsync(header, value, innerSuccess, innerError)
        self.sendAsync(command, value, innerSuccess, innerError)

        done.wait() # TODO(Nacho): should be add a default timeout?
        if d.error: raise d.error
@@ -116,7 +116,7 @@ class BaseAsync(Client):

    ###

    def sendAsync(self, header, value, onSuccess, onError):
    def sendAsync(self, command, value, onSuccess, onError):
        if self.faulted: # TODO(Nacho): should we fault through onError on fault or bow up on the callers face?
            self._raise(common.ConnectionFaulted("Can't send message when connection is on a faulted state."), onError)
            return #skip the rest
@@ -126,21 +126,21 @@ class BaseAsync(Client):
            self._raise(common.NotConnected("Not connected."), onError)
            return #skip the rest

        def innerSuccess(header, value):
        def innerSuccess(response, value):
            try:
                operations._check_status(header)
                onSuccess(header, value)
                operations._check_status(response)
                onSuccess(response, value)
            except Exception as ex:
                onError(ex)

        # get sequence
        self.update_header(header)
        self.update_header(command)

        # add callback to pending dictionary
        self._pending[header.command.header.sequence] = (innerSuccess, onError)
        self._pending[command.header.sequence] = (innerSuccess, onError)

        # transmit
        self.network_send(header, value)
        self.network_send(command, value)

    def _process(self, op, *args, **kwargs):
        if not self.isConnected: raise common.NotConnected("Must call connect() before sending operations.")
+37 −19
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ ss = socket

LOG = logging.getLogger(__name__)

def calculate_hmac(secret, message):
def calculate_hmac(secret, command):
    mac = hmac.new(secret, digestmod=sha1)

    def update(entity):
@@ -45,12 +45,11 @@ def calculate_hmac(secret, message):
        mac.update(struct.pack(">I", len(entity)))
        mac.update(entity)

    # always add command
    update(message.command)
    update(command)

    d = mac.digest()
    if LOG.isEnabledFor(logging.DEBUG):
        LOG.debug('message hmac: %s' % hexlify(d))
        LOG.debug('command hmac: %s' % hexlify(d))
    return d

class BaseClient(object):
@@ -156,16 +155,15 @@ class BaseClient(object):
        self._sequence = itertools.count()


    def update_header(self, message):
    def update_header(self, command):
        """
        Updates the message header with connection specific information.
        The unique sequence is assigned by this method.

        :param message: message to be modified (message is modified in place)
        """
        header = message.command.header
        header = command.header
        header.clusterVersion = self.cluster_version
        header.identity = self.identity
        header.connectionID = self.connection_id
        header.sequence = self._sequence.next()
        if LOG.isEnabledFor(logging.DEBUG):
@@ -207,8 +205,17 @@ class BaseClient(object):
                        raise common.ServerDisconnect('Server send disconnect')
                    i += nbytes

    def authenticate(self, command):
        m = messages.Message()
        m.commandBytes = command.SerializeToString()

    def network_send(self, header, value):
        m.authType = messages.Message.HMACAUTH
        m.hmacAuth.identity = self.identity
        m.hmacAuth.hmac = calculate_hmac(self.secret, command)

        return m

    def network_send(self, command, value):
        """
        Sends a raw message.
        The HMAC is calculated and added to the message.
@@ -219,14 +226,14 @@ class BaseClient(object):
        # fail fast on NotConnected
        self.socket

        header.hmac = calculate_hmac(self.secret, header)
        m = self.authenticate(command)

        if self.debug:
            print header
            print command

        self._send_delimited_v2(header, value)
        self._send_delimited_v2(m, value)

        return header
        return m

    def toHexString(self, array):
        return ''.join('%02x ' % ord(byte) for byte in array)
@@ -276,9 +283,6 @@ class BaseClient(object):
        proto = messages.Message()
        proto.ParseFromString(str(raw_proto))

        if self.debug:
            print proto

        return (proto, value)

    def network_recv(self):
@@ -288,13 +292,27 @@ class BaseClient(object):
        return: the message received
        """

        resp = self._recv_delimited_v2()
        (m, value) = self._recv_delimited_v2()

        if m.authType == messages.Message.HMACAUTH:
            if m.hmacAuth.identity == self.identity:
                hmac = calculate_hmac(self.secret, m.commandBytes)
                if not hmac == m.hmacAuth.hmac:
                    raise Exception('Hmac does not match')
            else:
                raise Exception('Wrong identity received!')

        resp = messages.Command()
        resp.ParseFromString(m.commandBytes)

        if self.debug:
            print resp

        # update connectionId to whatever the drive said.
        if resp[0].command.header.connectionID:
            self.connection_id = resp[0].command.header.connectionID
        if resp.header.connectionID:
            self.connection_id = resp.header.connectionID

        return resp
        return (resp, value)

    def send(self, header, value):
       self.network_send(header, value)