diff options
Diffstat (limited to '')
-rw-r--r-- | conn.go | 149 |
1 files changed, 76 insertions, 73 deletions
@@ -32,13 +32,15 @@ type messagePacket struct { type Conn struct { conn net.Conn isSSL bool - isClosed bool + isClosing bool Debug debugging - chanConfirm chan int + chanConfirm chan bool chanResults map[uint64]chan *ber.Packet chanMessage chan *messagePacket chanMessageID chan uint64 - closeLock sync.Mutex + wgSender sync.WaitGroup + wgClose sync.WaitGroup + once sync.Once } // Dial connects to the given address on the given network using net.Dial @@ -86,10 +88,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { func NewConn(conn net.Conn) *Conn { return &Conn{ conn: conn, - isSSL: false, - isClosed: false, - Debug: false, - chanConfirm: make(chan int), + chanConfirm: make(chan bool), chanMessageID: make(chan uint64), chanMessage: make(chan *messagePacket, 10), chanResults: map[uint64]chan *ber.Packet{}, @@ -99,35 +98,33 @@ func NewConn(conn net.Conn) *Conn { func (l *Conn) start() { go l.reader() go l.processMessages() + l.wgClose.Add(1) } // Close closes the connection. -func (l *Conn) Close() *Error { - l.closeLock.Lock() - defer l.closeLock.Unlock() - - // close only once - if l.isClosed { - return nil - } - - l.Debug.Printf("Sending quit message\n") - l.chanMessage <- &messagePacket{Op: MessageQuit} - <-l.chanConfirm - l.chanConfirm = nil +func (l *Conn) Close() { + l.once.Do(func() { + l.isClosing = true + l.wgSender.Wait() + + l.Debug.Printf("Sending quit message and waiting for confirmation\n") + l.chanMessage <- &messagePacket{Op: MessageQuit} + <-l.chanConfirm + close(l.chanMessage) - l.Debug.Printf("Closing network connection\n") - if err := l.conn.Close(); err != nil { - return NewError(ErrorNetwork, err) - } + l.Debug.Printf("Closing network connection\n") + if err := l.conn.Close(); err != nil { + log.Print(err) + } - l.isClosed = true - return nil + l.conn = nil + l.wgClose.Done() + }) + l.wgClose.Wait() } // Returns the next available messageID func (l *Conn) nextMessageID() uint64 { - // l.chanMessageID will be set to nil by processMessage() if l.chanMessageID != nil { if messageID, ok := <-l.chanMessageID; ok { return messageID @@ -178,36 +175,50 @@ func (l *Conn) startTLS(config *tls.Config) *Error { } func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, *Error) { - out := make(chan *ber.Packet) - // l.chanMessage will be set to nil by processMessage() - if l.chanMessage != nil { - l.chanMessage <- &messagePacket{ - Op: MessageRequest, - MessageID: packet.Children[0].Value.(uint64), - Packet: packet, - Channel: out, - } - } else { + if l.isClosing { return nil, NewError(ErrorNetwork, errors.New("Connection closed")) } + out := make(chan *ber.Packet) + message := &messagePacket{ + Op: MessageRequest, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + Channel: out, + } + l.sendProcessMessage(message) return out, nil } +func (l *Conn) finishMessage(MessageID uint64) { + if l.isClosing { + return + } + message := &messagePacket{ + Op: MessageFinish, + MessageID: MessageID, + } + l.sendProcessMessage(message) +} + +func (l *Conn) sendProcessMessage(message *messagePacket) bool { + if l.isClosing { + return false + } + l.wgSender.Add(1) + l.chanMessage <- message + l.wgSender.Done() + return true +} + func (l *Conn) processMessages() { defer func() { for messageID, channel := range l.chanResults { - if channel != nil { - l.Debug.Printf("Closing channel for MessageID %d\n", messageID) - close(channel) - l.chanResults[messageID] = nil - } + l.Debug.Printf("Closing channel for MessageID %d\n", messageID) + close(channel) + delete(l.chanResults, messageID) } - // l.chanMessage should be closed by sender but there is more than one - close(l.chanMessage) - l.chanMessage = nil close(l.chanMessageID) - // l.chanMessageID should be set to nil by nextMessageID() but it is not a go routine - l.chanMessageID = nil + l.chanConfirm <- true close(l.chanConfirm) }() @@ -216,23 +227,25 @@ func (l *Conn) processMessages() { select { case l.chanMessageID <- messageID: messageID++ - case messagePacket := <-l.chanMessage: + case messagePacket, ok := <-l.chanMessage: + if !ok { + l.Debug.Printf("Shutting down - message channel is closed\n") + return + } switch messagePacket.Op { case MessageQuit: - l.Debug.Printf("Shutting down\n") - l.chanConfirm <- 1 + l.Debug.Printf("Shutting down - quit message received\n") return case MessageRequest: // Add to message list and write to network l.Debug.Printf("Sending message %d\n", messagePacket.MessageID) l.chanResults[messagePacket.MessageID] = messagePacket.Channel + // go routine buf := messagePacket.Packet.Bytes() - // TODO: understand for len(buf) > 0 { n, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s\n", err.Error()) - l.Close() break } // nothing else to send @@ -247,49 +260,39 @@ func (l *Conn) processMessages() { if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok { chanResult <- messagePacket.Packet } else { - log.Printf("Unexpected Message Result: %d\n", messagePacket.MessageID) + log.Printf("Received unexpected message %d\n", messagePacket.MessageID) ber.PrintPacket(messagePacket.Packet) } case MessageFinish: // Remove from message list l.Debug.Printf("Finished message %d\n", messagePacket.MessageID) - l.chanResults[messagePacket.MessageID] = nil + close(l.chanResults[messagePacket.MessageID]) + delete(l.chanResults, messagePacket.MessageID) } } } } -func (l *Conn) finishMessage(MessageID uint64) { - // l.chanMessage will be set to nil by processMessage() - if l.chanMessage != nil { - l.chanMessage <- &messagePacket{Op: MessageFinish, MessageID: MessageID} - } -} - func (l *Conn) reader() { defer func() { l.Close() - l.conn = nil }() for { packet, err := ber.ReadPacket(l.conn) if err != nil { - l.Debug.Printf("ldap.reader: %s\n", err.Error()) + l.Debug.Printf("reader: %s\n", err.Error()) return } - addLDAPDescriptions(packet) - - if l.chanMessage != nil { - l.chanMessage <- &messagePacket{ - Op: MessageResponse, - MessageID: packet.Children[0].Value.(uint64), - Packet: packet, - } - } else { - log.Printf("ldap.reader: Cannot return message, channel is already closed\n") + message := &messagePacket{ + Op: MessageResponse, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + } + if !l.sendProcessMessage(message) { return } + } } |