diff options
Diffstat (limited to 'conn.go')
-rw-r--r-- | conn.go | 270 |
1 files changed, 270 insertions, 0 deletions
@@ -0,0 +1,270 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This package provides LDAP client functions. +package ldap + +import ( + "github.com/mmitton/asn1-ber" + "crypto/tls" + "fmt" + "net" + "os" +) + +// LDAP Connection +type Conn struct { + conn net.Conn + isSSL bool + Debug bool + + chanResults map[ uint64 ] chan *ber.Packet + chanProcessMessage chan *messagePacket + chanMessageID chan uint64 +} + +// Dial connects to the given address on the given network using net.Dial +// and then returns a new Conn for the connection. +func Dial(network, addr string) (*Conn, *Error) { + c, err := net.Dial(network, "", addr) + if err != nil { + return nil, NewError( ErrorNetwork, err ) + } + conn := NewConn(c) + conn.start() + return conn, nil +} + +// Dial connects to the given address on the given network using net.Dial +// and then sets up SSL connection and returns a new Conn for the connection. +func DialSSL(network, addr string) (*Conn, *Error) { + c, err := tls.Dial(network, "", addr, nil) + if err != nil { + return nil, NewError( ErrorNetwork, err ) + } + conn := NewConn(c) + conn.isSSL = true + + conn.start() + return conn, nil +} + +// Dial connects to the given address on the given network using net.Dial +// and then starts a TLS session and returns a new Conn for the connection. +func DialTLS(network, addr string) (*Conn, *Error) { + c, err := net.Dial(network, "", addr) + if err != nil { + return nil, NewError( ErrorNetwork, err ) + } + conn := NewConn(c) + + err = conn.startTLS() + if err != nil { + conn.Close() + return nil, NewError( ErrorNetwork, err ) + } + conn.start() + return conn, nil +} + +// NewConn returns a new Conn using conn for network I/O. +func NewConn(conn net.Conn) *Conn { + return &Conn{ + conn: conn, + isSSL: false, + Debug: false, + chanResults: map[uint64] chan *ber.Packet{}, + chanProcessMessage: make( chan *messagePacket ), + chanMessageID: make( chan uint64 ), + } +} + +func (l *Conn) start() { + go l.reader() + go l.processMessages() +} + +// Close closes the connection. +func (l *Conn) Close() *Error { + if l.chanProcessMessage != nil { + message_packet := &messagePacket{ Op: MessageQuit } + l.chanProcessMessage <- message_packet + l.chanProcessMessage = nil + } + + if l.conn != nil { + err := l.conn.Close() + if err != nil { + return NewError( ErrorNetwork, err ) + } + l.conn = nil + } + return nil +} + +// Returns the next available messageID +func (l *Conn) nextMessageID() uint64 { + messageID := <-l.chanMessageID + return messageID +} + +// StartTLS sends the command to start a TLS session and then creates a new TLS Client +func (l *Conn) startTLS() *Error { + messageID := l.nextMessageID() + + if l.isSSL { + return NewError( ErrorNetwork, os.NewError( "Already encrypted" ) ) + } + + packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) + packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) + startTLS := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS" ) + startTLS.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command" ) ) + packet.AppendChild( startTLS ) + if l.Debug { + ber.PrintPacket( packet ) + } + + _, err := l.conn.Write( packet.Bytes() ) + if err != nil { + return NewError( ErrorNetwork, err ) + } + + packet, err = ber.ReadPacket( l.conn ) + if err != nil { + return NewError( ErrorNetwork, err ) + } + + if l.Debug { + if err := addLDAPDescriptions( packet ); err != nil { + return NewError( ErrorDebugging, err ) + } + ber.PrintPacket( packet ) + } + + if packet.Children[ 1 ].Children[ 0 ].Value.(uint64) == 0 { + conn := tls.Client( l.conn, nil ) + l.isSSL = true + l.conn = conn + } + + return nil +} + +const ( + MessageQuit = 0 + MessageRequest = 1 + MessageResponse = 2 + MessageFinish = 3 +) + +type messagePacket struct { + Op int + MessageID uint64 + Packet *ber.Packet + Channel chan *ber.Packet +} + +func (l *Conn) sendMessage( p *ber.Packet ) (out chan *ber.Packet, err *Error) { + message_id := p.Children[ 0 ].Value.(uint64) + out = make(chan *ber.Packet) + + message_packet := &messagePacket{ Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out } + if l.chanProcessMessage == nil { + err = NewError( ErrorNetwork, os.NewError( "Connection closed" ) ) + return + } + l.chanProcessMessage <- message_packet + return +} + +func (l *Conn) processMessages() { + defer l.closeAllChannels() + + var message_id uint64 = 1 + var message_packet *messagePacket + for { + select { + case l.chanMessageID <- message_id: + if l.conn == nil { + return + } + message_id++ + case message_packet = <-l.chanProcessMessage: + if l.conn == nil { + return + } + switch message_packet.Op { + case MessageQuit: + // Close all channels and quit + if l.Debug { + fmt.Printf( "Shutting down\n" ) + } + return + case MessageRequest: + // Add to message list and write to network + if l.Debug { + fmt.Printf( "Sending message %d\n", message_packet.MessageID ) + } + l.chanResults[ message_packet.MessageID ] = message_packet.Channel + l.conn.Write( message_packet.Packet.Bytes() ) + case MessageResponse: + // Pass back to waiting goroutine + if l.Debug { + fmt.Printf( "Receiving message %d\n", message_packet.MessageID ) + } + chanResult := l.chanResults[ message_packet.MessageID ] + if chanResult == nil { + fmt.Printf( "Unexpected Message Result: %d", message_id ) + } else { + chanResult <- message_packet.Packet + } + case MessageFinish: + // Remove from message list + if l.Debug { + fmt.Printf( "Finished message %d\n", message_packet.MessageID ) + } + l.chanResults[ message_packet.MessageID ] = nil, false + } + } + } +} + +func (l *Conn) closeAllChannels() { + for MessageID, Channel := range l.chanResults { + if l.Debug { + fmt.Printf( "Closing channel for MessageID %d\n", MessageID ); + } + close( Channel ) + l.chanResults[ MessageID ] = nil, false + } + close( l.chanMessageID ) + l.chanMessageID = nil +} + +func (l *Conn) finishMessage( MessageID uint64 ) { + message_packet := &messagePacket{ Op: MessageFinish, MessageID: MessageID } + if l.chanProcessMessage != nil { + l.chanProcessMessage <- message_packet + } +} + +func (l *Conn) reader() { + for { + p, err := ber.ReadPacket( l.conn ) + if err != nil { + if l.Debug { + fmt.Printf( "ldap.reader: %s\n", err.String() ) + } + break + } + + message_id := p.Children[ 0 ].Value.(uint64) + message_packet := &messagePacket{ Op: MessageResponse, MessageID: message_id, Packet: p } + l.chanProcessMessage <- message_packet + } + + l.Close() +} + |