diff options
Diffstat (limited to '')
-rw-r--r-- | server_search.go | 216 |
1 files changed, 216 insertions, 0 deletions
diff --git a/server_search.go b/server_search.go new file mode 100644 index 0000000..a7d78ac --- /dev/null +++ b/server_search.go @@ -0,0 +1,216 @@ +package ldap + +import ( + "errors" + "fmt" + "github.com/nmcclain/asn1-ber" + "net" + "strings" +) + +func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) { + defer func() { + if r := recover(); r != nil { + resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) + } + }() + + searchReq, err := parseSearchRequest(boundDN, req, controls) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + filterPacket, err := CompileFilter(searchReq.Filter) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + fnNames := []string{} + for k := range server.SearchFns { + fnNames = append(fnNames, k) + } + fn := routeFunc(searchReq.BaseDN, fnNames) + searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn) + if err != nil { + return NewError(searchResp.ResultCode, err) + } + + if server.EnforceLDAP { + if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} + // Server DerefAliases not supported: RFC4511 4.5.1.3 + return NewError(LDAPResultOperationsError, errors.New("Server DerefAliases not supported")) + } + if searchReq.TimeLimit > 0 { + // TODO: Server TimeLimit not implemented + } + } + + for i, entry := range searchResp.Entries { + if server.EnforceLDAP { + // size limit + if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { + break + } + + // filter + keep, resultCode := ServerApplyFilter(filterPacket, entry) + if resultCode != LDAPResultSuccess { + return NewError(resultCode, errors.New("ServerApplyFilter error")) + } + if !keep { + continue + } + + // constrained search scope + switch searchReq.Scope { + case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. + case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. + if entry.DN != searchReq.BaseDN { + continue + } + case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. + parts := strings.Split(entry.DN, ",") + if len(parts) < 2 && entry.DN != searchReq.BaseDN { + continue + } + if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { + continue + } + } + + // attributes + if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { + entry, err = filterAttributes(entry, searchReq.Attributes) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + } + + // respond + responsePacket := encodeSearchResponse(messageID, searchReq, entry) + if err = sendPacket(conn, responsePacket); err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + return nil +} + +///////////////////////// +func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) { + if len(req.Children) != 8 { + return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) + } + + // Parse the request + baseObject, ok := req.Children[0].Value.(string) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + s, ok := req.Children[1].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + scope := int(s) + d, ok := req.Children[2].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + derefAliases := int(d) + s, ok = req.Children[3].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + sizeLimit := int(s) + t, ok := req.Children[4].Value.(uint64) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + timeLimit := int(t) + typesOnly := false + if req.Children[5].Value != nil { + typesOnly, ok = req.Children[5].Value.(bool) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + } + filter, err := DecompileFilter(req.Children[6]) + if err != nil { + return SearchRequest{}, err + } + attributes := []string{} + for _, attr := range req.Children[7].Children { + a, ok := attr.Value.(string) + if !ok { + return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request")) + } + attributes = append(attributes, a) + } + searchReq := SearchRequest{baseObject, scope, + derefAliases, sizeLimit, timeLimit, + typesOnly, filter, attributes, *controls} + + return searchReq, nil +} + +///////////////////////// +func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { + // only return requested attributes + newAttributes := []*EntryAttribute{} + + for _, attr := range entry.Attributes { + for _, requested := range attributes { + if strings.ToLower(attr.Name) == strings.ToLower(requested) { + newAttributes = append(newAttributes, attr) + } + } + } + entry.Attributes = newAttributes + + return entry, nil +} + +///////////////////////// +func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") + searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) + + attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") + for _, attribute := range res.Attributes { + attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) + } + + searchEntry.AppendChild(attrs) + responsePacket.AppendChild(searchEntry) + + return responsePacket +} + +func encodeSearchAttribute(name string, values []string) *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) + + valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") + for _, value := range values { + valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) + } + + packet.AppendChild(valuesPacket) + + return packet +} + +func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") + donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + responsePacket.AppendChild(donePacket) + + return responsePacket +} |