diff options
-rw-r--r-- | bind.go | 8 | ||||
-rw-r--r-- | conn.go | 234 | ||||
-rw-r--r-- | control.go | 44 | ||||
-rw-r--r-- | filter.go | 121 | ||||
-rw-r--r-- | ldap.go | 6 | ||||
-rw-r--r-- | ldap_test.go | 2 |
6 files changed, 210 insertions, 205 deletions
@@ -33,8 +33,8 @@ func (l *Conn) Bind(username, password string) *Error { return NewError(ErrorNetwork, errors.New("Could not send message")) } defer l.finishMessage(messageID) - packet = <-channel + packet = <-channel if packet == nil { return NewError(ErrorNetwork, errors.New("Could not retrieve response")) } @@ -46,9 +46,9 @@ func (l *Conn) Bind(username, password string) *Error { ber.PrintPacket(packet) } - result_code, result_description := getLDAPResultCode(packet) - if result_code != 0 { - return NewError(result_code, errors.New(result_description)) + resultCode, resultDescription := getLDAPResultCode(packet) + if resultCode != 0 { + return NewError(resultCode, errors.New(resultDescription)) } return nil @@ -14,15 +14,31 @@ import ( "sync" ) +const ( + MessageQuit = 0 + MessageRequest = 1 + MessageResponse = 2 + MessageFinish = 3 +) + +type messagePacket struct { + Op int + MessageID uint64 + Packet *ber.Packet + Channel chan *ber.Packet +} + // LDAP Connection type Conn struct { - conn net.Conn - isSSL bool - Debug debugging - chanResults map[uint64]chan *ber.Packet - chanProcessMessage chan *messagePacket - chanMessageID chan uint64 - closeLock sync.Mutex + conn net.Conn + isSSL bool + isClosed bool + Debug debugging + chanConfirm chan int + chanResults map[uint64]chan *ber.Packet + chanMessage chan *messagePacket + chanMessageID chan uint64 + closeLock sync.Mutex } // Dial connects to the given address on the given network using net.Dial @@ -46,7 +62,6 @@ func DialSSL(network, addr string, config *tls.Config) (*Conn, *Error) { } conn := NewConn(c) conn.isSSL = true - conn.start() return conn, nil } @@ -59,7 +74,6 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { return nil, NewError(ErrorNetwork, err) } conn := NewConn(c) - if err := conn.startTLS(config); err != nil { conn.Close() return nil, NewError(ErrorNetwork, err.Err) @@ -71,12 +85,14 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, *Error) { // NewConn returns a new Conn using conn for network I/O. func NewConn(conn net.Conn) *Conn { return &Conn{ - conn: conn, - isSSL: false, - Debug: false, - chanResults: map[uint64]chan *ber.Packet{}, - chanProcessMessage: make(chan *messagePacket), - chanMessageID: make(chan uint64), + conn: conn, + isSSL: false, + isClosed: false, + Debug: false, + chanConfirm: make(chan int), + chanMessageID: make(chan uint64), + chanMessage: make(chan *messagePacket, 10), + chanResults: map[uint64]chan *ber.Packet{}, } } @@ -90,27 +106,34 @@ func (l *Conn) Close() *Error { l.closeLock.Lock() defer l.closeLock.Unlock() - l.sendProcessMessage(&messagePacket{Op: MessageQuit}) + // close only once + if l.isClosed { + return nil + } - if l.conn != nil { - err := l.conn.Close() - if err != nil { - return NewError(ErrorNetwork, err) - } - l.conn = nil + l.Debug.Printf("Sending quit message\n") + l.chanMessage <- &messagePacket{Op: MessageQuit} + <-l.chanConfirm + l.chanConfirm = nil + + l.Debug.Printf("Closing network connection\n") + if err := l.conn.Close(); err != nil { + return NewError(ErrorNetwork, err) } + + l.isClosed = true return nil } // Returns the next available messageID -func (l *Conn) nextMessageID() (messageID uint64) { - defer func() { - if r := recover(); r != nil { - messageID = 0 +func (l *Conn) nextMessageID() uint64 { + // l.chanMessageID will be set to nil by processMessage() + if l.chanMessageID != nil { + if messageID, ok := <-l.chanMessageID; ok { + return messageID } - }() - messageID = <-l.chanMessageID - return + } + return 0 } // StartTLS sends the command to start a TLS session and then creates a new TLS Client @@ -154,136 +177,119 @@ func (l *Conn) startTLS(config *tls.Config) *Error { return nil } -const ( - MessageQuit = 0 - MessageRequest = 1 - MessageResponse = 2 - MessageFinish = 3 -) - -type messagePacket struct { - Op int - MessageID uint64 - Packet *ber.Packet - Channel chan *ber.Packet -} - -func (l *Conn) sendMessage(p *ber.Packet) (out chan *ber.Packet, err *Error) { - message_id := p.Children[0].Value.(uint64) - out = make(chan *ber.Packet) - - if l.chanProcessMessage == nil { - err = NewError(ErrorNetwork, errors.New("Connection closed")) - return +func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, *Error) { + out := make(chan *ber.Packet) + // l.chanMessage will be set to nil by processMessage() + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{ + Op: MessageRequest, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + Channel: out, + } + } else { + return nil, NewError(ErrorNetwork, errors.New("Connection closed")) } - message_packet := &messagePacket{Op: MessageRequest, MessageID: message_id, Packet: p, Channel: out} - l.sendProcessMessage(message_packet) - return + return out, nil } func (l *Conn) processMessages() { - defer l.closeAllChannels() + defer func() { + for messageID, channel := range l.chanResults { + if channel != nil { + l.Debug.Printf("Closing channel for MessageID %d\n", messageID) + close(channel) + l.chanResults[messageID] = nil + } + } + // l.chanMessage should be closed by sender but there is more than one + close(l.chanMessage) + l.chanMessage = nil + close(l.chanMessageID) + // l.chanMessageID should be set to nil by nextMessageID() but it is not a go routine + l.chanMessageID = nil + close(l.chanConfirm) + }() - var message_id uint64 = 1 - var message_packet *messagePacket + var messageID uint64 = 1 for { select { - case l.chanMessageID <- message_id: - if l.conn == nil { - return - } - message_id++ - case message_packet = <-l.chanProcessMessage: - if l.conn == nil { - return - } - switch message_packet.Op { + case l.chanMessageID <- messageID: + messageID++ + case messagePacket := <-l.chanMessage: + switch messagePacket.Op { case MessageQuit: - // Close all channels and quit l.Debug.Printf("Shutting down\n") + l.chanConfirm <- 1 return case MessageRequest: // Add to message list and write to network - l.Debug.Printf("Sending message %d\n", message_packet.MessageID) - l.chanResults[message_packet.MessageID] = message_packet.Channel - buf := message_packet.Packet.Bytes() + l.Debug.Printf("Sending message %d\n", messagePacket.MessageID) + l.chanResults[messagePacket.MessageID] = messagePacket.Channel + buf := messagePacket.Packet.Bytes() + // TODO: understand for len(buf) > 0 { n, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s\n", err.Error()) - return + l.Close() + break } + // nothing else to send if n == len(buf) { break } + // the remaining buf content buf = buf[n:] } case MessageResponse: - // Pass back to waiting goroutine - l.Debug.Printf("Receiving message %d\n", message_packet.MessageID) - if chanResult, ok := l.chanResults[message_packet.MessageID]; ok { - // If the "Search Result Done" is read before the - // "Search Result Entry" no Entry can be returned - // go func() { chanResult <- message_packet.Packet }() - chanResult <- message_packet.Packet + l.Debug.Printf("Receiving message %d\n", messagePacket.MessageID) + if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok { + chanResult <- messagePacket.Packet } else { - log.Printf("Unexpected Message Result: %d\n", message_id) - ber.PrintPacket(message_packet.Packet) + log.Printf("Unexpected Message Result: %d\n", messagePacket.MessageID) + ber.PrintPacket(messagePacket.Packet) } case MessageFinish: // Remove from message list - l.Debug.Printf("Finished message %d\n", message_packet.MessageID) - l.chanResults[message_packet.MessageID] = nil + l.Debug.Printf("Finished message %d\n", messagePacket.MessageID) + l.chanResults[messagePacket.MessageID] = nil } } } } -func (l *Conn) closeAllChannels() { - log.Printf("closeAllChannels\n") - for messageID, channel := range l.chanResults { - if channel != nil { - l.Debug.Printf("Closing channel for MessageID %d\n", messageID) - close(channel) - l.chanResults[messageID] = nil - } - } - close(l.chanMessageID) - l.chanMessageID = nil - - close(l.chanProcessMessage) - l.chanProcessMessage = nil -} - func (l *Conn) finishMessage(MessageID uint64) { - message_packet := &messagePacket{Op: MessageFinish, MessageID: MessageID} - l.sendProcessMessage(message_packet) + // l.chanMessage will be set to nil by processMessage() + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{Op: MessageFinish, MessageID: MessageID} + } } func (l *Conn) reader() { - defer l.Close() + defer func() { + l.Close() + l.conn = nil + }() + for { - p, err := ber.ReadPacket(l.conn) + packet, err := ber.ReadPacket(l.conn) if err != nil { l.Debug.Printf("ldap.reader: %s\n", err.Error()) return } - addLDAPDescriptions(p) + addLDAPDescriptions(packet) - message_id := p.Children[0].Value.(uint64) - message_packet := &messagePacket{Op: MessageResponse, MessageID: message_id, Packet: p} - if l.chanProcessMessage != nil { - l.chanProcessMessage <- message_packet + if l.chanMessage != nil { + l.chanMessage <- &messagePacket{ + Op: MessageResponse, + MessageID: packet.Children[0].Value.(uint64), + Packet: packet, + } } else { - log.Printf("ldap.reader: Cannot return message\n") + log.Printf("ldap.reader: Cannot return message, channel is already closed\n") return } } } - -func (l *Conn) sendProcessMessage(message *messagePacket) { - if l.chanProcessMessage != nil { - go func() { l.chanProcessMessage <- message }() - } -} @@ -34,14 +34,14 @@ func (c *ControlString) GetControlType() string { return c.ControlType } -func (c *ControlString) Encode() (p *ber.Packet) { - p = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") - p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")")) +func (c *ControlString) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlType, "Control Type ("+ControlTypeMap[c.ControlType]+")")) if c.Criticality { - p.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, c.Criticality, "Criticality")) + packet.AppendChild(ber.NewBoolean(ber.ClassUniversal, ber.TypePrimative, ber.TagBoolean, c.Criticality, "Criticality")) } - p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlValue, "Control Value")) - return + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, c.ControlValue, "Control Value")) + return packet } func (c *ControlString) String() string { @@ -57,9 +57,9 @@ func (c *ControlPaging) GetControlType() string { return ControlTypePaging } -func (c *ControlPaging) Encode() (p *ber.Packet) { - p = ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") - p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")")) +func (c *ControlPaging) Encode() *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Control") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, ControlTypePaging, "Control Type ("+ControlTypeMap[ControlTypePaging]+")")) p2 := ber.Encode(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, nil, "Control Value (Paging)") seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Search Control Value") @@ -70,8 +70,8 @@ func (c *ControlPaging) Encode() (p *ber.Packet) { seq.AppendChild(cookie) p2.AppendChild(seq) - p.AppendChild(p2) - return + packet.AppendChild(p2) + return packet } func (c *ControlPaging) String() string { @@ -97,16 +97,16 @@ func FindControl(Controls []Control, ControlType string) Control { return nil } -func DecodeControl(p *ber.Packet) Control { - ControlType := p.Children[0].Value.(string) +func DecodeControl(packet *ber.Packet) Control { + ControlType := packet.Children[0].Value.(string) Criticality := false - p.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" - value := p.Children[1] - if len(p.Children) == 3 { - value = p.Children[2] - p.Children[1].Description = "Criticality" - Criticality = p.Children[1].Value.(bool) + packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")" + value := packet.Children[1] + if len(packet.Children) == 3 { + value = packet.Children[2] + packet.Children[1].Description = "Criticality" + Criticality = packet.Children[1].Value.(bool) } value.Description = "Control Value" @@ -149,9 +149,9 @@ func NewControlPaging(PagingSize uint32) *ControlPaging { } func encodeControls(Controls []Control) *ber.Packet { - p := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls") + packet := ber.Encode(ber.ClassContext, ber.TypeConstructed, 0, nil, "Controls") for _, control := range Controls { - p.AppendChild(control.Encode()) + packet.AppendChild(control.Encode()) } - return p + return packet } @@ -138,11 +138,11 @@ func DecompileFilter(packet *ber.Packet) (ret string, err *Error) { func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, *Error) { for pos < len(filter) && filter[pos] == '(' { - child, new_pos, err := compileFilter(filter, pos+1) + child, newPos, err := compileFilter(filter, pos+1) if err != nil { return pos, err } - pos = new_pos + pos = newPos parent.AppendChild(child) } if pos == len(filter) { @@ -152,98 +152,99 @@ func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, *Error) return pos + 1, nil } -func compileFilter(filter string, pos int) (p *ber.Packet, new_pos int, err *Error) { +func compileFilter(filter string, pos int) (*ber.Packet, int, *Error) { + var packet *ber.Packet + var err *Error + defer func() { if r := recover(); r != nil { err = NewError(ErrorFilterCompile, errors.New("Error compiling filter")) } }() - p = nil - new_pos = pos - err = nil + newPos := pos switch filter[pos] { case '(': - p, new_pos, err = compileFilter(filter, pos+1) - new_pos++ - return + packet, newPos, err = compileFilter(filter, pos+1) + newPos++ + return packet, newPos, err case '&': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) - new_pos, err = compileFilterSet(filter, pos+1, p) - return + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd]) + newPos, err = compileFilterSet(filter, pos+1, packet) + return packet, newPos, err case '|': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) - new_pos, err = compileFilterSet(filter, pos+1, p) - return + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr]) + newPos, err = compileFilterSet(filter, pos+1, packet) + return packet, newPos, err case '!': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot]) var child *ber.Packet - child, new_pos, err = compileFilter(filter, pos+1) - p.AppendChild(child) - return + child, newPos, err = compileFilter(filter, pos+1) + packet.AppendChild(child) + return packet, newPos, err default: attribute := "" condition := "" - for new_pos < len(filter) && filter[new_pos] != ')' { + for newPos < len(filter) && filter[newPos] != ')' { switch { - case p != nil: - condition += fmt.Sprintf("%c", filter[new_pos]) - case filter[new_pos] == '=': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) - case filter[new_pos] == '>' && filter[new_pos+1] == '=': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) - new_pos++ - case filter[new_pos] == '<' && filter[new_pos+1] == '=': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) - new_pos++ - case filter[new_pos] == '~' && filter[new_pos+1] == '=': - p = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual]) - new_pos++ - case p == nil: - attribute += fmt.Sprintf("%c", filter[new_pos]) + case packet != nil: + condition += fmt.Sprintf("%c", filter[newPos]) + case filter[newPos] == '=': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch]) + case filter[newPos] == '>' && filter[newPos+1] == '=': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual]) + newPos++ + case filter[newPos] == '<' && filter[newPos+1] == '=': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual]) + newPos++ + case filter[newPos] == '~' && filter[newPos+1] == '=': + packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterLessOrEqual]) + newPos++ + case packet == nil: + attribute += fmt.Sprintf("%c", filter[newPos]) } - new_pos++ + newPos++ } - if new_pos == len(filter) { + if newPos == len(filter) { err = NewError(ErrorFilterCompile, errors.New("Unexpected end of filter")) - return + return packet, newPos, err } - if p == nil { + if packet == nil { err = NewError(ErrorFilterCompile, errors.New("Error parsing filter")) - return + return packet, newPos, err } - p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute")) + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, attribute, "Attribute")) switch { - case p.Tag == FilterEqualityMatch && condition == "*": - p.Tag = FilterPresent - p.Description = FilterMap[uint64(p.Tag)] - case p.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*': + case packet.Tag == FilterEqualityMatch && condition == "*": + packet.Tag = FilterPresent + packet.Description = FilterMap[uint64(packet.Tag)] + case packet.Tag == FilterEqualityMatch && condition[0] == '*' && condition[len(condition)-1] == '*': // Any - p.Tag = FilterSubstrings - p.Description = FilterMap[uint64(p.Tag)] + packet.Tag = FilterSubstrings + packet.Description = FilterMap[uint64(packet.Tag)] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsAny, condition[1:len(condition)-1], "Any Substring")) - p.AppendChild(seq) - case p.Tag == FilterEqualityMatch && condition[0] == '*': + packet.AppendChild(seq) + case packet.Tag == FilterEqualityMatch && condition[0] == '*': // Final - p.Tag = FilterSubstrings - p.Description = FilterMap[uint64(p.Tag)] + packet.Tag = FilterSubstrings + packet.Description = FilterMap[uint64(packet.Tag)] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsFinal, condition[1:], "Final Substring")) - p.AppendChild(seq) - case p.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*': + packet.AppendChild(seq) + case packet.Tag == FilterEqualityMatch && condition[len(condition)-1] == '*': // Initial - p.Tag = FilterSubstrings - p.Description = FilterMap[uint64(p.Tag)] + packet.Tag = FilterSubstrings + packet.Description = FilterMap[uint64(packet.Tag)] seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings") seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimative, FilterSubstringsInitial, condition[:len(condition)-1], "Initial Substring")) - p.AppendChild(seq) + packet.AppendChild(seq) default: - p.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition")) + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimative, ber.TagOctetString, condition, "Condition")) } - new_pos++ - return + newPos++ + return packet, newPos, err } err = NewError(ErrorFilterCompile, errors.New("Reached end of filter without closing parens")) - return + return packet, newPos, err } @@ -290,9 +290,9 @@ func NewError(ResultCode uint8, Err error) *Error { return &Error{ResultCode: ResultCode, Err: Err} } -func getLDAPResultCode(p *ber.Packet) (code uint8, description string) { - if len(p.Children) >= 2 { - response := p.Children[1] +func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) { + if len(packet.Children) >= 2 { + response := packet.Children[1] if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 { code = uint8(response.Children[0].Value.(uint64)) description = response.Children[2].Value.(string) diff --git a/ldap_test.go b/ldap_test.go index f21a8a6..00cc7d5 100644 --- a/ldap_test.go +++ b/ldap_test.go @@ -90,13 +90,11 @@ func testMultiGoroutineSearch(t *testing.T, l *Conn, results chan *SearchResult, attributes, nil) sr, err := l.Search(search_request) - if err != nil { t.Errorf(err.String()) results <- nil return } - results <- sr } |