diff options
Diffstat (limited to 'conn.go')
-rw-r--r-- | conn.go | 18 |
1 files changed, 9 insertions, 9 deletions
@@ -39,8 +39,8 @@ func Dial(network, addr string) (*Conn, *Error) { // Dial connects to the given address on the given network using net.Dial // and then sets up SSL connection and returns a new Conn for the connection. -func DialSSL(network, addr string) (*Conn, *Error) { - c, err := tls.Dial(network, addr, nil) +func DialSSL(network, addr string, config *tls.Config) (*Conn, *Error) { + c, err := tls.Dial(network, addr, config) if err != nil { return nil, NewError(ErrorNetwork, err) } @@ -53,14 +53,14 @@ func DialSSL(network, addr string) (*Conn, *Error) { // Dial connects to the given address on the given network using net.Dial // and then starts a TLS session and returns a new Conn for the connection. -func DialTLS(network, addr string) (*Conn, *Error) { +func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { c, err := net.Dial(network, addr) if err != nil { return nil, NewError(ErrorNetwork, err) } conn := NewConn(c) - if err := conn.startTLS(); err != nil { + if err := conn.startTLS(config); err != nil { conn.Close() return nil, NewError(ErrorNetwork, err.Err) } @@ -114,7 +114,7 @@ func (l *Conn) nextMessageID() (messageID uint64) { } // StartTLS sends the command to start a TLS session and then creates a new TLS Client -func (l *Conn) startTLS() *Error { +func (l *Conn) startTLS(config *tls.Config) *Error { messageID := l.nextMessageID() if l.isSSL { @@ -123,9 +123,9 @@ func (l *Conn) startTLS() *Error { packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) - startTLS := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") - startTLS.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) - packet.AppendChild(startTLS) + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationExtendedRequest, nil, "Start TLS") + request.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, 0, "1.3.6.1.4.1.1466.20037", "TLS Extended Command")) + packet.AppendChild(request) l.Debug.PrintPacket(packet) _, err := l.conn.Write(packet.Bytes()) @@ -146,7 +146,7 @@ func (l *Conn) startTLS() *Error { } if packet.Children[1].Children[0].Value.(uint64) == 0 { - conn := tls.Client(l.conn, nil) + conn := tls.Client(l.conn, config) l.isSSL = true l.conn = conn } |