Commit f9f8a60c authored by Zhu Yong's avatar Zhu Yong
Browse files

ResponseHandler will wait for handler to be called, fix for multithread condition

parent 3d4e18b6
Loading
Loading
Loading
Loading
+34 −68
Original line number Diff line number Diff line
@@ -27,11 +27,9 @@ func (conn *BlockConnection) NoOp() (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) get(key []byte, getCmd kproto.Command_MessageType) (*Record, Status, error) {
@@ -51,11 +49,9 @@ func (conn *BlockConnection) get(key []byte, getCmd kproto.Command_MessageType)
		return nil, callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return &callback.Entry, callback.Status(), nil
	return &callback.Entry, callback.Status(), err
}

func (conn *BlockConnection) Get(key []byte) (*Record, Status, error) {
@@ -78,11 +74,9 @@ func (conn *BlockConnection) GetKeyRange(r *KeyRange) ([][]byte, Status, error)
		return nil, callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Keys, callback.Status(), nil
	return callback.Keys, callback.Status(), err
}

func (conn *BlockConnection) GetVersion(key []byte) ([]byte, Status, error) {
@@ -93,11 +87,9 @@ func (conn *BlockConnection) GetVersion(key []byte) ([]byte, Status, error) {
		return nil, callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Version, callback.Status(), nil
	return callback.Version, callback.Status(), err
}

func (conn *BlockConnection) Flush() (Status, error) {
@@ -108,11 +100,9 @@ func (conn *BlockConnection) Flush() (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) Delete(entry *Record) (Status, error) {
@@ -123,11 +113,9 @@ func (conn *BlockConnection) Delete(entry *Record) (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) Put(entry *Record) (Status, error) {
@@ -138,11 +126,9 @@ func (conn *BlockConnection) Put(entry *Record) (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) P2PPush(request *P2PPushRequest) ([]Status, Status, error) {
@@ -153,11 +139,9 @@ func (conn *BlockConnection) P2PPush(request *P2PPushRequest) ([]Status, Status,
		return nil, callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Statuses, callback.Status(), nil
	return callback.Statuses, callback.Status(), err
}

func (conn *BlockConnection) GetLog(logs []LogType) (*Log, Status, error) {
@@ -168,11 +152,9 @@ func (conn *BlockConnection) GetLog(logs []LogType) (*Log, Status, error) {
		return nil, callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return &callback.Logs, callback.Status(), nil
	return &callback.Logs, callback.Status(), err
}

func (conn *BlockConnection) pinop(pin []byte, op kproto.Command_PinOperation_PinOpType) (Status, error) {
@@ -194,11 +176,9 @@ func (conn *BlockConnection) pinop(pin []byte, op kproto.Command_PinOperation_Pi
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) SecureErase(pin []byte) (Status, error) {
@@ -226,11 +206,9 @@ func (conn *BlockConnection) UpdateFirmware(code []byte) (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) SetClusterVersion(version int64) (Status, error) {
@@ -241,11 +219,9 @@ func (conn *BlockConnection) SetClusterVersion(version int64) (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) SetLockPin(currentPin []byte, newPin []byte) (Status, error) {
@@ -256,11 +232,9 @@ func (conn *BlockConnection) SetLockPin(currentPin []byte, newPin []byte) (Statu
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) SetErasePin(currentPin []byte, newPin []byte) (Status, error) {
@@ -271,11 +245,9 @@ func (conn *BlockConnection) SetErasePin(currentPin []byte, newPin []byte) (Stat
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) SetACL(acls []SecurityACL) (Status, error) {
@@ -286,11 +258,9 @@ func (conn *BlockConnection) SetACL(acls []SecurityACL) (Status, error) {
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) MediaScan(op *MediaOperation, pri Priority) (Status, error) {
@@ -301,11 +271,9 @@ func (conn *BlockConnection) MediaScan(op *MediaOperation, pri Priority) (Status
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) MediaOptimize(op *MediaOperation, pri Priority) (Status, error) {
@@ -316,11 +284,9 @@ func (conn *BlockConnection) MediaOptimize(op *MediaOperation, pri Priority) (St
		return callback.Status(), err
	}

	for callback.Done() == false {
		conn.nbc.Run()
	}
	err = conn.nbc.Listen(h)

	return callback.Status(), nil
	return callback.Status(), err
}

func (conn *BlockConnection) Close() {
+23 −4
Original line number Diff line number Diff line
@@ -2,14 +2,17 @@ package kinetic

import (
	kproto "github.com/yongzhy/kinetic-go/proto"
	"sync"
)

// ResponseHandler is the handler for XXXXX_RESPONSE message from drive.
type ResponseHandler struct {
	callback Callback
	done     bool
	cond     *sync.Cond
}

func (h *ResponseHandler) Handle(cmd *kproto.Command, value []byte) error {
func (h *ResponseHandler) handle(cmd *kproto.Command, value []byte) error {
	if h.callback != nil {
		if cmd.Status != nil && cmd.Status.Code != nil {
			if cmd.GetStatus().GetCode() == kproto.Command_Status_SUCCESS {
@@ -18,18 +21,34 @@ func (h *ResponseHandler) Handle(cmd *kproto.Command, value []byte) error {
				h.callback.Failure(getStatusFromProto(cmd))
			}
		} else {
			klog.Info("Other status received")
			klog.Warn("Other status received")
			klog.Info("%v", cmd)
		}

	}
	h.cond.L.Lock()
	h.done = true
	h.cond.Signal()
	h.cond.L.Unlock()
	return nil
}

func (h *ResponseHandler) Error(s Status) {
func (h *ResponseHandler) fail(s Status) {
	if h.callback != nil {
		h.callback.Failure(s)
	}
	h.cond.L.Lock()
	h.done = true
	h.cond.Signal()
	h.cond.L.Unlock()
}

func (h *ResponseHandler) wait() {
	h.cond.L.Lock()
	if h.done == false {
		h.cond.Wait()
	}
	h.cond.L.Unlock()
}

func (h *ResponseHandler) SetCallback(call Callback) {
@@ -38,6 +57,6 @@ func (h *ResponseHandler) SetCallback(call Callback) {

// Helper function to build a ResponseHandler with call as the Callback.
func NewResponseHandler(call Callback) *ResponseHandler {
	h := &ResponseHandler{callback: call}
	h := &ResponseHandler{callback: call, done: false, cond: sync.NewCond(&sync.Mutex{})}
	return h
}
+4 −2
Original line number Diff line number Diff line
@@ -356,8 +356,10 @@ func (conn *NonBlockConnection) MediaOptimize(op *MediaOperation, pri Priority,
	return conn.service.submit(msg, cmd, nil, h)
}

func (conn *NonBlockConnection) Run() error {
	return conn.service.listen()
func (conn *NonBlockConnection) Listen(h *ResponseHandler) error {
	err := conn.service.listen()
	h.wait()
	return err
}

func (conn *NonBlockConnection) Close() {
+32 −21
Original line number Diff line number Diff line
@@ -37,8 +37,9 @@ func newCommand(t kproto.Command_MessageType) *kproto.Command {
}

type networkService struct {
	rmutex sync.Mutex
	wmutex sync.Mutex
	rxMu   sync.Mutex
	txMu   sync.Mutex
	mapMu  sync.Mutex
	conn   net.Conn
	seq    int64                      // Operation sequence ID
	connId int64                      // current conection ID
@@ -64,9 +65,11 @@ func newNetworkService(op ClientOptions) (*networkService, error) {
		fatal:  false,
	}

	ns.rmutex.Lock()
	ns.rxMu.Lock()
	// Do the handshake.
	// TODO: we can store the Device Configuration and Limits from handshake
	_, _, _, err = ns.receive()
	ns.rmutex.Unlock()
	ns.rxMu.Unlock()

	if err != nil {
		klog.Error("Can't establish connection to %s", op.Host)
@@ -79,15 +82,15 @@ func newNetworkService(op ClientOptions) (*networkService, error) {
// When client network service has error, call error handling
// from all Messagehandler current in Queue.
func (ns *networkService) clientError(s Status, mh *ResponseHandler) {
	ns.mapMu.Lock()
	for ack, h := range ns.hmap {
		if h.callback != nil {
			h.callback.Failure(s)
		}
		h.fail(s)
		delete(ns.hmap, ack)
	}
	ns.mapMu.Unlock()

	if mh != nil && mh.callback != nil {
		mh.callback.Failure(s)
	if mh != nil {
		mh.fail(s)
	}
}

@@ -96,19 +99,22 @@ func (ns *networkService) listen() error {
		return errors.New("Can't listen, network service has fatal error")
	}

	ns.mapMu.Lock()
	if len(ns.hmap) == 0 {
		ns.mapMu.Unlock()
		return nil
	}
	ns.mapMu.Unlock()

	ns.rmutex.Lock()
	ns.rxMu.Lock()
	msg, cmd, value, err := ns.receive()
	ns.rmutex.Unlock()
	ns.rxMu.Unlock()
	if err != nil {
		klog.Error("Network Service listen error")
		return err
	}

	klog.Info("Kinetic response received ", cmd.GetHeader().GetMessageType().String(),
	klog.Debug("Kinetic response received ", cmd.GetHeader().GetMessageType().String(),
		", AckSeq = ", cmd.GetHeader().GetAckSequence(),
		", Code = ", cmd.GetStatus().GetCode())

@@ -119,17 +125,19 @@ func (ns *networkService) listen() error {
	}

	ack := cmd.GetHeader().GetAckSequence()
	ns.mapMu.Lock()
	h, ok := ns.hmap[ack]
	ns.mapMu.Unlock()
	if ok == false {
		klog.Warn("Couldn't find a handler for acksequence ", ack)
		klog.Error("Couldn't find a handler for acksequence ", ack)
		return nil
	}

	(*h).Handle(cmd, value)
	h.handle(cmd, value)

	ns.wmutex.Lock()
	ns.mapMu.Lock()
	delete(ns.hmap, ack)
	ns.wmutex.Unlock()
	ns.mapMu.Unlock()

	return nil
}
@@ -142,7 +150,7 @@ func (ns *networkService) submit(msg *kproto.Message, cmd *kproto.Command, value
		return errors.New("Valid ResponseHandler is required")
	}

	ns.wmutex.Lock()
	ns.txMu.Lock()

	cmd.GetHeader().ConnectionID = &ns.connId
	cmd.GetHeader().Sequence = &ns.seq
@@ -160,17 +168,20 @@ func (ns *networkService) submit(msg *kproto.Message, cmd *kproto.Command, value
		msg.GetHmacAuth().Hmac = compute_hmac(msg.CommandBytes, ns.option.Hmac)
	}

	klog.Info("Kinetic message send ", cmd.GetHeader().GetMessageType().String(), " Seq = ", ns.seq)
	klog.Debug("Kinetic message send ", cmd.GetHeader().GetMessageType().String(), " Seq = ", ns.seq)

	err = ns.send(msg, value)

	if err != nil {
		return err
	}

	ns.mapMu.Lock()
	ns.hmap[ns.seq] = h
	ns.seq++
	ns.mapMu.Unlock()

	ns.wmutex.Unlock()
	ns.seq++
	ns.txMu.Unlock()

	return nil
}
@@ -296,5 +307,5 @@ func (ns *networkService) receive() (*kproto.Message, *kproto.Command, []byte, e

func (ns *networkService) close() {
	ns.conn.Close()
	klog.Infof("Connection to %s closed", ns.option.Host)
	klog.Debug("Connection to %s closed", ns.option.Host)
}