diff options
Diffstat (limited to 'filter.go')
-rw-r--r-- | filter.go | 249 |
1 files changed, 249 insertions, 0 deletions
diff --git a/filter.go b/filter.go new file mode 100644 index 0000000..1b8fd1e --- /dev/null +++ b/filter.go @@ -0,0 +1,249 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// File contains a filter compiler/decompiler +package ldap + +import ( + "fmt" + "os" + "github.com/mmitton/asn1-ber" +) + +const ( + FilterAnd = 0 + FilterOr = 1 + FilterNot = 2 + FilterEqualityMatch = 3 + FilterSubstrings = 4 + FilterGreaterOrEqual = 5 + FilterLessOrEqual = 6 + FilterPresent = 7 + FilterApproxMatch = 8 + FilterExtensibleMatch = 9 +) + +var FilterMap = map[ uint64 ] string { + FilterAnd : "And", + FilterOr : "Or", + FilterNot : "Not", + FilterEqualityMatch : "Equality Match", + FilterSubstrings : "Substrings", + FilterGreaterOrEqual : "Greater Or Equal", + FilterLessOrEqual : "Less Or Equal", + FilterPresent : "Present", + FilterApproxMatch : "Approx Match", + FilterExtensibleMatch : "Extensible Match", +} + +const ( + FilterSubstringsInitial = 0 + FilterSubstringsAny = 1 + FilterSubstringsFinal = 2 +) + +var FilterSubstringsMap = map[ uint64 ] string { + FilterSubstringsInitial : "Substrings Initial", + FilterSubstringsAny : "Substrings Any", + FilterSubstringsFinal : "Substrings Final", +} + +func CompileFilter( filter string ) ( *ber.Packet, *Error ) { + if len( filter ) == 0 || filter[ 0 ] != '(' { + return nil, NewError( ErrorFilterCompile, os.NewError( "Filter does not start with an '('" ) ) + } + packet, pos, err := compileFilter( filter, 1 ) + if err != nil { + return nil, err + } + if pos != len( filter ) { + return nil, NewError( ErrorFilterCompile, os.NewError( "Finished compiling filter with extra at end.\n" + fmt.Sprint( filter[pos:] ) ) ) + } + return packet, nil +} + +func DecompileFilter( packet *ber.Packet ) (ret string, err *Error) { + defer func() { + if r := recover(); r != nil { + err = NewError( ErrorFilterDecompile, os.NewError( "Error decompiling filter" ) ) + } + }() + ret = "(" + err = nil + child_str := "" + + switch packet.Tag { + case FilterAnd: + ret += "&" + for _, child := range packet.Children { + child_str, err = DecompileFilter( child ) + if err != nil { + return + } + ret += child_str + } + case FilterOr: + ret += "|" + for _, child := range packet.Children { + child_str, err = DecompileFilter( child ) + if err != nil { + return + } + ret += child_str + } + case FilterNot: + ret += "!" + child_str, err = DecompileFilter( packet.Children[ 0 ] ) + if err != nil { + return + } + ret += child_str + + case FilterSubstrings: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += "=" + switch packet.Children[ 1 ].Children[ 0 ].Tag { + case FilterSubstringsInitial: + ret += ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*" + case FilterSubstringsAny: + ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + "*" + case FilterSubstringsFinal: + ret += "*" + ber.DecodeString( packet.Children[ 1 ].Children[ 0 ].Data.Bytes() ) + } + case FilterEqualityMatch: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += "=" + ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) + case FilterGreaterOrEqual: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += ">=" + ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) + case FilterLessOrEqual: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += "<=" + ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) + case FilterPresent: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += "=*" + case FilterApproxMatch: + ret += ber.DecodeString( packet.Children[ 0 ].Data.Bytes() ) + ret += "~=" + ret += ber.DecodeString( packet.Children[ 1 ].Data.Bytes() ) + } + + ret += ")" + return +} + +func compileFilterSet( filter string, pos int, parent *ber.Packet ) ( int, *Error ) { + for pos < len( filter ) && filter[ pos ] == '(' { + child, new_pos, err := compileFilter( filter, pos + 1 ) + if err != nil { + return pos, err + } + pos = new_pos + parent.AppendChild( child ) + } + if pos == len( filter ) { + return pos, NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) ) + } + + return pos + 1, nil +} + +func compileFilter( filter string, pos int ) ( p *ber.Packet, new_pos int, err *Error ) { + defer func() { + if r := recover(); r != nil { + err = NewError( ErrorFilterCompile, os.NewError( "Error compiling filter" ) ) + } + }() + p = nil + new_pos = pos + err = nil + + switch filter[pos] { + case '(': + p, new_pos, err = compileFilter( filter, pos + 1 ) + new_pos++ + return + case '&': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[ FilterAnd ] ) + new_pos, err = compileFilterSet( filter, pos + 1, p ) + return + case '|': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[ FilterOr ] ) + new_pos, err = compileFilterSet( filter, pos + 1, p ) + return + case '!': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[ FilterNot ] ) + var child *ber.Packet + child, new_pos, err = compileFilter( filter, pos + 1 ) + p.AppendChild( child ) + return + default: + attribute := "" + condition := "" + for new_pos < len( filter ) && filter[ new_pos ] != ')' { + switch { + case p != nil: + condition += fmt.Sprintf( "%c", filter[ new_pos ] ) + case filter[ new_pos ] == '=': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[ FilterEqualityMatch ] ) + case filter[ new_pos ] == '>' && filter[ new_pos + 1 ] == '=': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[ FilterGreaterOrEqual ] ) + new_pos++ + case filter[ new_pos ] == '<' && filter[ new_pos + 1 ] == '=': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[ FilterLessOrEqual ] ) + new_pos++ + case filter[ new_pos ] == '~' && filter[ new_pos + 1 ] == '=': + p = ber.Encode( ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[ FilterLessOrEqual ] ) + new_pos++ + case p == nil: + attribute += fmt.Sprintf( "%c", filter[ new_pos ] ) + } + new_pos++ + } + if new_pos == len( filter ) { + err = NewError( ErrorFilterCompile, os.NewError( "Unexpected end of filter" ) ) + return + } + if p == nil { + err = NewError( ErrorFilterCompile, os.NewError( "Error parsing filter" ) ) + return + } + p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute" ) ) + switch { + case p.Tag == FilterEqualityMatch && condition == "*": + p.Tag = FilterPresent + p.Description = FilterMap[ uint64(p.Tag) ] + case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*' && condition[ len( condition ) - 1 ] == '*': + // Any + p.Tag = FilterSubstrings + p.Description = FilterMap[ uint64(p.Tag) ] + seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) + seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[ 1 : len( condition ) - 1 ], "Any Substring" ) ) + p.AppendChild( seq ) + case p.Tag == FilterEqualityMatch && condition[ 0 ] == '*': + // Final + p.Tag = FilterSubstrings + p.Description = FilterMap[ uint64(p.Tag) ] + seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) + seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[ 1: ], "Final Substring" ) ) + p.AppendChild( seq ) + case p.Tag == FilterEqualityMatch && condition[ len( condition ) - 1 ] == '*': + // Initial + p.Tag = FilterSubstrings + p.Description = FilterMap[ uint64(p.Tag) ] + seq := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings" ) + seq.AppendChild( ber.NewString( ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[ :len( condition ) - 1 ], "Initial Substring" ) ) + p.AppendChild( seq ) + default: + p.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition" ) ) + } + new_pos++ + return + } + err = NewError( ErrorFilterCompile, os.NewError( "Reached end of filter without closing parens" ) ) + return +} |