diff options
Diffstat (limited to 'search.go')
-rw-r--r-- | search.go | 535 |
1 files changed, 307 insertions, 228 deletions
@@ -1,269 +1,348 @@ // Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. - +// // File contains Search functionality +// +// https://tools.ietf.org/html/rfc4511 +// +// SearchRequest ::= [APPLICATION 3] SEQUENCE { +// baseObject LDAPDN, +// scope ENUMERATED { +// baseObject (0), +// singleLevel (1), +// wholeSubtree (2), +// ... }, +// derefAliases ENUMERATED { +// neverDerefAliases (0), +// derefInSearching (1), +// derefFindingBaseObj (2), +// derefAlways (3) }, +// sizeLimit INTEGER (0 .. maxInt), +// timeLimit INTEGER (0 .. maxInt), +// typesOnly BOOLEAN, +// filter Filter, +// attributes AttributeSelection } +// +// AttributeSelection ::= SEQUENCE OF selector LDAPString +// -- The LDAPString is constrained to +// -- <attributeSelector> in Section 4.5.1.8 +// +// Filter ::= CHOICE { +// and [0] SET SIZE (1..MAX) OF filter Filter, +// or [1] SET SIZE (1..MAX) OF filter Filter, +// not [2] Filter, +// equalityMatch [3] AttributeValueAssertion, +// substrings [4] SubstringFilter, +// greaterOrEqual [5] AttributeValueAssertion, +// lessOrEqual [6] AttributeValueAssertion, +// present [7] AttributeDescription, +// approxMatch [8] AttributeValueAssertion, +// extensibleMatch [9] MatchingRuleAssertion, +// ... } +// +// SubstringFilter ::= SEQUENCE { +// type AttributeDescription, +// substrings SEQUENCE SIZE (1..MAX) OF substring CHOICE { +// initial [0] AssertionValue, -- can occur at most once +// any [1] AssertionValue, +// final [2] AssertionValue } -- can occur at most once +// } +// +// MatchingRuleAssertion ::= SEQUENCE { +// matchingRule [1] MatchingRuleId OPTIONAL, +// type [2] AttributeDescription OPTIONAL, +// matchValue [3] AssertionValue, +// dnAttributes [4] BOOLEAN DEFAULT FALSE } +// +// package ldap import ( - "github.com/mmitton/asn1-ber" - "fmt" - "os" + "errors" + "fmt" + "github.com/tmfkams/asn1-ber" + "strings" ) const ( - ScopeBaseObject = 0 - ScopeSingleLevel = 1 - ScopeWholeSubtree = 2 + ScopeBaseObject = 0 + ScopeSingleLevel = 1 + ScopeWholeSubtree = 2 ) -var ScopeMap = map[ int ] string { - ScopeBaseObject : "Base Object", - ScopeSingleLevel : "Single Level", - ScopeWholeSubtree : "Whole Subtree", +var ScopeMap = map[int]string{ + ScopeBaseObject: "Base Object", + ScopeSingleLevel: "Single Level", + ScopeWholeSubtree: "Whole Subtree", } const ( - NeverDerefAliases = 0 - DerefInSearching = 1 - DerefFindingBaseObj = 2 - DerefAlways = 3 + NeverDerefAliases = 0 + DerefInSearching = 1 + DerefFindingBaseObj = 2 + DerefAlways = 3 ) -var DerefMap = map[ int ] string { - NeverDerefAliases : "NeverDerefAliases", - DerefInSearching : "DerefInSearching", - DerefFindingBaseObj : "DerefFindingBaseObj", - DerefAlways : "DerefAlways", +var DerefMap = map[int]string{ + NeverDerefAliases: "NeverDerefAliases", + DerefInSearching: "DerefInSearching", + DerefFindingBaseObj: "DerefFindingBaseObj", + DerefAlways: "DerefAlways", } type Entry struct { - DN string - Attributes []*EntryAttribute + DN string + Attributes []*EntryAttribute +} + +func (e *Entry) GetAttributeValues(Attribute string) []string { + for _, attr := range e.Attributes { + if attr.Name == Attribute { + return attr.Values + } + } + return []string{} +} + +func (e *Entry) GetAttributeValue(Attribute string) string { + values := e.GetAttributeValues(Attribute) + if len(values) == 0 { + return "" + } + return values[0] +} + +func (e *Entry) Print() { + fmt.Printf("DN: %s\n", e.DN) + for _, attr := range e.Attributes { + attr.Print() + } +} + +func (e *Entry) PrettyPrint(indent int) { + fmt.Printf("%sDN: %s\n", strings.Repeat(" ", indent), e.DN) + for _, attr := range e.Attributes { + attr.PrettyPrint(indent + 2) + } } type EntryAttribute struct { - Name string - Values []string + Name string + Values []string } -type SearchResult struct { - Entries []*Entry - Referrals []string - Controls []Control +func (e *EntryAttribute) Print() { + fmt.Printf("%s: %s\n", e.Name, e.Values) } -func (e *Entry) GetAttributeValues( Attribute string ) []string { - for _, attr := range e.Attributes { - if attr.Name == Attribute { - return attr.Values - } - } +func (e *EntryAttribute) PrettyPrint(indent int) { + fmt.Printf("%s%s: %s\n", strings.Repeat(" ", indent), e.Name, e.Values) +} + +type SearchResult struct { + Entries []*Entry + Referrals []string + Controls []Control +} - return []string{ } +func (s *SearchResult) Print() { + for _, entry := range s.Entries { + entry.Print() + } } -func (e *Entry) GetAttributeValue( Attribute string ) string { - values := e.GetAttributeValues( Attribute ) - if len( values ) == 0 { - return "" - } - return values[ 0 ] +func (s *SearchResult) PrettyPrint(indent int) { + for _, entry := range s.Entries { + entry.PrettyPrint(indent) + } } type SearchRequest struct { - BaseDN string - Scope int - DerefAliases int - SizeLimit int - TimeLimit int - TypesOnly bool - Filter string - Attributes []string - Controls []Control + BaseDN string + Scope int + DerefAliases int + SizeLimit int + TimeLimit int + TypesOnly bool + Filter string + Attributes []string + Controls []Control +} + +func (s *SearchRequest) encode() (*ber.Packet, *Error) { + request := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request") + request.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, s.BaseDN, "Base DN")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(s.Scope), "Scope")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(s.DerefAliases), "Deref Aliases")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(s.SizeLimit), "Size Limit")) + request.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(s.TimeLimit), "Time Limit")) + request.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, s.TypesOnly, "Types Only")) + // compile and encode filter + filterPacket, err := CompileFilter(s.Filter) + if err != nil { + return nil, err + } + request.AppendChild(filterPacket) + // encode attributes + attributesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes") + for _, attribute := range s.Attributes { + attributesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute")) + } + request.AppendChild(attributesPacket) + return request, nil } func NewSearchRequest( - BaseDN string, - Scope, DerefAliases, SizeLimit, TimeLimit int, - TypesOnly bool, - Filter string, - Attributes []string, - Controls []Control, - ) (*SearchRequest) { - return &SearchRequest{ - BaseDN: BaseDN, - Scope: Scope, - DerefAliases: DerefAliases, - SizeLimit: SizeLimit, - TimeLimit: TimeLimit, - TypesOnly: TypesOnly, - Filter: Filter, - Attributes: Attributes, - Controls: Controls, - } + BaseDN string, + Scope, DerefAliases, SizeLimit, TimeLimit int, + TypesOnly bool, + Filter string, + Attributes []string, + Controls []Control, +) *SearchRequest { + return &SearchRequest{ + BaseDN: BaseDN, + Scope: Scope, + DerefAliases: DerefAliases, + SizeLimit: SizeLimit, + TimeLimit: TimeLimit, + TypesOnly: TypesOnly, + Filter: Filter, + Attributes: Attributes, + Controls: Controls, + } } -func (l *Conn) SearchWithPaging( SearchRequest *SearchRequest, PagingSize uint32 ) (*SearchResult, *Error) { - if SearchRequest.Controls == nil { - SearchRequest.Controls = make( []Control, 0 ) - } - - PagingControl := NewControlPaging( PagingSize ) - SearchRequest.Controls = append( SearchRequest.Controls, PagingControl ) - SearchResult := new( SearchResult ) - for { - result, err := l.Search( SearchRequest ) - if l.Debug { - fmt.Printf( "Looking for Paging Control...\n" ) - } - if err != nil { - return SearchResult, err - } - if result == nil { - return SearchResult, NewError( ErrorNetwork, os.NewError( "Packet not received" ) ) - } - - for _, entry := range result.Entries { - SearchResult.Entries = append( SearchResult.Entries, entry ) - } - for _, referral := range result.Referrals { - SearchResult.Referrals = append( SearchResult.Referrals, referral ) - } - for _, control := range result.Controls { - SearchResult.Controls = append( SearchResult.Controls, control ) - } - - if l.Debug { - fmt.Printf( "Looking for Paging Control...\n" ) - } - paging_result := FindControl( result.Controls, ControlTypePaging ) - if paging_result == nil { - PagingControl = nil - if l.Debug { - fmt.Printf( "Could not find paging control. Breaking...\n" ) - } - break - } - - cookie := paging_result.(*ControlPaging).Cookie - if len( cookie ) == 0 { - PagingControl = nil - if l.Debug { - fmt.Printf( "Could not find cookie. Breaking...\n" ) - } - break - } - PagingControl.SetCookie( cookie ) - } - - if PagingControl != nil { - if l.Debug { - fmt.Printf( "Abandoning Paging...\n" ) - } - PagingControl.PagingSize = 0 - l.Search( SearchRequest ) - } - - return SearchResult, nil +func (l *Conn) SearchWithPaging(searchRequest *SearchRequest, pagingSize uint32) (*SearchResult, *Error) { + if searchRequest.Controls == nil { + searchRequest.Controls = make([]Control, 0) + } + + pagingControl := NewControlPaging(pagingSize) + searchRequest.Controls = append(searchRequest.Controls, pagingControl) + searchResult := new(SearchResult) + for { + result, err := l.Search(searchRequest) + l.Debug.Printf("Looking for Paging Control...\n") + if err != nil { + return searchResult, err + } + if result == nil { + return searchResult, NewError(ErrorNetwork, errors.New("Packet not received")) + } + + for _, entry := range result.Entries { + searchResult.Entries = append(searchResult.Entries, entry) + } + for _, referral := range result.Referrals { + searchResult.Referrals = append(searchResult.Referrals, referral) + } + for _, control := range result.Controls { + searchResult.Controls = append(searchResult.Controls, control) + } + + l.Debug.Printf("Looking for Paging Control...\n") + pagingResult := FindControl(result.Controls, ControlTypePaging) + if pagingResult == nil { + pagingControl = nil + l.Debug.Printf("Could not find paging control. Breaking...\n") + break + } + + cookie := pagingResult.(*ControlPaging).Cookie + if len(cookie) == 0 { + pagingControl = nil + l.Debug.Printf("Could not find cookie. Breaking...\n") + break + } + pagingControl.SetCookie(cookie) + } + + if pagingControl != nil { + l.Debug.Printf("Abandoning Paging...\n") + pagingControl.PagingSize = 0 + l.Search(searchRequest) + } + + return searchResult, nil } -func (l *Conn) Search( SearchRequest *SearchRequest ) (*SearchResult, *Error) { - messageID := l.nextMessageID() - - packet := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request" ) - packet.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID" ) ) - searchRequest := ber.Encode( ber.ClassApplication, ber.TypeConstructed, ApplicationSearchRequest, nil, "Search Request" ) - searchRequest.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, SearchRequest.BaseDN, "Base DN" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(SearchRequest.Scope), "Scope" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagEnumerated, uint64(SearchRequest.DerefAliases), "Deref Aliases" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(SearchRequest.SizeLimit), "Size Limit" ) ) - searchRequest.AppendChild( ber.NewInteger( ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, uint64(SearchRequest.TimeLimit), "Time Limit" ) ) - searchRequest.AppendChild( ber.NewBoolean( ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, SearchRequest.TypesOnly, "Types Only" ) ) - filterPacket, err := CompileFilter( SearchRequest.Filter ) - if err != nil { - return nil, err - } - searchRequest.AppendChild( filterPacket ) - attributesPacket := ber.Encode( ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes" ) - for _, attribute := range SearchRequest.Attributes { - attributesPacket.AppendChild( ber.NewString( ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute" ) ) - } - searchRequest.AppendChild( attributesPacket ) - packet.AppendChild( searchRequest ) - if SearchRequest.Controls != nil { - packet.AppendChild( encodeControls( SearchRequest.Controls ) ) - } - - if l.Debug { - ber.PrintPacket( packet ) - } - - channel, err := l.sendMessage( packet ) - if err != nil { - return nil, err - } - if channel == nil { - return nil, NewError( ErrorNetwork, os.NewError( "Could not send message" ) ) - } - defer l.finishMessage( messageID ) - - result := &SearchResult{ - Entries: make( []*Entry, 0 ), - Referrals: make( []string, 0 ), - Controls: make( []Control, 0 ) } - - foundSearchResultDone := false - for !foundSearchResultDone { - if l.Debug { - fmt.Printf( "%d: waiting for response\n", messageID ) - } - packet = <-channel - if l.Debug { - fmt.Printf( "%d: got response %p\n", messageID, packet ) - } - if packet == nil { - return nil, NewError( ErrorNetwork, os.NewError( "Could not retrieve message" ) ) - } - - if l.Debug { - if err := addLDAPDescriptions( packet ); err != nil { - return nil, NewError( ErrorDebugging, err ) - } - ber.PrintPacket( packet ) - } - - switch packet.Children[ 1 ].Tag { - case 4: - entry := new( Entry ) - entry.DN = packet.Children[ 1 ].Children[ 0 ].Value.(string) - for _, child := range packet.Children[ 1 ].Children[ 1 ].Children { - attr := new( EntryAttribute ) - attr.Name = child.Children[ 0 ].Value.(string) - for _, value := range child.Children[ 1 ].Children { - attr.Values = append( attr.Values, value.Value.(string) ) - } - entry.Attributes = append( entry.Attributes, attr ) - } - result.Entries = append( result.Entries, entry ) - case 5: - result_code, result_description := getLDAPResultCode( packet ) - if result_code != 0 { - return result, NewError( result_code, os.NewError( result_description ) ) - } - if len( packet.Children ) == 3 { - for _, child := range packet.Children[ 2 ].Children { - result.Controls = append( result.Controls, DecodeControl( child ) ) - } - } - foundSearchResultDone = true - case 19: - result.Referrals = append( result.Referrals, packet.Children[ 1 ].Children[ 0 ].Value.(string) ) - } - } - if l.Debug { - fmt.Printf( "%d: returning\n", messageID ) - } - - return result, nil +func (l *Conn) Search(searchRequest *SearchRequest) (*SearchResult, *Error) { + messageID := l.nextMessageID() + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request") + packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimative, ber.TagInteger, messageID, "MessageID")) + // encode search request + encodedSearchRequest, err := searchRequest.encode() + if err != nil { + return nil, err + } + packet.AppendChild(encodedSearchRequest) + // encode search controls + if searchRequest.Controls != nil { + packet.AppendChild(encodeControls(searchRequest.Controls)) + } + + l.Debug.PrintPacket(packet) + + channel, err := l.sendMessage(packet) + if err != nil { + return nil, err + } + if channel == nil { + return nil, NewError(ErrorNetwork, errors.New("Could not send message")) + } + defer l.finishMessage(messageID) + + result := &SearchResult{ + Entries: make([]*Entry, 0), + Referrals: make([]string, 0), + Controls: make([]Control, 0)} + + foundSearchResultDone := false + for !foundSearchResultDone { + l.Debug.Printf("%d: waiting for response\n", messageID) + packet = <-channel + l.Debug.Printf("%d: got response %p\n", messageID, packet) + if packet == nil { + return nil, NewError(ErrorNetwork, errors.New("Could not retrieve message")) + } + + if l.Debug { + if err := addLDAPDescriptions(packet); err != nil { + return nil, NewError(ErrorDebugging, err.Err) + } + ber.PrintPacket(packet) + } + + switch packet.Children[1].Tag { + case 4: + entry := new(Entry) + entry.DN = packet.Children[1].Children[0].Value.(string) + for _, child := range packet.Children[1].Children[1].Children { + attr := new(EntryAttribute) + attr.Name = child.Children[0].Value.(string) + for _, value := range child.Children[1].Children { + attr.Values = append(attr.Values, value.Value.(string)) + } + entry.Attributes = append(entry.Attributes, attr) + } + result.Entries = append(result.Entries, entry) + case 5: + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return result, NewError(resultCode, errors.New(resultDescription)) + } + if len(packet.Children) == 3 { + for _, child := range packet.Children[2].Children { + result.Controls = append(result.Controls, DecodeControl(child)) + } + } + foundSearchResultDone = true + case 19: + result.Referrals = append(result.Referrals, packet.Children[1].Children[0].Value.(string)) + } + } + l.Debug.Printf("%d: returning\n", messageID) + return result, nil } |