diff options
Diffstat (limited to 'filter.go')
-rw-r--r-- | filter.go | 161 |
1 files changed, 147 insertions, 14 deletions
@@ -7,8 +7,8 @@ package ldap import ( "errors" "fmt" - - "github.com/vanackere/asn1-ber" + "github.com/nmcclain/asn1-ber" + "strings" ) const ( @@ -24,7 +24,7 @@ const ( FilterExtensibleMatch = 9 ) -var filterMap = map[uint8]string{ +var FilterMap = map[uint8]string{ FilterAnd: "And", FilterOr: "Or", FilterNot: "Not", @@ -163,15 +163,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { newPos++ return packet, newPos, err case '&': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, filterMap[FilterAnd]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) newPos, err = compileFilterSet(filter, pos+1, packet) return packet, newPos, err case '|': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, filterMap[FilterOr]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) newPos, err = compileFilterSet(filter, pos+1, packet) return packet, newPos, err case '!': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, filterMap[FilterNot]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) var child *ber.Packet child, newPos, err = compileFilter(filter, pos+1) packet.AppendChild(child) @@ -184,15 +184,15 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { case packet != nil: condition += fmt.Sprintf("%c", filter[newPos]) case filter[newPos] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, filterMap[FilterEqualityMatch]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) case filter[newPos] == '>' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, filterMap[FilterGreaterOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) newPos++ case filter[newPos] == '<' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, filterMap[FilterLessOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) newPos++ case filter[newPos] == '~' && filter[newPos+1] == '=': - packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, filterMap[FilterLessOrEqual]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual]) newPos++ case packet == nil: attribute += fmt.Sprintf("%c", filter[newPos]) @@ -211,7 +211,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { if packet.Tag == FilterEqualityMatch && condition == "*" { packet.TagType = ber.TypePrimitive packet.Tag = FilterPresent - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] packet.Data.WriteString(attribute) return packet, newPos + 1, nil } @@ -220,21 +220,21 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*': // Any packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring")) packet.AppendChild(seq) case packet.Tag == FilterEqualityMatch && condition[0] == '*': // Final packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsFinal, condition[1:], "Final Substring")) packet.AppendChild(seq) case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*': // Initial packet.Tag = FilterSubstrings - packet.Description = filterMap[packet.Tag] + packet.Description = FilterMap[packet.Tag] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring")) packet.AppendChild(seq) @@ -245,3 +245,136 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) { return packet, newPos, err } } + +func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) { + //log.Printf("%# v", pretty.Formatter(entry)) + + switch FilterMap[f.Tag] { + default: + //log.Fatalf("Unknown LDAP filter code: %d", f.Tag) + return false, LDAPResultOperationsError + case "Equality Match": + if len(f.Children) != 2 { + return false, LDAPResultOperationsError + } + attribute := f.Children[0].Value.(string) + value := f.Children[1].Value.(string) + for _, a := range entry.Attributes { + if strings.ToLower(a.Name) == strings.ToLower(attribute) { + for _, v := range a.Values { + if strings.ToLower(v) == strings.ToLower(value) { + return true, LDAPResultSuccess + } + } + } + } + case "Present": + for _, a := range entry.Attributes { + if strings.ToLower(a.Name) == strings.ToLower(f.Data.String()) { + return true, LDAPResultSuccess + } + } + case "And": + for _, child := range f.Children { + ok, exitCode := ServerApplyFilter(child, entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } + if !ok { + return false, LDAPResultSuccess + } + } + return true, LDAPResultSuccess + case "Or": + anyOk := false + for _, child := range f.Children { + ok, exitCode := ServerApplyFilter(child, entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } else if ok { + anyOk = true + } + } + if anyOk { + return true, LDAPResultSuccess + } + case "Not": + if len(f.Children) != 1 { + return false, LDAPResultOperationsError + } + ok, exitCode := ServerApplyFilter(f.Children[0], entry) + if exitCode != LDAPResultSuccess { + return false, exitCode + } else if !ok { + return true, LDAPResultSuccess + } + case "FilterSubstrings": + return false, LDAPResultOperationsError + case "FilterGreaterOrEqual": + return false, LDAPResultOperationsError + case "FilterLessOrEqual": + return false, LDAPResultOperationsError + case "FilterApproxMatch": + return false, LDAPResultOperationsError + case "FilterExtensibleMatch": + return false, LDAPResultOperationsError + } + + return false, LDAPResultSuccess +} + +func GetFilterType(filter string) (string, error) { // TODO <- test this + f, err := CompileFilter(filter) + if err != nil { + return "", err + } + return parseFilterType(f) +} +func parseFilterType(f *ber.Packet) (string, error) { + searchType := "" + switch FilterMap[f.Tag] { + case "Equality Match": + if len(f.Children) != 2 { + return "", errors.New("Equality match must have only two children") + } + attribute := strings.ToLower(f.Children[0].Value.(string)) + value := f.Children[1].Value.(string) + + if attribute == "objectclass" { + searchType = strings.ToLower(value) + } + case "And": + for _, child := range f.Children { + subType, err := parseFilterType(child) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + } + case "Or": + for _, child := range f.Children { + subType, err := parseFilterType(child) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + } + case "Not": + if len(f.Children) != 1 { + return "", errors.New("Not filter must have only one child") + } + subType, err := parseFilterType(f.Children[0]) + if err != nil { + return "", err + } + if len(subType) > 0 { + searchType = subType + } + + } + return strings.ToLower(searchType), nil +} |