Loading kinetic/baseasync.py +14 −14 Original line number Diff line number Diff line Loading @@ -75,13 +75,13 @@ class BaseAsync(Client): raise common.ConnectionFaulted("Connection {0} is faulted. Can't receive message when connection is on a faulted state.".format(self)) try: header,value = self.network_recv() seq = header.command.header.ackSequence resp,value = self.network_recv() seq = resp.header.ackSequence LOG.debug("Received message with ackSequence={0} on connection {1}.".format(seq,self)) onSuccess,_ = self._pending[seq] del self._pending[seq] try: self.dispatch(onSuccess,header,value) self.dispatch(onSuccess,resp,value) except Exception as e: self._raise(e) except Exception as e: Loading @@ -93,22 +93,22 @@ class BaseAsync(Client): ### Override BaseClient methods def send(self, header, value): def send(self, command, value): done = threading.Event() class Dummy : pass d = Dummy() d.error = None d.result = None def innerSuccess(header, value): d.result = (header, value) def innerSuccess(response, value): d.result = (response, value) done.set() def innerError(e): d.error = e done.set() self.sendAsync(header, value, innerSuccess, innerError) self.sendAsync(command, value, innerSuccess, innerError) done.wait() # TODO(Nacho): should be add a default timeout? if d.error: raise d.error Loading @@ -116,7 +116,7 @@ class BaseAsync(Client): ### def sendAsync(self, header, value, onSuccess, onError): def sendAsync(self, command, value, onSuccess, onError): if self.faulted: # TODO(Nacho): should we fault through onError on fault or bow up on the callers face? self._raise(common.ConnectionFaulted("Can't send message when connection is on a faulted state."), onError) return #skip the rest Loading @@ -126,21 +126,21 @@ class BaseAsync(Client): self._raise(common.NotConnected("Not connected."), onError) return #skip the rest def innerSuccess(header, value): def innerSuccess(response, value): try: operations._check_status(header) onSuccess(header, value) operations._check_status(response) onSuccess(response, value) except Exception as ex: onError(ex) # get sequence self.update_header(header) self.update_header(command) # add callback to pending dictionary self._pending[header.command.header.sequence] = (innerSuccess, onError) self._pending[command.header.sequence] = (innerSuccess, onError) # transmit self.network_send(header, value) self.network_send(command, value) def _process(self, op, *args, **kwargs): if not self.isConnected: raise common.NotConnected("Must call connect() before sending operations.") Loading kinetic/baseclient.py +37 −19 Original line number Diff line number Diff line Loading @@ -33,7 +33,7 @@ ss = socket LOG = logging.getLogger(__name__) def calculate_hmac(secret, message): def calculate_hmac(secret, command): mac = hmac.new(secret, digestmod=sha1) def update(entity): Loading @@ -45,12 +45,11 @@ def calculate_hmac(secret, message): mac.update(struct.pack(">I", len(entity))) mac.update(entity) # always add command update(message.command) update(command) d = mac.digest() if LOG.isEnabledFor(logging.DEBUG): LOG.debug('message hmac: %s' % hexlify(d)) LOG.debug('command hmac: %s' % hexlify(d)) return d class BaseClient(object): Loading Loading @@ -156,16 +155,15 @@ class BaseClient(object): self._sequence = itertools.count() def update_header(self, message): def update_header(self, command): """ Updates the message header with connection specific information. The unique sequence is assigned by this method. :param message: message to be modified (message is modified in place) """ header = message.command.header header = command.header header.clusterVersion = self.cluster_version header.identity = self.identity header.connectionID = self.connection_id header.sequence = self._sequence.next() if LOG.isEnabledFor(logging.DEBUG): Loading Loading @@ -207,8 +205,17 @@ class BaseClient(object): raise common.ServerDisconnect('Server send disconnect') i += nbytes def authenticate(self, command): m = messages.Message() m.commandBytes = command.SerializeToString() def network_send(self, header, value): m.authType = messages.Message.HMACAUTH m.hmacAuth.identity = self.identity m.hmacAuth.hmac = calculate_hmac(self.secret, command) return m def network_send(self, command, value): """ Sends a raw message. The HMAC is calculated and added to the message. Loading @@ -219,14 +226,14 @@ class BaseClient(object): # fail fast on NotConnected self.socket header.hmac = calculate_hmac(self.secret, header) m = self.authenticate(command) if self.debug: print header print command self._send_delimited_v2(header, value) self._send_delimited_v2(m, value) return header return m def toHexString(self, array): return ''.join('%02x ' % ord(byte) for byte in array) Loading Loading @@ -276,9 +283,6 @@ class BaseClient(object): proto = messages.Message() proto.ParseFromString(str(raw_proto)) if self.debug: print proto return (proto, value) def network_recv(self): Loading @@ -288,13 +292,27 @@ class BaseClient(object): return: the message received """ resp = self._recv_delimited_v2() (m, value) = self._recv_delimited_v2() if m.authType == messages.Message.HMACAUTH: if m.hmacAuth.identity == self.identity: hmac = calculate_hmac(self.secret, m.commandBytes) if not hmac == m.hmacAuth.hmac: raise Exception('Hmac does not match') else: raise Exception('Wrong identity received!') resp = messages.Command() resp.ParseFromString(m.commandBytes) if self.debug: print resp # update connectionId to whatever the drive said. if resp[0].command.header.connectionID: self.connection_id = resp[0].command.header.connectionID if resp.header.connectionID: self.connection_id = resp.header.connectionID return resp return (resp, value) def send(self, header, value): self.network_send(header, value) Loading Loading
kinetic/baseasync.py +14 −14 Original line number Diff line number Diff line Loading @@ -75,13 +75,13 @@ class BaseAsync(Client): raise common.ConnectionFaulted("Connection {0} is faulted. Can't receive message when connection is on a faulted state.".format(self)) try: header,value = self.network_recv() seq = header.command.header.ackSequence resp,value = self.network_recv() seq = resp.header.ackSequence LOG.debug("Received message with ackSequence={0} on connection {1}.".format(seq,self)) onSuccess,_ = self._pending[seq] del self._pending[seq] try: self.dispatch(onSuccess,header,value) self.dispatch(onSuccess,resp,value) except Exception as e: self._raise(e) except Exception as e: Loading @@ -93,22 +93,22 @@ class BaseAsync(Client): ### Override BaseClient methods def send(self, header, value): def send(self, command, value): done = threading.Event() class Dummy : pass d = Dummy() d.error = None d.result = None def innerSuccess(header, value): d.result = (header, value) def innerSuccess(response, value): d.result = (response, value) done.set() def innerError(e): d.error = e done.set() self.sendAsync(header, value, innerSuccess, innerError) self.sendAsync(command, value, innerSuccess, innerError) done.wait() # TODO(Nacho): should be add a default timeout? if d.error: raise d.error Loading @@ -116,7 +116,7 @@ class BaseAsync(Client): ### def sendAsync(self, header, value, onSuccess, onError): def sendAsync(self, command, value, onSuccess, onError): if self.faulted: # TODO(Nacho): should we fault through onError on fault or bow up on the callers face? self._raise(common.ConnectionFaulted("Can't send message when connection is on a faulted state."), onError) return #skip the rest Loading @@ -126,21 +126,21 @@ class BaseAsync(Client): self._raise(common.NotConnected("Not connected."), onError) return #skip the rest def innerSuccess(header, value): def innerSuccess(response, value): try: operations._check_status(header) onSuccess(header, value) operations._check_status(response) onSuccess(response, value) except Exception as ex: onError(ex) # get sequence self.update_header(header) self.update_header(command) # add callback to pending dictionary self._pending[header.command.header.sequence] = (innerSuccess, onError) self._pending[command.header.sequence] = (innerSuccess, onError) # transmit self.network_send(header, value) self.network_send(command, value) def _process(self, op, *args, **kwargs): if not self.isConnected: raise common.NotConnected("Must call connect() before sending operations.") Loading
kinetic/baseclient.py +37 −19 Original line number Diff line number Diff line Loading @@ -33,7 +33,7 @@ ss = socket LOG = logging.getLogger(__name__) def calculate_hmac(secret, message): def calculate_hmac(secret, command): mac = hmac.new(secret, digestmod=sha1) def update(entity): Loading @@ -45,12 +45,11 @@ def calculate_hmac(secret, message): mac.update(struct.pack(">I", len(entity))) mac.update(entity) # always add command update(message.command) update(command) d = mac.digest() if LOG.isEnabledFor(logging.DEBUG): LOG.debug('message hmac: %s' % hexlify(d)) LOG.debug('command hmac: %s' % hexlify(d)) return d class BaseClient(object): Loading Loading @@ -156,16 +155,15 @@ class BaseClient(object): self._sequence = itertools.count() def update_header(self, message): def update_header(self, command): """ Updates the message header with connection specific information. The unique sequence is assigned by this method. :param message: message to be modified (message is modified in place) """ header = message.command.header header = command.header header.clusterVersion = self.cluster_version header.identity = self.identity header.connectionID = self.connection_id header.sequence = self._sequence.next() if LOG.isEnabledFor(logging.DEBUG): Loading Loading @@ -207,8 +205,17 @@ class BaseClient(object): raise common.ServerDisconnect('Server send disconnect') i += nbytes def authenticate(self, command): m = messages.Message() m.commandBytes = command.SerializeToString() def network_send(self, header, value): m.authType = messages.Message.HMACAUTH m.hmacAuth.identity = self.identity m.hmacAuth.hmac = calculate_hmac(self.secret, command) return m def network_send(self, command, value): """ Sends a raw message. The HMAC is calculated and added to the message. Loading @@ -219,14 +226,14 @@ class BaseClient(object): # fail fast on NotConnected self.socket header.hmac = calculate_hmac(self.secret, header) m = self.authenticate(command) if self.debug: print header print command self._send_delimited_v2(header, value) self._send_delimited_v2(m, value) return header return m def toHexString(self, array): return ''.join('%02x ' % ord(byte) for byte in array) Loading Loading @@ -276,9 +283,6 @@ class BaseClient(object): proto = messages.Message() proto.ParseFromString(str(raw_proto)) if self.debug: print proto return (proto, value) def network_recv(self): Loading @@ -288,13 +292,27 @@ class BaseClient(object): return: the message received """ resp = self._recv_delimited_v2() (m, value) = self._recv_delimited_v2() if m.authType == messages.Message.HMACAUTH: if m.hmacAuth.identity == self.identity: hmac = calculate_hmac(self.secret, m.commandBytes) if not hmac == m.hmacAuth.hmac: raise Exception('Hmac does not match') else: raise Exception('Wrong identity received!') resp = messages.Command() resp.ParseFromString(m.commandBytes) if self.debug: print resp # update connectionId to whatever the drive said. if resp[0].command.header.connectionID: self.connection_id = resp[0].command.header.connectionID if resp.header.connectionID: self.connection_id = resp.header.connectionID return resp return (resp, value) def send(self, header, value): self.network_send(header, value) Loading