diff options
Diffstat (limited to 'conn.go')
-rw-r--r-- | conn.go | 234 |
1 files changed, 120 insertions, 114 deletions
@@ -14,15 +14,31 @@ import ( "sync" ) +const ( + MessageQuit = 0 + MessageRequest = 1 + MessageResponse = 2 + MessageFinish = 3 +) + +type messagePacket struct { + Op int + MessageID uint64 + Packet *ber.Packet + Channel chan *ber.Packet +} + // LDAP Connection type Conn struct { - conn net.Conn - isSSL bool - Debug debugging - chanResults map[uint64]chan *ber.Packet - chanProcessMessage chan *messagePacket - chanMessageID chan uint64 - closeLock sync.Mutex + conn net.Conn + isSSL bool + isClosed bool + Debug debugging + chanConfirm chan int + chanResults map[uint64]chan *ber.Packet + chanMessage chan *messagePacket + chanMessageID chan uint64 + closeLock sync.Mutex } // Dial connects to the given address on the given network using net.Dial @@ -46,7 +62,6 @@ func DialSSL(network, addr string, config *tls.Config) (*Conn, *Error) { } conn := NewConn(c) conn.isSSL = true - conn.start() return conn, nil } @@ -59,7 +74,6 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { return nil, NewError(ErrorNetwork, err) } conn := NewConn(c) - if err := conn.startTLS(config); err != nil { conn.Close() return nil, NewError(ErrorNetwork, err.Err) @@ -71,12 +85,14 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn) *Conn { return &Conn{ - conn: conn, - isSSL: false, - Debug: false, - chanResults: map[uint64]chan *ber.Packet{}, - chanProcessMessage: make(chan *messagePacket), - chanMessageID: make(chan uint64), + conn: conn, + isSSL: false, + isClosed: false, + Debug: false, + chanConfirm: make(chan int), + chanMessageID: make(chan uint64), + chanMessage: make(chan *messagePacket, 10), + chanResults: map[uint64]chan *ber.Packet{}, } } @@ -90,27 +106,34 @@ func (l *Conn) Close() *Error { l.closeLock.Lock() defer l.closeLock.Unlock() - l.sendProcessMessage(&messagePacket{Op: MessageQuit}) + // close only once + if l.isClosed { + return nil + } - if l.conn != nil { - err := l.conn.Close() - if err != nil { - return NewError(ErrorNetwork, err) - } - l.conn = nil + l.Debug.Printf("Sending quit message\n") + l.chanMessage <- &messagePacket{Op: MessageQuit} + <-l.chanConfirm + l.chanConfirm = nil + + l.Debug.Printf("Closing network connection\n") + if err := l.conn.Close(); err != nil { + return NewError(ErrorNetwork, err) } + + l.isClosed = true return nil } // Returns the next available messageID -func (l *Conn) nextMessageID() (messageID uint64) { - defer func() { - if r := recover(); r != nil { - messageID = 0 +func (l *Conn) nextMessageID() uint64 { + // l.chanMessageID will be set to nil by processMessage() + if l.chanMessageID != nil { + if messageID, ok := <-l.chanMessageID; ok { + return messageID } - }() - messageID = <-l.chanMessageID - return + } + return 0 } // StartTLS sends the command to start a TLS session and then creates a new TLS Client @@ -154,136 +177,119 @@ func (l *Conn) startTLS(config *tls.Config) *Error { return nil } -const ( - MessageQuit = 0 - MessageRequest = 1 - MessageResponse = 2 - MessageFinish = 3 -) - -type messagePacket struct { - Op int - MessageID uint64 - Packet *ber.Packet - Channel chan *ber.Packet -} - -func (l *Conn) sendMessage(p *ber.Packet) (out chan *ber.Packet, err *Error) { - message_id := p.Children[0].Value.(uint64) - out = make(chan *ber.Packet) - - if l.chanProcessMessage == nil { - err = NewError(ErrorNetwork, errors.New("Connection closed")) - return +func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, *Error) { + out := make(chan *ber.Packet) + // l.chanMessage will be set to nil by processMessage() + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{ + Op: MessageRequest, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + Channel: out, + } + } else { + return nil, NewError(ErrorNetwork, errors.New("Connection closed")) } - message_packet := &messagePacket{Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out} - l.sendProcessMessage(message_packet) - return + return out, nil } func (l *Conn) processMessages() { - defer l.closeAllChannels() + defer func() { + for messageID, channel := range l.chanResults { + if channel != nil { + l.Debug.Printf("Closing channel for MessageID %d\n", messageID) + close(channel) + l.chanResults[messageID] = nil + } + } + // l.chanMessage should be closed by sender but there is more than one + close(l.chanMessage) + l.chanMessage = nil + close(l.chanMessageID) + // l.chanMessageID should be set to nil by nextMessageID() but it is not a go routine + l.chanMessageID = nil + close(l.chanConfirm) + }() - var message_id uint64 = 1 - var message_packet *messagePacket + var messageID uint64 = 1 for { select { - case l.chanMessageID <- message_id: - if l.conn == nil { - return - } - message_id++ - case message_packet = <-l.chanProcessMessage: - if l.conn == nil { - return - } - switch message_packet.Op { + case l.chanMessageID <- messageID: + messageID++ + case messagePacket := <-l.chanMessage: + switch messagePacket.Op { case MessageQuit: - // Close all channels and quit l.Debug.Printf("Shutting down\n") + l.chanConfirm <- 1 return case MessageRequest: // Add to message list and write to network - l.Debug.Printf("Sending message %d\n", message_packet.MessageID) - l.chanResults[message_packet.MessageID] = message_packet.Channel - buf := message_packet.Packet.Bytes() + l.Debug.Printf("Sending message %d\n", messagePacket.MessageID) + l.chanResults[messagePacket.MessageID] = messagePacket.Channel + buf := messagePacket.Packet.Bytes() + // TODO: understand for len(buf) > 0 { n, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s\n", err.Error()) - return + l.Close() + break } + // nothing else to send if n == len(buf) { break } + // the remaining buf content buf = buf[n:] } case MessageResponse: - // Pass back to waiting goroutine - l.Debug.Printf("Receiving message %d\n", message_packet.MessageID) - if chanResult, ok := l.chanResults[message_packet.MessageID]; ok { - // If the "Search Result Done" is read before the - // "Search Result Entry" no Entry can be returned - // go func() { chanResult <- message_packet.Packet }() - chanResult <- message_packet.Packet + l.Debug.Printf("Receiving message %d\n", messagePacket.MessageID) + if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok { + chanResult <- messagePacket.Packet } else { - log.Printf("Unexpected Message Result: %d\n", message_id) - ber.PrintPacket(message_packet.Packet) + log.Printf("Unexpected Message Result: %d\n", messagePacket.MessageID) + ber.PrintPacket(messagePacket.Packet) } case MessageFinish: // Remove from message list - l.Debug.Printf("Finished message %d\n", message_packet.MessageID) - l.chanResults[message_packet.MessageID] = nil + l.Debug.Printf("Finished message %d\n", messagePacket.MessageID) + l.chanResults[messagePacket.MessageID] = nil } } } } -func (l *Conn) closeAllChannels() { - log.Printf("closeAllChannels\n") - for messageID, channel := range l.chanResults { - if channel != nil { - l.Debug.Printf("Closing channel for MessageID %d\n", messageID) - close(channel) - l.chanResults[messageID] = nil - } - } - close(l.chanMessageID) - l.chanMessageID = nil - - close(l.chanProcessMessage) - l.chanProcessMessage = nil -} - func (l *Conn) finishMessage(MessageID uint64) { - message_packet := &messagePacket{Op: MessageFinish, MessageID: MessageID} - l.sendProcessMessage(message_packet) + // l.chanMessage will be set to nil by processMessage() + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{Op: MessageFinish, MessageID: MessageID} + } } func (l *Conn) reader() { - defer l.Close() + defer func() { + l.Close() + l.conn = nil + }() + for { - p, err := ber.ReadPacket(l.conn) + packet, err := ber.ReadPacket(l.conn) if err != nil { l.Debug.Printf("ldap.reader: %s\n", err.Error()) return } - addLDAPDescriptions(p) + addLDAPDescriptions(packet) - message_id := p.Children[0].Value.(uint64) - message_packet := &messagePacket{Op: MessageResponse, MessageID: message_id, Packet: p} - if l.chanProcessMessage != nil { - l.chanProcessMessage <- message_packet + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{ + Op: MessageResponse, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + } } else { - log.Printf("ldap.reader: Cannot return message\n") + log.Printf("ldap.reader: Cannot return message, channel is already closed\n") return } } } - -func (l *Conn) sendProcessMessage(message *messagePacket) { - if l.chanProcessMessage != nil { - go func() { l.chanProcessMessage <- message }() - } -} |