summaryrefslogtreecommitdiffstats
path: root/server_search.go
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--server_search.go216
1 files changed, 216 insertions, 0 deletions
diff --git a/server_search.go b/server_search.go
new file mode 100644
index 0000000..a7d78ac
--- /dev/null
+++ b/server_search.go
@@ -0,0 +1,216 @@
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+ "net"
+ "strings"
+)
+
+func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
+ }
+ }()
+
+ searchReq, err := parseSearchRequest(boundDN, req, controls)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ filterPacket, err := CompileFilter(searchReq.Filter)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ fnNames := []string{}
+ for k := range server.SearchFns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(searchReq.BaseDN, fnNames)
+ searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
+ if err != nil {
+ return NewError(searchResp.ResultCode, err)
+ }
+
+ if server.EnforceLDAP {
+ if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
+ // Server DerefAliases not supported: RFC4511 4.5.1.3
+ return NewError(LDAPResultOperationsError, errors.New("Server DerefAliases not supported"))
+ }
+ if searchReq.TimeLimit > 0 {
+ // TODO: Server TimeLimit not implemented
+ }
+ }
+
+ for i, entry := range searchResp.Entries {
+ if server.EnforceLDAP {
+ // size limit
+ if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
+ break
+ }
+
+ // filter
+ keep, resultCode := ServerApplyFilter(filterPacket, entry)
+ if resultCode != LDAPResultSuccess {
+ return NewError(resultCode, errors.New("ServerApplyFilter error"))
+ }
+ if !keep {
+ continue
+ }
+
+ // constrained search scope
+ switch searchReq.Scope {
+ case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
+ case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
+ if entry.DN != searchReq.BaseDN {
+ continue
+ }
+ case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
+ parts := strings.Split(entry.DN, ",")
+ if len(parts) < 2 && entry.DN != searchReq.BaseDN {
+ continue
+ }
+ if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
+ continue
+ }
+ }
+
+ // attributes
+ if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
+ entry, err = filterAttributes(entry, searchReq.Attributes)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ }
+
+ // respond
+ responsePacket := encodeSearchResponse(messageID, searchReq, entry)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ return nil
+}
+
+/////////////////////////
+func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) {
+ if len(req.Children) != 8 {
+ return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
+ }
+
+ // Parse the request
+ baseObject, ok := req.Children[0].Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ s, ok := req.Children[1].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ scope := int(s)
+ d, ok := req.Children[2].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ derefAliases := int(d)
+ s, ok = req.Children[3].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ sizeLimit := int(s)
+ t, ok := req.Children[4].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ timeLimit := int(t)
+ typesOnly := false
+ if req.Children[5].Value != nil {
+ typesOnly, ok = req.Children[5].Value.(bool)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ }
+ filter, err := DecompileFilter(req.Children[6])
+ if err != nil {
+ return SearchRequest{}, err
+ }
+ attributes := []string{}
+ for _, attr := range req.Children[7].Children {
+ a, ok := attr.Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ attributes = append(attributes, a)
+ }
+ searchReq := SearchRequest{baseObject, scope,
+ derefAliases, sizeLimit, timeLimit,
+ typesOnly, filter, attributes, *controls}
+
+ return searchReq, nil
+}
+
+/////////////////////////
+func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
+ // only return requested attributes
+ newAttributes := []*EntryAttribute{}
+
+ for _, attr := range entry.Attributes {
+ for _, requested := range attributes {
+ if strings.ToLower(attr.Name) == strings.ToLower(requested) {
+ newAttributes = append(newAttributes, attr)
+ }
+ }
+ }
+ entry.Attributes = newAttributes
+
+ return entry, nil
+}
+
+/////////////////////////
+func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
+ searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
+
+ attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
+ for _, attribute := range res.Attributes {
+ attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
+ }
+
+ searchEntry.AppendChild(attrs)
+ responsePacket.AppendChild(searchEntry)
+
+ return responsePacket
+}
+
+func encodeSearchAttribute(name string, values []string) *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
+
+ valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
+ for _, value := range values {
+ valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
+ }
+
+ packet.AppendChild(valuesPacket)
+
+ return packet
+}
+
+func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
+ donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+ responsePacket.AppendChild(donePacket)
+
+ return responsePacket
+}