summaryrefslogtreecommitdiffstats
path: root/conn.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--conn.go149
1 files changed, 76 insertions, 73 deletions
diff --git a/conn.go b/conn.go
index d62342e..bb9f6b5 100644
--- a/conn.go
+++ b/conn.go
@@ -32,13 +32,15 @@ type messagePacket struct {
type Conn struct {
conn net.Conn
isSSL bool
- isClosed bool
+ isClosing bool
Debug debugging
- chanConfirm chan int
+ chanConfirm chan bool
chanResults map[uint64]chan *ber.Packet
chanMessage chan *messagePacket
chanMessageID chan uint64
- closeLock sync.Mutex
+ wgSender sync.WaitGroup
+ wgClose sync.WaitGroup
+ once sync.Once
}
// Dial connects to the given address on the given network using net.Dial
@@ -86,10 +88,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) {
func NewConn(conn net.Conn) *Conn {
return &Conn{
conn: conn,
- isSSL: false,
- isClosed: false,
- Debug: false,
- chanConfirm: make(chan int),
+ chanConfirm: make(chan bool),
chanMessageID: make(chan uint64),
chanMessage: make(chan *messagePacket, 10),
chanResults: map[uint64]chan *ber.Packet{},
@@ -99,35 +98,33 @@ func NewConn(conn net.Conn) *Conn {
func (l *Conn) start() {
go l.reader()
go l.processMessages()
+ l.wgClose.Add(1)
}
// Close closes the connection.
-func (l *Conn) Close() *Error {
- l.closeLock.Lock()
- defer l.closeLock.Unlock()
-
- // close only once
- if l.isClosed {
- return nil
- }
-
- l.Debug.Printf("Sending quit message\n")
- l.chanMessage <- &messagePacket{Op: MessageQuit}
- <-l.chanConfirm
- l.chanConfirm = nil
+func (l *Conn) Close() {
+ l.once.Do(func() {
+ l.isClosing = true
+ l.wgSender.Wait()
+
+ l.Debug.Printf("Sending quit message and waiting for confirmation\n")
+ l.chanMessage <- &messagePacket{Op: MessageQuit}
+ <-l.chanConfirm
+ close(l.chanMessage)
- l.Debug.Printf("Closing network connection\n")
- if err := l.conn.Close(); err != nil {
- return NewError(ErrorNetwork, err)
- }
+ l.Debug.Printf("Closing network connection\n")
+ if err := l.conn.Close(); err != nil {
+ log.Print(err)
+ }
- l.isClosed = true
- return nil
+ l.conn = nil
+ l.wgClose.Done()
+ })
+ l.wgClose.Wait()
}
// Returns the next available messageID
func (l *Conn) nextMessageID() uint64 {
- // l.chanMessageID will be set to nil by processMessage()
if l.chanMessageID != nil {
if messageID, ok := <-l.chanMessageID; ok {
return messageID
@@ -178,36 +175,50 @@ func (l *Conn) startTLS(config *tls.Config) *Error {
}
func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, *Error) {
- out := make(chan *ber.Packet)
- // l.chanMessage will be set to nil by processMessage()
- if l.chanMessage != nil {
- l.chanMessage <- &messagePacket{
- Op: MessageRequest,
- MessageID: packet.Children[0].Value.(uint64),
- Packet: packet,
- Channel: out,
- }
- } else {
+ if l.isClosing {
return nil, NewError(ErrorNetwork, errors.New("Connection closed"))
}
+ out := make(chan *ber.Packet)
+ message := &messagePacket{
+ Op: MessageRequest,
+ MessageID: packet.Children[0].Value.(uint64),
+ Packet: packet,
+ Channel: out,
+ }
+ l.sendProcessMessage(message)
return out, nil
}
+func (l *Conn) finishMessage(MessageID uint64) {
+ if l.isClosing {
+ return
+ }
+ message := &messagePacket{
+ Op: MessageFinish,
+ MessageID: MessageID,
+ }
+ l.sendProcessMessage(message)
+}
+
+func (l *Conn) sendProcessMessage(message *messagePacket) bool {
+ if l.isClosing {
+ return false
+ }
+ l.wgSender.Add(1)
+ l.chanMessage <- message
+ l.wgSender.Done()
+ return true
+}
+
func (l *Conn) processMessages() {
defer func() {
for messageID, channel := range l.chanResults {
- if channel != nil {
- l.Debug.Printf("Closing channel for MessageID %d\n", messageID)
- close(channel)
- l.chanResults[messageID] = nil
- }
+ l.Debug.Printf("Closing channel for MessageID %d\n", messageID)
+ close(channel)
+ delete(l.chanResults, messageID)
}
- // l.chanMessage should be closed by sender but there is more than one
- close(l.chanMessage)
- l.chanMessage = nil
close(l.chanMessageID)
- // l.chanMessageID should be set to nil by nextMessageID() but it is not a go routine
- l.chanMessageID = nil
+ l.chanConfirm <- true
close(l.chanConfirm)
}()
@@ -216,23 +227,25 @@ func (l *Conn) processMessages() {
select {
case l.chanMessageID <- messageID:
messageID++
- case messagePacket := <-l.chanMessage:
+ case messagePacket, ok := <-l.chanMessage:
+ if !ok {
+ l.Debug.Printf("Shutting down - message channel is closed\n")
+ return
+ }
switch messagePacket.Op {
case MessageQuit:
- l.Debug.Printf("Shutting down\n")
- l.chanConfirm <- 1
+ l.Debug.Printf("Shutting down - quit message received\n")
return
case MessageRequest:
// Add to message list and write to network
l.Debug.Printf("Sending message %d\n", messagePacket.MessageID)
l.chanResults[messagePacket.MessageID] = messagePacket.Channel
+ // go routine
buf := messagePacket.Packet.Bytes()
- // TODO: understand
for len(buf) > 0 {
n, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s\n", err.Error())
- l.Close()
break
}
// nothing else to send
@@ -247,49 +260,39 @@ func (l *Conn) processMessages() {
if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
chanResult <- messagePacket.Packet
} else {
- log.Printf("Unexpected Message Result: %d\n", messagePacket.MessageID)
+ log.Printf("Received unexpected message %d\n", messagePacket.MessageID)
ber.PrintPacket(messagePacket.Packet)
}
case MessageFinish:
// Remove from message list
l.Debug.Printf("Finished message %d\n", messagePacket.MessageID)
- l.chanResults[messagePacket.MessageID] = nil
+ close(l.chanResults[messagePacket.MessageID])
+ delete(l.chanResults, messagePacket.MessageID)
}
}
}
}
-func (l *Conn) finishMessage(MessageID uint64) {
- // l.chanMessage will be set to nil by processMessage()
- if l.chanMessage != nil {
- l.chanMessage <- &messagePacket{Op: MessageFinish, MessageID: MessageID}
- }
-}
-
func (l *Conn) reader() {
defer func() {
l.Close()
- l.conn = nil
}()
for {
packet, err := ber.ReadPacket(l.conn)
if err != nil {
- l.Debug.Printf("ldap.reader: %s\n", err.Error())
+ l.Debug.Printf("reader: %s\n", err.Error())
return
}
-
addLDAPDescriptions(packet)
-
- if l.chanMessage != nil {
- l.chanMessage <- &messagePacket{
- Op: MessageResponse,
- MessageID: packet.Children[0].Value.(uint64),
- Packet: packet,
- }
- } else {
- log.Printf("ldap.reader: Cannot return message, channel is already closed\n")
+ message := &messagePacket{
+ Op: MessageResponse,
+ MessageID: packet.Children[0].Value.(uint64),
+ Packet: packet,
+ }
+ if !l.sendProcessMessage(message) {
return
}
+
}
}