summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorned <ned@appliedtrust.com>2014-11-23 20:03:05 +0100
committerned <ned@appliedtrust.com>2014-11-23 20:03:05 +0100
commitc43d537d5bb0eeb491153b00cdefcb54a6178187 (patch)
tree45187fde4a720d3f53d13ec45ac4fea8e27356e4
parentLDAP server support (diff)
downloadldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar.gz
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar.bz2
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar.lz
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar.xz
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.tar.zst
ldap-c43d537d5bb0eeb491153b00cdefcb54a6178187.zip
-rw-r--r--README.md16
-rw-r--r--control.go62
-rw-r--r--examples/server.go6
-rw-r--r--filter.go39
-rw-r--r--filter_test.go24
-rw-r--r--ldap.go44
-rw-r--r--modify.go6
-rw-r--r--server.go544
-rw-r--r--server_bind.go73
-rw-r--r--server_modify.go231
-rw-r--r--server_modify_test.go191
-rw-r--r--server_search.go216
-rw-r--r--server_search_test.go403
-rw-r--r--server_test.go376
-rw-r--r--tests/add.ldif6
-rw-r--r--tests/add2.ldif6
-rw-r--r--tests/cert_DONOTUSE.pem (renamed from examples/cert_DONOTUSE.pem)0
-rw-r--r--tests/key_DONOTUSE.pem (renamed from examples/key_DONOTUSE.pem)0
-rw-r--r--tests/modify.ldif16
-rw-r--r--tests/modify2.ldif10
20 files changed, 1515 insertions, 754 deletions
diff --git a/README.md b/README.md
index c72fca8..2418eab 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@ searchResults, err := l.Search(search)
The server library is modeled after net/http - you designate handlers for the LDAP operations you want to support (Bind/Search/etc.), then start the server with ListenAndServe(). You can specify different handlers for different baseDNs - they must implement the interfaces of the operations you want to support:
```go
type Binder interface {
- Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error)
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
}
type Searcher interface {
Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error)
@@ -76,7 +76,7 @@ func main() {
}
type ldapHandler struct {
}
-func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) {
if bindDN == "" && bindSimplePw == "" {
return ldap.LDAPResultSuccess, nil
}
@@ -89,25 +89,17 @@ func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, e
### LDAP server examples:
* examples/server.go: **Basic LDAP authentication (bind and search only)**
* examples/proxy.go: **Simple LDAP proxy server.**
-* server_test: **The tests have examples of all server functions.**
-
-*Warning: Do not use the example SSL certificates in production!*
+* server_test.go: **The _test.go files have examples of all server functions.**
### Known limitations:
* Golang's TLS implementation does not support SSLv2. Some old OSs require SSLv2, and are not able to connect to an LDAP server created with this library's ListenAndServeTLS() function. If you *must* support legacy (read: *insecure*) SSLv2 clients, run your LDAP server behind HAProxy.
### Not implemented:
-All of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**:
-* 4.1.11. Controls
+From the server perspective, all of [RFC4510](http://tools.ietf.org/html/rfc4510) is implemented **except**:
* 4.5.1.3. SearchRequest.derefAliases
* 4.5.1.5. SearchRequest.timeLimit
* 4.5.1.6. SearchRequest.typesOnly
-* 4.6. Modify Operation
-* 4.7. Add Operation
-* 4.8. Delete Operation
-* 4.9. Modify DN Operation
-* 4.10. Compare Operation
* 4.14. StartTLS Operation
*Server library by: [nmcclain](https://github.com/nmcclain)*
diff --git a/control.go b/control.go
index dd46fea..8376dd7 100644
--- a/control.go
+++ b/control.go
@@ -6,7 +6,6 @@ package ldap
import (
"fmt"
-
"github.com/nmcclain/asn1-ber"
)
@@ -99,40 +98,41 @@ func FindControl(controls []Control, controlType string) Control {
func DecodeControl(packet *ber.Packet) Control {
ControlType := packet.Children[0].Value.(string)
- Criticality := false
-
packet.Children[0].Description = "Control Type (" + ControlTypeMap[ControlType] + ")"
- value := packet.Children[1]
- if len(packet.Children) == 3 {
- value = packet.Children[2]
- packet.Children[1].Description = "Criticality"
- Criticality = packet.Children[1].Value.(bool)
- }
+ c := new(ControlString)
+ c.ControlType = ControlType
+ c.Criticality = false
+
+ if len(packet.Children) > 1 {
+ value := packet.Children[1]
+ if len(packet.Children) == 3 {
+ value = packet.Children[2]
+ packet.Children[1].Description = "Criticality"
+ c.Criticality = packet.Children[1].Value.(bool)
+ }
- value.Description = "Control Value"
- switch ControlType {
- case ControlTypePaging:
- value.Description += " (Paging)"
- c := new(ControlPaging)
- if value.Value != nil {
- valueChildren := ber.DecodePacket(value.Data.Bytes())
- value.Data.Truncate(0)
- value.Value = nil
- value.AppendChild(valueChildren)
+ value.Description = "Control Value"
+ switch ControlType {
+ case ControlTypePaging:
+ value.Description += " (Paging)"
+ c := new(ControlPaging)
+ if value.Value != nil {
+ valueChildren := ber.DecodePacket(value.Data.Bytes())
+ value.Data.Truncate(0)
+ value.Value = nil
+ value.AppendChild(valueChildren)
+ }
+ value = value.Children[0]
+ value.Description = "Search Control Value"
+ value.Children[0].Description = "Paging Size"
+ value.Children[1].Description = "Cookie"
+ c.PagingSize = uint32(value.Children[0].Value.(uint64))
+ c.Cookie = value.Children[1].Data.Bytes()
+ value.Children[1].Value = c.Cookie
+ return c
}
- value = value.Children[0]
- value.Description = "Search Control Value"
- value.Children[0].Description = "Paging Size"
- value.Children[1].Description = "Cookie"
- c.PagingSize = uint32(value.Children[0].Value.(uint64))
- c.Cookie = value.Children[1].Data.Bytes()
- value.Children[1].Value = c.Cookie
- return c
+ c.ControlValue = value.Value.(string)
}
- c := new(ControlString)
- c.ControlType = ControlType
- c.Criticality = Criticality
- c.ControlValue = value.Value.(string)
return c
}
diff --git a/examples/server.go b/examples/server.go
index dca74ed..3341991 100644
--- a/examples/server.go
+++ b/examples/server.go
@@ -24,7 +24,9 @@ func main() {
s.SearchFunc("", handler)
// start the server
- if err := s.ListenAndServe("localhost:3389"); err != nil {
+ listen := "localhost:3389"
+ log.Printf("Starting example LDAP server on %s", listen)
+ if err := s.ListenAndServe(listen); err != nil {
log.Fatal("LDAP Server Failed: %s", err.Error())
}
}
@@ -33,7 +35,7 @@ type ldapHandler struct {
}
///////////// Allow anonymous binds only
-func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (h ldapHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (ldap.LDAPResultCode, error) {
if bindDN == "" && bindSimplePw == "" {
return ldap.LDAPResultSuccess, nil
}
diff --git a/filter.go b/filter.go
index d7bc798..0c9706a 100644
--- a/filter.go
+++ b/filter.go
@@ -246,9 +246,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
}
}
-func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) {
- //log.Printf("%# v", pretty.Formatter(entry))
-
+func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, LDAPResultCode) {
switch FilterMap[f.Tag] {
default:
//log.Fatalf("Unknown LDAP filter code: %d", f.Tag)
@@ -308,30 +306,30 @@ func ServerApplyFilter(f *ber.Packet, entry *Entry) (bool, uint64) {
} else if !ok {
return true, LDAPResultSuccess
}
- case "FilterSubstrings":
+ case "FilterSubstrings": // TODO
return false, LDAPResultOperationsError
- case "FilterGreaterOrEqual":
+ case "FilterGreaterOrEqual": // TODO
return false, LDAPResultOperationsError
- case "FilterLessOrEqual":
+ case "FilterLessOrEqual": // TODO
return false, LDAPResultOperationsError
- case "FilterApproxMatch":
+ case "FilterApproxMatch": // TODO
return false, LDAPResultOperationsError
- case "FilterExtensibleMatch":
+ case "FilterExtensibleMatch": // TODO
return false, LDAPResultOperationsError
}
return false, LDAPResultSuccess
}
-func GetFilterType(filter string) (string, error) { // TODO <- test this
+func GetFilterObjectClass(filter string) (string, error) {
f, err := CompileFilter(filter)
if err != nil {
return "", err
}
- return parseFilterType(f)
+ return parseFilterObjectClass(f)
}
-func parseFilterType(f *ber.Packet) (string, error) {
- searchType := ""
+func parseFilterObjectClass(f *ber.Packet) (string, error) {
+ objectClass := ""
switch FilterMap[f.Tag] {
case "Equality Match":
if len(f.Children) != 2 {
@@ -339,42 +337,41 @@ func parseFilterType(f *ber.Packet) (string, error) {
}
attribute := strings.ToLower(f.Children[0].Value.(string))
value := f.Children[1].Value.(string)
-
if attribute == "objectclass" {
- searchType = strings.ToLower(value)
+ objectClass = strings.ToLower(value)
}
case "And":
for _, child := range f.Children {
- subType, err := parseFilterType(child)
+ subType, err := parseFilterObjectClass(child)
if err != nil {
return "", err
}
if len(subType) > 0 {
- searchType = subType
+ objectClass = subType
}
}
case "Or":
for _, child := range f.Children {
- subType, err := parseFilterType(child)
+ subType, err := parseFilterObjectClass(child)
if err != nil {
return "", err
}
if len(subType) > 0 {
- searchType = subType
+ objectClass = subType
}
}
case "Not":
if len(f.Children) != 1 {
return "", errors.New("Not filter must have only one child")
}
- subType, err := parseFilterType(f.Children[0])
+ subType, err := parseFilterObjectClass(f.Children[0])
if err != nil {
return "", err
}
if len(subType) > 0 {
- searchType = subType
+ objectClass = subType
}
}
- return strings.ToLower(searchType), nil
+ return strings.ToLower(objectClass), nil
}
diff --git a/filter_test.go b/filter_test.go
index fb54905..2e62f25 100644
--- a/filter_test.go
+++ b/filter_test.go
@@ -111,3 +111,27 @@ func BenchmarkFilterDecompile(b *testing.B) {
DecompileFilter(filters[i%maxIdx])
}
}
+
+func TestGetFilterObjectClass(t *testing.T) {
+ c, err := GetFilterObjectClass("(objectClass=*)")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ c, err = GetFilterObjectClass("(objectClass=posixAccount)")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "posixaccount" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ c, err = GetFilterObjectClass("(&(cn=awesome)(objectClass=posixGroup))")
+ if err != nil {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+ if c != "posixgroup" {
+ t.Errorf("GetFilterObjectClass failed")
+ }
+}
diff --git a/ldap.go b/ldap.go
index 42c50d6..e6d6d52 100644
--- a/ldap.go
+++ b/ldap.go
@@ -107,7 +107,7 @@ const (
ErrorDebugging = 203
)
-var LDAPResultCodeMap = map[uint8]string{
+var LDAPResultCodeMap = map[LDAPResultCode]string{
LDAPResultSuccess: "Success",
LDAPResultOperationsError: "Operations Error",
LDAPResultProtocolError: "Protocol Error",
@@ -155,6 +155,38 @@ const (
LDAPBindAuthSASL = 3
)
+type LDAPResultCode uint8
+
+type Attribute struct {
+ attrType string
+ attrVals []string
+}
+type AddRequest struct {
+ dn string
+ attributes []Attribute
+}
+type DeleteRequest struct {
+ dn string
+}
+type ModifyDNRequest struct {
+ dn string
+ newrdn string
+ deleteoldrdn bool
+ newSuperior string
+}
+type AttributeValueAssertion struct {
+ attributeDesc string
+ assertionValue string
+}
+type CompareRequest struct {
+ dn string
+ ava []AttributeValueAssertion
+}
+type ExtendedRequest struct {
+ requestName string
+ requestValue string
+}
+
// Adds descriptions to an LDAP Response packet for debugging
func addLDAPDescriptions(packet *ber.Packet) (err error) {
defer func() {
@@ -259,7 +291,7 @@ func addRequestDescriptions(packet *ber.Packet) {
func addDefaultLDAPResponseDescriptions(packet *ber.Packet) {
resultCode := packet.Children[1].Children[0].Value.(uint64)
- packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[uint8(resultCode)] + ")"
+ packet.Children[1].Children[0].Description = "Result Code (" + LDAPResultCodeMap[LDAPResultCode(resultCode)] + ")"
packet.Children[1].Children[1].Description = "Matched DN"
packet.Children[1].Children[2].Description = "Error Message"
if len(packet.Children[1].Children) > 3 {
@@ -285,22 +317,22 @@ func DebugBinaryFile(fileName string) error {
type Error struct {
Err error
- ResultCode uint8
+ ResultCode LDAPResultCode
}
func (e *Error) Error() string {
return fmt.Sprintf("LDAP Result Code %d %q: %s", e.ResultCode, LDAPResultCodeMap[e.ResultCode], e.Err.Error())
}
-func NewError(resultCode uint8, err error) error {
+func NewError(resultCode LDAPResultCode, err error) error {
return &Error{ResultCode: resultCode, Err: err}
}
-func getLDAPResultCode(packet *ber.Packet) (code uint8, description string) {
+func getLDAPResultCode(packet *ber.Packet) (code LDAPResultCode, description string) {
if len(packet.Children) >= 2 {
response := packet.Children[1]
if response.ClassType == ber.ClassApplication && response.TagType == ber.TypeConstructed && len(response.Children) == 3 {
- return uint8(response.Children[0].Value.(uint64)), response.Children[2].Value.(string)
+ return LDAPResultCode(response.Children[0].Value.(uint64)), response.Children[2].Value.(string)
}
}
diff --git a/modify.go b/modify.go
index 7decf2c..6ffe314 100644
--- a/modify.go
+++ b/modify.go
@@ -42,6 +42,12 @@ const (
ReplaceAttribute = 2
)
+var LDAPModifyAttributeMap = map[uint64]string{
+ AddAttribute: "Add",
+ DeleteAttribute: "Delete",
+ ReplaceAttribute: "Replace",
+}
+
type PartialAttribute struct {
attrType string
attrVals []string
diff --git a/server.go b/server.go
index 4a46e6f..dcb6406 100644
--- a/server.go
+++ b/server.go
@@ -2,8 +2,6 @@ package ldap
import (
"crypto/tls"
- "errors"
- "fmt"
"github.com/nmcclain/asn1-ber"
"io"
"log"
@@ -13,23 +11,55 @@ import (
)
type Binder interface {
- Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error)
+ Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error)
}
type Searcher interface {
- Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error)
+ Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error)
+}
+type Adder interface {
+ Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Modifier interface {
+ Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Deleter interface {
+ Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error)
+}
+type ModifyDNr interface {
+ ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Comparer interface {
+ Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Abandoner interface {
+ Abandon(boundDN string, conn net.Conn) error
+}
+type Extender interface {
+ Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error)
+}
+type Unbinder interface {
+ Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error)
}
type Closer interface {
- Close(conn net.Conn) error
+ Close(boundDN string, conn net.Conn) error
}
-/////////////////////////
+//
type Server struct {
- bindFns map[string]Binder
- searchFns map[string]Searcher
- closeFns map[string]Closer
- quit chan bool
+ BindFns map[string]Binder
+ SearchFns map[string]Searcher
+ AddFns map[string]Adder
+ ModifyFns map[string]Modifier
+ DeleteFns map[string]Deleter
+ ModifyDNFns map[string]ModifyDNr
+ CompareFns map[string]Comparer
+ AbandonFns map[string]Abandoner
+ ExtendedFns map[string]Extender
+ UnbindFns map[string]Unbinder
+ CloseFns map[string]Closer
+ Quit chan bool
EnforceLDAP bool
- stats *Stats
+ Stats *Stats
}
type Stats struct {
@@ -44,35 +74,75 @@ type ServerSearchResult struct {
Entries []*Entry
Referrals []string
Controls []Control
- ResultCode uint64
+ ResultCode LDAPResultCode
}
-/////////////////////////
+//
func NewServer() *Server {
s := new(Server)
- s.quit = make(chan bool)
+ s.Quit = make(chan bool)
d := defaultHandler{}
- s.bindFns = make(map[string]Binder)
- s.searchFns = make(map[string]Searcher)
- s.closeFns = make(map[string]Closer)
- s.bindFns[""] = d
- s.searchFns[""] = d
- s.closeFns[""] = d
- s.stats = nil
+ s.BindFns = make(map[string]Binder)
+ s.SearchFns = make(map[string]Searcher)
+ s.AddFns = make(map[string]Adder)
+ s.ModifyFns = make(map[string]Modifier)
+ s.DeleteFns = make(map[string]Deleter)
+ s.ModifyDNFns = make(map[string]ModifyDNr)
+ s.CompareFns = make(map[string]Comparer)
+ s.AbandonFns = make(map[string]Abandoner)
+ s.ExtendedFns = make(map[string]Extender)
+ s.UnbindFns = make(map[string]Unbinder)
+ s.CloseFns = make(map[string]Closer)
+ s.BindFunc("", d)
+ s.SearchFunc("", d)
+ s.AddFunc("", d)
+ s.ModifyFunc("", d)
+ s.DeleteFunc("", d)
+ s.ModifyDNFunc("", d)
+ s.CompareFunc("", d)
+ s.AbandonFunc("", d)
+ s.ExtendedFunc("", d)
+ s.UnbindFunc("", d)
+ s.CloseFunc("", d)
+ s.Stats = nil
return s
}
-func (server *Server) BindFunc(baseDN string, bindFn Binder) {
- server.bindFns[baseDN] = bindFn
+func (server *Server) BindFunc(baseDN string, f Binder) {
+ server.BindFns[baseDN] = f
+}
+func (server *Server) SearchFunc(baseDN string, f Searcher) {
+ server.SearchFns[baseDN] = f
+}
+func (server *Server) AddFunc(baseDN string, f Adder) {
+ server.AddFns[baseDN] = f
+}
+func (server *Server) ModifyFunc(baseDN string, f Modifier) {
+ server.ModifyFns[baseDN] = f
}
-func (server *Server) SearchFunc(baseDN string, searchFn Searcher) {
- server.searchFns[baseDN] = searchFn
+func (server *Server) DeleteFunc(baseDN string, f Deleter) {
+ server.DeleteFns[baseDN] = f
}
-func (server *Server) CloseFunc(baseDN string, closeFn Closer) {
- server.closeFns[baseDN] = closeFn
+func (server *Server) ModifyDNFunc(baseDN string, f ModifyDNr) {
+ server.ModifyDNFns[baseDN] = f
+}
+func (server *Server) CompareFunc(baseDN string, f Comparer) {
+ server.CompareFns[baseDN] = f
+}
+func (server *Server) AbandonFunc(baseDN string, f Abandoner) {
+ server.AbandonFns[baseDN] = f
+}
+func (server *Server) ExtendedFunc(baseDN string, f Extender) {
+ server.ExtendedFns[baseDN] = f
+}
+func (server *Server) UnbindFunc(baseDN string, f Unbinder) {
+ server.UnbindFns[baseDN] = f
+}
+func (server *Server) CloseFunc(baseDN string, f Closer) {
+ server.CloseFns[baseDN] = f
}
func (server *Server) QuitChannel(quit chan bool) {
- server.quit = quit
+ server.Quit = quit
}
func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error {
@@ -95,18 +165,18 @@ func (server *Server) ListenAndServeTLS(listenString string, certFile string, ke
func (server *Server) SetStats(enable bool) {
if enable {
- server.stats = &Stats{}
+ server.Stats = &Stats{}
} else {
- server.stats = nil
+ server.Stats = nil
}
}
func (server *Server) GetStats() Stats {
defer func() {
- server.stats.statsMutex.Unlock()
+ server.Stats.statsMutex.Unlock()
}()
- server.stats.statsMutex.Lock()
- return *server.stats
+ server.Stats.statsMutex.Lock()
+ return *server.Stats
}
func (server *Server) ListenAndServe(listenString string) error {
@@ -140,9 +210,9 @@ listener:
for {
select {
case c := <-newConn:
- server.stats.countConns(1)
+ server.Stats.countConns(1)
go server.handleConnection(c)
- case <-server.quit:
+ case <-server.Quit:
ln.Close()
break listener
}
@@ -150,8 +220,7 @@ listener:
return nil
}
-/////////////////////////
-
+//
func (server *Server) handleConnection(conn net.Conn) {
boundDN := "" // "" == anonymous
@@ -172,40 +241,46 @@ handler:
break
}
// check the message ID and ClassType
- messageID := packet.Children[0].Value.(uint64)
+ messageID, ok := packet.Children[0].Value.(uint64)
+ if !ok {
+ log.Print("malformed messageID")
+ break
+ }
req := packet.Children[1]
if req.ClassType != ber.ClassApplication {
log.Print("req.ClassType != ber.ClassApplication")
break
}
// handle controls if present
+ controls := []Control{}
if len(packet.Children) > 2 {
- controls := packet.Children[2]
- ber.PrintPacket(controls)
- log.Print("TODO Parse Controls")
- /*
- Controls ::= SEQUENCE OF control Control
-
- Control ::= SEQUENCE {
- controlType LDAPOID,
- criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension
- controlValue OCTET STRING OPTIONAL }
- */
+ for _, child := range packet.Children[2].Children {
+ controls = append(controls, DecodeControl(child))
+ }
}
+ //log.Printf("DEBUG: handling operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
+ //ber.PrintPacket(packet) // DEBUG
+
// dispatch the LDAP operation
switch req.Tag { // ldap op code
default:
- //log.Printf("Bound as %s", boundDN)
- //ber.PrintPacket(packet)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, LDAPResultOperationsError, "Unsupported operation: add")
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ }
log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
break handler
case ApplicationBindRequest:
- server.stats.countBinds(1)
- ldapResultCode := server.handleBindRequest(req, server.bindFns, conn)
+ server.Stats.countBinds(1)
+ ldapResultCode := HandleBindRequest(req, server.BindFns, conn)
if ldapResultCode == LDAPResultSuccess {
- boundDN = req.Children[1].Value.(string)
+ boundDN, ok = req.Children[1].Value.(string)
+ if !ok {
+ log.Printf("Malformed Bind DN")
+ break handler
+ }
}
responsePacket := encodeBindResponse(messageID, ldapResultCode)
if err = sendPacket(conn, responsePacket); err != nil {
@@ -213,12 +288,13 @@ handler:
break handler
}
case ApplicationSearchRequest:
- server.stats.countSearches(1)
- if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil {
+ server.Stats.countSearches(1)
+ if err := HandleSearchRequest(req, &controls, messageID, boundDN, server, conn); err != nil {
log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks?
e := err.(*Error)
- if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil {
+ if err = sendPacket(conn, encodeSearchDone(messageID, e.ResultCode)); err != nil {
log.Printf("sendPacket error %s", err.Error())
+ break handler
}
break handler
} else {
@@ -228,181 +304,65 @@ handler:
}
}
case ApplicationUnbindRequest:
- server.stats.countUnbinds(1)
- break handler // simply disconnect - this IS implemented
+ server.Stats.countUnbinds(1)
+ break handler // simply disconnect
case ApplicationExtendedRequest:
- responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request")
+ ldapResultCode := HandleExtendedRequest(req, boundDN, server.ExtendedFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
if err = sendPacket(conn, responsePacket); err != nil {
log.Printf("sendPacket error %s", err.Error())
+ break handler
}
- break handler
case ApplicationAbandonRequest:
- log.Printf("Abandoning request!")
+ HandleAbandonRequest(req, boundDN, server.AbandonFns, conn)
break handler
- // Unimplemented LDAP operations:
- case ApplicationModifyRequest:
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
case ApplicationAddRequest:
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
- case ApplicationDelRequest:
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
- case ApplicationModifyDNRequest:
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
- case ApplicationCompareRequest:
- log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag)
- break handler
- }
- }
-
- for _, c := range server.closeFns {
- c.Close(conn)
- }
-
- conn.Close()
-}
-
-/////////////////////////
-func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) {
- defer func() {
- if r := recover(); r != nil {
- resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
- }
- }()
-
- searchReq, err := parseSearchRequest(boundDN, req)
- if err != nil {
- return NewError(LDAPResultOperationsError, err)
- }
-
- filterPacket, err := CompileFilter(searchReq.Filter)
- if err != nil {
- return NewError(LDAPResultOperationsError, err)
- }
-
- fnNames := []string{}
- for k := range searchFns {
- fnNames = append(fnNames, k)
- }
- searchFn := routeFunc(searchReq.BaseDN, fnNames)
- searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn)
- if err != nil {
- return NewError(uint8(searchResp.ResultCode), err)
- }
-
- if server.EnforceLDAP {
- if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
- // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases
- }
- if len(searchReq.Controls) > 0 {
- return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO
- }
- if searchReq.TimeLimit > 0 {
- return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO
- }
- }
-
- for i, entry := range searchResp.Entries {
- if server.EnforceLDAP {
- // size limit
- if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
- break
+ ldapResultCode := HandleAddRequest(req, boundDN, server.AddFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationAddResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
}
-
- // filter
- keep, resultCode := ServerApplyFilter(filterPacket, entry)
- if resultCode != LDAPResultSuccess {
- return NewError(uint8(resultCode), errors.New("ServerApplyFilter error"))
+ case ApplicationModifyRequest:
+ ldapResultCode := HandleModifyRequest(req, boundDN, server.ModifyFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
}
- if !keep {
- continue
+ case ApplicationDelRequest:
+ ldapResultCode := HandleDeleteRequest(req, boundDN, server.DeleteFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationDelResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
}
-
- // constrained search scope
- switch searchReq.Scope {
- case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
- case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
- if entry.DN != searchReq.BaseDN {
- continue
- }
- case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
- parts := strings.Split(entry.DN, ",")
- if len(parts) < 2 && entry.DN != searchReq.BaseDN {
- continue
- }
- if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
- continue
- }
+ case ApplicationModifyDNRequest:
+ ldapResultCode := HandleModifyDNRequest(req, boundDN, server.ModifyDNFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationModifyDNResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
}
-
- // attributes
- if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
- entry, err = filterAttributes(entry, searchReq.Attributes)
- if err != nil {
- return NewError(LDAPResultOperationsError, err)
- }
+ case ApplicationCompareRequest:
+ ldapResultCode := HandleCompareRequest(req, boundDN, server.CompareFns, conn)
+ responsePacket := encodeLDAPResponse(messageID, ApplicationCompareResponse, ldapResultCode, LDAPResultCodeMap[ldapResultCode])
+ if err = sendPacket(conn, responsePacket); err != nil {
+ log.Printf("sendPacket error %s", err.Error())
+ break handler
}
}
-
- // respond
- responsePacket := encodeSearchResponse(messageID, searchReq, entry)
- if err = sendPacket(conn, responsePacket); err != nil {
- return NewError(LDAPResultOperationsError, err)
- }
}
- return nil
-}
-/////////////////////////
-func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) {
- defer func() {
- if r := recover(); r != nil {
- resultCode = LDAPResultOperationsError
- }
- }()
-
- // we only support ldapv3
- ldapVersion := req.Children[0].Value.(uint64)
- if ldapVersion != 3 {
- log.Printf("Unsupported LDAP version: %d", ldapVersion)
- return LDAPResultInappropriateAuthentication
+ for _, c := range server.CloseFns {
+ c.Close(boundDN, conn)
}
- // auth types
- bindDN := req.Children[1].Value.(string)
- bindAuth := req.Children[2]
- switch bindAuth.Tag {
- default:
- log.Print("Unknown LDAP authentication method")
- return LDAPResultInappropriateAuthentication
- case LDAPBindAuthSimple:
- if len(req.Children) == 3 {
- fnNames := []string{}
- for k := range bindFns {
- fnNames = append(fnNames, k)
- }
- bindFn := routeFunc(bindDN, fnNames)
- resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn)
- if err != nil {
- log.Printf("BindFn Error %s", err.Error())
- }
- return resultCode
- } else {
- log.Print("Simple bind request has wrong # children. len(req.Children) != 3")
- return LDAPResultInappropriateAuthentication
- }
- case LDAPBindAuthSASL:
- log.Print("SASL authentication is not supported")
- return LDAPResultInappropriateAuthentication
- }
- return LDAPResultOperationsError
+ conn.Close()
}
-/////////////////////////
+//
func sendPacket(conn net.Conn, packet *ber.Packet) error {
_, err := conn.Write(packet.Bytes())
if err != nil {
@@ -412,38 +372,7 @@ func sendPacket(conn net.Conn, packet *ber.Packet) error {
return nil
}
-/////////////////////////
-func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) {
- if len(req.Children) != 8 {
- return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
- }
-
- // Parse the request
- baseObject := req.Children[0].Value.(string)
- scope := int(req.Children[1].Value.(uint64))
- derefAliases := int(req.Children[2].Value.(uint64))
- sizeLimit := int(req.Children[3].Value.(uint64))
- timeLimit := int(req.Children[4].Value.(uint64))
- typesOnly := false
- if req.Children[5].Value != nil {
- typesOnly = req.Children[5].Value.(bool)
- }
- filter, err := DecompileFilter(req.Children[6])
- if err != nil {
- return SearchRequest{}, err
- }
- attributes := []string{}
- for _, attr := range req.Children[7].Children {
- attributes = append(attributes, attr.Value.(string))
- }
- searchReq := SearchRequest{baseObject, scope,
- derefAliases, sizeLimit, timeLimit,
- typesOnly, filter, attributes, nil}
-
- return searchReq, nil
-}
-
-/////////////////////////
+//
func routeFunc(dn string, funcNames []string) string {
bestPick := ""
for _, fn := range funcNames {
@@ -460,109 +389,58 @@ func routeFunc(dn string, funcNames []string) string {
return bestPick
}
-/////////////////////////
-func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
- // only return requested attributes
- newAttributes := []*EntryAttribute{}
-
- for _, attr := range entry.Attributes {
- for _, requested := range attributes {
- if strings.ToLower(attr.Name) == strings.ToLower(requested) {
- newAttributes = append(newAttributes, attr)
- }
- }
- }
- entry.Attributes = newAttributes
-
- return entry, nil
+//
+func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode LDAPResultCode, message string) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
+ reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
+ responsePacket.AppendChild(reponse)
+ return responsePacket
}
-/////////////////////////
+//
type defaultHandler struct {
}
-func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
- return LDAPResultInappropriateAuthentication, nil
+func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInvalidCredentials, nil
}
-func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+func (h defaultHandler) Search(boundDN string, req SearchRequest, conn net.Conn) (ServerSearchResult, error) {
return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil
}
-func (h defaultHandler) Close(conn net.Conn) error {
- conn.Close()
- return nil
+func (h defaultHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
}
-
-/////////////////////////
-func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet {
- responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
- responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
-
- bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
- bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
- bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
- bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
-
- responsePacket.AppendChild(bindReponse)
-
- // ber.PrintPacket(responsePacket)
- return responsePacket
+func (h defaultHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
}
-func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
- responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
- responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
-
- searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
- searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
-
- attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
- for _, attribute := range res.Attributes {
- attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
- }
-
- searchEntry.AppendChild(attrs)
- responsePacket.AppendChild(searchEntry)
-
- return responsePacket
+func (h defaultHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
}
-
-func encodeSearchAttribute(name string, values []string) *ber.Packet {
- packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
- packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
-
- valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
- for _, value := range values {
- valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
- }
-
- packet.AppendChild(valuesPacket)
-
- return packet
+func (h defaultHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
}
-
-func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet {
- responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
- responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
- donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
- donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
- donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
- donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
- responsePacket.AppendChild(donePacket)
-
- return responsePacket
+func (h defaultHandler) Compare(boundDN string, req CompareRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
}
-
-func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet {
- responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
- responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
- reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType])
- reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: "))
- reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
- reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: "))
- responsePacket.AppendChild(reponse)
- return responsePacket
+func (h defaultHandler) Abandon(boundDN string, conn net.Conn) error {
+ return nil
+}
+func (h defaultHandler) Extended(boundDN string, req ExtendedRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultProtocolError, nil
+}
+func (h defaultHandler) Unbind(boundDN string, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultSuccess, nil
+}
+func (h defaultHandler) Close(boundDN string, conn net.Conn) error {
+ conn.Close()
+ return nil
}
-/////////////////////////
+//
func (stats *Stats) countConns(delta int) {
if stats != nil {
stats.statsMutex.Lock()
@@ -592,4 +470,4 @@ func (stats *Stats) countSearches(delta int) {
}
}
-/////////////////////////
+//
diff --git a/server_bind.go b/server_bind.go
new file mode 100644
index 0000000..5a80bf5
--- /dev/null
+++ b/server_bind.go
@@ -0,0 +1,73 @@
+package ldap
+
+import (
+ "github.com/nmcclain/asn1-ber"
+ "log"
+ "net"
+)
+
+func HandleBindRequest(req *ber.Packet, fns map[string]Binder, conn net.Conn) (resultCode LDAPResultCode) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultCode = LDAPResultOperationsError
+ }
+ }()
+
+ // we only support ldapv3
+ ldapVersion, ok := req.Children[0].Value.(uint64)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ if ldapVersion != 3 {
+ log.Printf("Unsupported LDAP version: %d", ldapVersion)
+ return LDAPResultInappropriateAuthentication
+ }
+
+ // auth types
+ bindDN, ok := req.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ bindAuth := req.Children[2]
+ switch bindAuth.Tag {
+ default:
+ log.Print("Unknown LDAP authentication method")
+ return LDAPResultInappropriateAuthentication
+ case LDAPBindAuthSimple:
+ if len(req.Children) == 3 {
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(bindDN, fnNames)
+ resultCode, err := fns[fn].Bind(bindDN, bindAuth.Data.String(), conn)
+ if err != nil {
+ log.Printf("BindFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+ } else {
+ log.Print("Simple bind request has wrong # children. len(req.Children) != 3")
+ return LDAPResultInappropriateAuthentication
+ }
+ case LDAPBindAuthSASL:
+ log.Print("SASL authentication is not supported")
+ return LDAPResultInappropriateAuthentication
+ }
+ return LDAPResultOperationsError
+}
+
+func encodeBindResponse(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response")
+ bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+
+ responsePacket.AppendChild(bindReponse)
+
+ // ber.PrintPacket(responsePacket)
+ return responsePacket
+}
diff --git a/server_modify.go b/server_modify.go
new file mode 100644
index 0000000..0dca219
--- /dev/null
+++ b/server_modify.go
@@ -0,0 +1,231 @@
+package ldap
+
+import (
+ "github.com/nmcclain/asn1-ber"
+ "log"
+ "net"
+)
+
+func HandleAddRequest(req *ber.Packet, boundDN string, fns map[string]Adder, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ addReq := AddRequest{}
+ addReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ addReq.attributes = []Attribute{}
+ for _, attr := range req.Children[1].Children {
+ if len(attr.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+
+ a := Attribute{}
+ a.attrType, ok = attr.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ a.attrVals = []string{}
+ for _, val := range attr.Children[1].Children {
+ v, ok := val.Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ a.attrVals = append(a.attrVals, v)
+ }
+ addReq.attributes = append(addReq.attributes, a)
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Add(boundDN, addReq, conn)
+ if err != nil {
+ log.Printf("AddFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleDeleteRequest(req *ber.Packet, boundDN string, fns map[string]Deleter, conn net.Conn) (resultCode LDAPResultCode) {
+ deleteDN := ber.DecodeString(req.Data.Bytes())
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Delete(boundDN, deleteDN, conn)
+ if err != nil {
+ log.Printf("DeleteFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleModifyRequest(req *ber.Packet, boundDN string, fns map[string]Modifier, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ modReq := ModifyRequest{}
+ modReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ for _, change := range req.Children[1].Children {
+ if len(change.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr := PartialAttribute{}
+ attrs := change.Children[1].Children
+ if len(attrs) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr.attrType, ok = attrs[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ for _, val := range attrs[1].Children {
+ v, ok := val.Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ attr.attrVals = append(attr.attrVals, v)
+ }
+ op, ok := change.Children[0].Value.(uint64)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ switch op {
+ default:
+ log.Printf("Unrecognized Modify attribute %d", op)
+ return LDAPResultProtocolError
+ case AddAttribute:
+ modReq.Add(attr.attrType, attr.attrVals)
+ case DeleteAttribute:
+ modReq.Delete(attr.attrType, attr.attrVals)
+ case ReplaceAttribute:
+ modReq.Replace(attr.attrType, attr.attrVals)
+ }
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Modify(boundDN, modReq, conn)
+ if err != nil {
+ log.Printf("ModifyFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleCompareRequest(req *ber.Packet, boundDN string, fns map[string]Comparer, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ compReq := CompareRequest{}
+ compReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ ava := req.Children[1]
+ if len(ava.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ attr, ok := ava.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ val, ok := ava.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ compReq.ava = []AttributeValueAssertion{AttributeValueAssertion{attr, val}}
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Compare(boundDN, compReq, conn)
+ if err != nil {
+ log.Printf("CompareFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleExtendedRequest(req *ber.Packet, boundDN string, fns map[string]Extender, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 1 && len(req.Children) != 2 {
+ return LDAPResultProtocolError
+ }
+ name := ber.DecodeString(req.Children[0].Data.Bytes())
+ var val string
+ if len(req.Children) == 2 {
+ val = ber.DecodeString(req.Children[1].Data.Bytes())
+ }
+ extReq := ExtendedRequest{name, val}
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].Extended(boundDN, extReq, conn)
+ if err != nil {
+ log.Printf("ExtendedFn Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
+
+func HandleAbandonRequest(req *ber.Packet, boundDN string, fns map[string]Abandoner, conn net.Conn) error {
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ err := fns[fn].Abandon(boundDN, conn)
+ return err
+}
+
+func HandleModifyDNRequest(req *ber.Packet, boundDN string, fns map[string]ModifyDNr, conn net.Conn) (resultCode LDAPResultCode) {
+ if len(req.Children) != 3 && len(req.Children) != 4 {
+ return LDAPResultProtocolError
+ }
+ var ok bool
+ mdnReq := ModifyDNRequest{}
+ mdnReq.dn, ok = req.Children[0].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ mdnReq.newrdn, ok = req.Children[1].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ mdnReq.deleteoldrdn, ok = req.Children[2].Value.(bool)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ if len(req.Children) == 4 {
+ mdnReq.newSuperior, ok = req.Children[3].Value.(string)
+ if !ok {
+ return LDAPResultProtocolError
+ }
+ }
+ fnNames := []string{}
+ for k := range fns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(boundDN, fnNames)
+ resultCode, err := fns[fn].ModifyDN(boundDN, mdnReq, conn)
+ if err != nil {
+ log.Printf("ModifyDN Error %s", err.Error())
+ return LDAPResultOperationsError
+ }
+ return resultCode
+}
diff --git a/server_modify_test.go b/server_modify_test.go
new file mode 100644
index 0000000..d45b810
--- /dev/null
+++ b/server_modify_test.go
@@ -0,0 +1,191 @@
+package ldap
+
+import (
+ "net"
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+)
+
+//
+func TestAdd(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.AddFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ if strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+
+//
+func TestDelete(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.DeleteFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Delete Me,dc=example,dc=com")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "Delete Result: Success (0)") || !strings.Contains(string(out), "Additional info: Success") {
+ t.Errorf("ldapdelete failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapdelete", "-v", "-H", ldapURL, "-x", "cn=Bob,dc=example,dc=com")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "Success") || !strings.Contains(string(out), "ldap_delete: Insufficient access") {
+ t.Errorf("ldapdelete should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapdelete command timed out")
+ }
+ quit <- true
+}
+
+func TestModify(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.ModifyFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify.ldif")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapmodify failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapmodify", "-v", "-H", ldapURL, "-x", "-f", "tests/modify2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_modify: Insufficient access") || strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapmodify should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+
+/*
+func TestModifyDN(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", modifyTestHandler{})
+ s.AddFunc("", modifyTestHandler{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+ go func() {
+ cmd := exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add.ldif")
+ //ldapmodrdn -H ldap://localhost:3389 -x "uid=babs,dc=example,dc=com" "uid=babsy,dc=example,dc=com"
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapadd", "-v", "-H", ldapURL, "-x", "-f", "tests/add2.ldif")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "ldap_add: Insufficient access") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ if strings.Contains(string(out), "modify complete") {
+ t.Errorf("ldapadd should have failed: %v", string(out))
+ }
+ done <- true
+ }()
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapadd command timed out")
+ }
+ quit <- true
+}
+*/
+
+//
+type modifyTestHandler struct {
+}
+
+func (h modifyTestHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
+ if bindDN == "" && bindSimplePw == "" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInvalidCredentials, nil
+}
+func (h modifyTestHandler) Add(boundDN string, req AddRequest, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected contents of add.ldif:
+ if len(req.attributes) == 5 && req.dn == "cn=Barbara Jensen,dc=example,dc=com" &&
+ req.attributes[2].attrType == "sn" && len(req.attributes[2].attrVals) == 1 &&
+ req.attributes[2].attrVals[0] == "Jensen" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Delete(boundDN, deleteDN string, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected deleteDN
+ if deleteDN == "cn=Delete Me,dc=example,dc=com" {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) Modify(boundDN string, req ModifyRequest, conn net.Conn) (LDAPResultCode, error) {
+ // only succeed on expected contents of modify.ldif:
+ if req.dn == "cn=testy,dc=example,dc=com" && len(req.addAttributes) == 1 &&
+ len(req.deleteAttributes) == 3 && len(req.replaceAttributes) == 2 &&
+ req.deleteAttributes[2].attrType == "details" && len(req.deleteAttributes[2].attrVals) == 0 {
+ return LDAPResultSuccess, nil
+ }
+ return LDAPResultInsufficientAccessRights, nil
+}
+func (h modifyTestHandler) ModifyDN(boundDN string, req ModifyDNRequest, conn net.Conn) (LDAPResultCode, error) {
+ return LDAPResultInsufficientAccessRights, nil
+}
diff --git a/server_search.go b/server_search.go
new file mode 100644
index 0000000..a7d78ac
--- /dev/null
+++ b/server_search.go
@@ -0,0 +1,216 @@
+package ldap
+
+import (
+ "errors"
+ "fmt"
+ "github.com/nmcclain/asn1-ber"
+ "net"
+ "strings"
+)
+
+func HandleSearchRequest(req *ber.Packet, controls *[]Control, messageID uint64, boundDN string, server *Server, conn net.Conn) (resultErr error) {
+ defer func() {
+ if r := recover(); r != nil {
+ resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r))
+ }
+ }()
+
+ searchReq, err := parseSearchRequest(boundDN, req, controls)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ filterPacket, err := CompileFilter(searchReq.Filter)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+
+ fnNames := []string{}
+ for k := range server.SearchFns {
+ fnNames = append(fnNames, k)
+ }
+ fn := routeFunc(searchReq.BaseDN, fnNames)
+ searchResp, err := server.SearchFns[fn].Search(boundDN, searchReq, conn)
+ if err != nil {
+ return NewError(searchResp.ResultCode, err)
+ }
+
+ if server.EnforceLDAP {
+ if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find}
+ // Server DerefAliases not supported: RFC4511 4.5.1.3
+ return NewError(LDAPResultOperationsError, errors.New("Server DerefAliases not supported"))
+ }
+ if searchReq.TimeLimit > 0 {
+ // TODO: Server TimeLimit not implemented
+ }
+ }
+
+ for i, entry := range searchResp.Entries {
+ if server.EnforceLDAP {
+ // size limit
+ if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit {
+ break
+ }
+
+ // filter
+ keep, resultCode := ServerApplyFilter(filterPacket, entry)
+ if resultCode != LDAPResultSuccess {
+ return NewError(resultCode, errors.New("ServerApplyFilter error"))
+ }
+ if !keep {
+ continue
+ }
+
+ // constrained search scope
+ switch searchReq.Scope {
+ case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates.
+ case ScopeBaseObject: // The scope is constrained to the entry named by baseObject.
+ if entry.DN != searchReq.BaseDN {
+ continue
+ }
+ case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject.
+ parts := strings.Split(entry.DN, ",")
+ if len(parts) < 2 && entry.DN != searchReq.BaseDN {
+ continue
+ }
+ if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN {
+ continue
+ }
+ }
+
+ // attributes
+ if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) {
+ entry, err = filterAttributes(entry, searchReq.Attributes)
+ if err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ }
+
+ // respond
+ responsePacket := encodeSearchResponse(messageID, searchReq, entry)
+ if err = sendPacket(conn, responsePacket); err != nil {
+ return NewError(LDAPResultOperationsError, err)
+ }
+ }
+ return nil
+}
+
+/////////////////////////
+func parseSearchRequest(boundDN string, req *ber.Packet, controls *[]Control) (SearchRequest, error) {
+ if len(req.Children) != 8 {
+ return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request"))
+ }
+
+ // Parse the request
+ baseObject, ok := req.Children[0].Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ s, ok := req.Children[1].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ scope := int(s)
+ d, ok := req.Children[2].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ derefAliases := int(d)
+ s, ok = req.Children[3].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ sizeLimit := int(s)
+ t, ok := req.Children[4].Value.(uint64)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ timeLimit := int(t)
+ typesOnly := false
+ if req.Children[5].Value != nil {
+ typesOnly, ok = req.Children[5].Value.(bool)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ }
+ filter, err := DecompileFilter(req.Children[6])
+ if err != nil {
+ return SearchRequest{}, err
+ }
+ attributes := []string{}
+ for _, attr := range req.Children[7].Children {
+ a, ok := attr.Value.(string)
+ if !ok {
+ return SearchRequest{}, NewError(LDAPResultProtocolError, errors.New("Bad search request"))
+ }
+ attributes = append(attributes, a)
+ }
+ searchReq := SearchRequest{baseObject, scope,
+ derefAliases, sizeLimit, timeLimit,
+ typesOnly, filter, attributes, *controls}
+
+ return searchReq, nil
+}
+
+/////////////////////////
+func filterAttributes(entry *Entry, attributes []string) (*Entry, error) {
+ // only return requested attributes
+ newAttributes := []*EntryAttribute{}
+
+ for _, attr := range entry.Attributes {
+ for _, requested := range attributes {
+ if strings.ToLower(attr.Name) == strings.ToLower(requested) {
+ newAttributes = append(newAttributes, attr)
+ }
+ }
+ }
+ entry.Attributes = newAttributes
+
+ return entry, nil
+}
+
+/////////////////////////
+func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+
+ searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry")
+ searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name"))
+
+ attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:")
+ for _, attribute := range res.Attributes {
+ attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values))
+ }
+
+ searchEntry.AppendChild(attrs)
+ responsePacket.AppendChild(searchEntry)
+
+ return responsePacket
+}
+
+func encodeSearchAttribute(name string, values []string) *ber.Packet {
+ packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute")
+ packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name"))
+
+ valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values")
+ for _, value := range values {
+ valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value"))
+ }
+
+ packet.AppendChild(valuesPacket)
+
+ return packet
+}
+
+func encodeSearchDone(messageID uint64, ldapResultCode LDAPResultCode) *ber.Packet {
+ responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response")
+ responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID"))
+ donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done")
+ donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, uint64(ldapResultCode), "resultCode: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: "))
+ donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: "))
+ responsePacket.AppendChild(donePacket)
+
+ return responsePacket
+}
diff --git a/server_search_test.go b/server_search_test.go
new file mode 100644
index 0000000..c3f42b0
--- /dev/null
+++ b/server_search_test.go
@@ -0,0 +1,403 @@
+package ldap
+
+import (
+ "os/exec"
+ "strings"
+ "testing"
+ "time"
+)
+
+//
+func TestSearchSimpleOK(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "uidNumber: 5000") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 4") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+func TestSearchSizelimit(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 3") {
+ t.Errorf("ldapsearch sizelimit failed - not enough entries: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numEntries: 2") {
+ t.Errorf("ldapsearch sizelimit failed - too many entries: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestBindSearchMulti(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.BindFunc("", bindSimple{})
+ s.BindFunc("c=testz", bindSimple2{})
+ s.SearchFunc("", searchSimple{})
+ s.SearchFunc("c=testz", searchSimple2{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test",
+ "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing default bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("search default routing failed: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz",
+ "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("error routing custom bind/search functions: %v", string(out))
+ }
+ if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("search custom routing failed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchPanic(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchPanic{})
+ s.BindFunc("", bindAnonOK{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "result: 1 Operations error") {
+ t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+type compileSearchFilterTest struct {
+ name string
+ filterStr string
+ numResponses string
+}
+
+var searchFilterTestFilters = []compileSearchFilterTest{
+ compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"},
+ compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"},
+ compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"},
+ compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"},
+ compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"},
+ compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"},
+ compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"},
+ compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"},
+ compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"},
+ compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"},
+ compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"},
+ compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"},
+ compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"},
+ compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"},
+ compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"},
+ compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"},
+ compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"},
+ /*
+ compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
+ compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
+ compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
+ compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
+ */
+}
+
+/////////////////////////
+func TestSearchFiltering(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ for _, i := range searchFilterTestFilters {
+ t.Log(i.name)
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr)
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "numResponses: "+i.numResponses) {
+ t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchAttributes(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ filterString := ""
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn")
+ out, _ := cmd.CombinedOutput()
+
+ if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
+ t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out))
+ }
+ if !strings.Contains(string(out), "cn: ned") {
+ t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out))
+ }
+ if strings.Contains(string(out), "uidNumber") {
+ t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out))
+ }
+ if strings.Contains(string(out), "accountstatus") {
+ t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out))
+ }
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+/////////////////////////
+func TestSearchScope(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.EnforceLDAP = true
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchSimple{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out))
+ }
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
+ t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out))
+ }
+
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
+
+func TestSearchControls(t *testing.T) {
+ quit := make(chan bool)
+ done := make(chan bool)
+ go func() {
+ s := NewServer()
+ s.QuitChannel(quit)
+ s.SearchFunc("", searchControls{})
+ s.BindFunc("", bindSimple{})
+ if err := s.ListenAndServe(listenString); err != nil {
+ t.Errorf("s.ListenAndServe failed: %s", err.Error())
+ }
+ }()
+
+ serverBaseDN := "o=testers,c=test"
+
+ go func() {
+ cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-e", "1.2.3.4.5")
+ out, _ := cmd.CombinedOutput()
+ if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 2") {
+ t.Errorf("ldapsearch with control failed: %v", string(out))
+ }
+
+ cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
+ "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
+ out, _ = cmd.CombinedOutput()
+ if strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "result: 0 Success") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+ if !strings.Contains(string(out), "numResponses: 1") {
+ t.Errorf("ldapsearch without control failed: %v", string(out))
+ }
+
+ done <- true
+ }()
+
+ select {
+ case <-done:
+ case <-time.After(timeout):
+ t.Errorf("ldapsearch command timed out")
+ }
+ quit <- true
+}
diff --git a/server_test.go b/server_test.go
index 9386a4a..7e813ec 100644
--- a/server_test.go
+++ b/server_test.go
@@ -61,7 +61,7 @@ func TestBindAnonFail(t *testing.T) {
go func() {
cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "ldap_bind: Inappropriate authentication (48)") {
+ if !strings.Contains(string(out), "ldap_bind: Invalid credentials (49)") {
t.Errorf("ldapsearch failed: %v", string(out))
}
done <- true
@@ -186,7 +186,7 @@ func TestBindSSL(t *testing.T) {
s := NewServer()
s.QuitChannel(quit)
s.BindFunc("", bindAnonOK{})
- if err := s.ListenAndServeTLS(listenString, "examples/cert_DONOTUSE.pem", "examples/key_DONOTUSE.pem"); err != nil {
+ if err := s.ListenAndServeTLS(listenString, "tests/cert_DONOTUSE.pem", "tests/key_DONOTUSE.pem"); err != nil {
t.Errorf("s.ListenAndServeTLS failed: %s", err.Error())
}
}()
@@ -240,348 +240,6 @@ func TestBindPanic(t *testing.T) {
}
/////////////////////////
-func TestSearchSimpleOK(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.QuitChannel(quit)
- s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- serverBaseDN := "o=testers,c=test"
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test")
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- if !strings.Contains(string(out), "uidNumber: 5000") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- if !strings.Contains(string(out), "numResponses: 4") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- quit <- true
-}
-
-func TestSearchSizelimit(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.EnforceLDAP = true
- s.QuitChannel(quit)
- s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "9") // effectively no limit for this test
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- if !strings.Contains(string(out), "numEntries: 3") {
- t.Errorf("ldapsearch sizelimit failed - not enough entries: %v", string(out))
- }
-
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", "-z", "2")
- out, _ = cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("ldapsearch failed: %v", string(out))
- }
- if !strings.Contains(string(out), "numEntries: 2") {
- t.Errorf("ldapsearch sizelimit failed - too many entries: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- quit <- true
-}
-
-/////////////////////////
-func TestBindSearchMulti(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.QuitChannel(quit)
- s.BindFunc("", bindSimple{})
- s.BindFunc("c=testz", bindSimple2{})
- s.SearchFunc("", searchSimple{})
- s.SearchFunc("c=testz", searchSimple2{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test",
- "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "cn=ned")
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("error routing default bind/search functions: %v", string(out))
- }
- if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
- t.Errorf("search default routing failed: %v", string(out))
- }
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=testz",
- "-D", "cn=testy,o=testers,c=testz", "-w", "ZLike2test", "cn=hamburger")
- out, _ = cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 0 Success") {
- t.Errorf("error routing custom bind/search functions: %v", string(out))
- }
- if !strings.Contains(string(out), "dn: cn=hamburger,o=testers,c=testz") {
- t.Errorf("search custom routing failed: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
-
- quit <- true
-}
-
-/////////////////////////
-func TestSearchPanic(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.QuitChannel(quit)
- s.SearchFunc("", searchPanic{})
- s.BindFunc("", bindAnonOK{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x", "-b", "o=testers,c=test")
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "result: 1 Operations error") {
- t.Errorf("ldapsearch should have returned operations error due to panic: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- quit <- true
-}
-
-/////////////////////////
-type compileSearchFilterTest struct {
- name string
- filterStr string
- numResponses string
-}
-
-var searchFilterTestFilters = []compileSearchFilterTest{
- compileSearchFilterTest{name: "equalityOk", filterStr: "(uid=ned)", numResponses: "2"},
- compileSearchFilterTest{name: "equalityNo", filterStr: "(uid=foo)", numResponses: "1"},
- compileSearchFilterTest{name: "equalityOk", filterStr: "(objectclass=posixaccount)", numResponses: "4"},
- compileSearchFilterTest{name: "presentEmptyOk", filterStr: "", numResponses: "4"},
- compileSearchFilterTest{name: "presentOk", filterStr: "(objectclass=*)", numResponses: "4"},
- compileSearchFilterTest{name: "presentOk", filterStr: "(description=*)", numResponses: "3"},
- compileSearchFilterTest{name: "presentNo", filterStr: "(foo=*)", numResponses: "1"},
- compileSearchFilterTest{name: "andOk", filterStr: "(&(uid=ned)(objectclass=posixaccount))", numResponses: "2"},
- compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(objectclass=posixgroup))", numResponses: "1"},
- compileSearchFilterTest{name: "andNo", filterStr: "(&(uid=ned)(uid=trent))", numResponses: "1"},
- compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(uid=trent))", numResponses: "3"},
- compileSearchFilterTest{name: "orOk", filterStr: "(|(uid=ned)(objectclass=posixaccount))", numResponses: "4"},
- compileSearchFilterTest{name: "orNo", filterStr: "(|(uid=foo)(objectclass=foo))", numResponses: "1"},
- compileSearchFilterTest{name: "andOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(objectclass=posixaccount))", numResponses: "3"},
- compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=ned))", numResponses: "3"},
- compileSearchFilterTest{name: "notOk", filterStr: "(!(uid=foo))", numResponses: "4"},
- compileSearchFilterTest{name: "notAndOrOk", filterStr: "(&(|(uid=ned)(uid=trent))(!(objectclass=posixgroup)))", numResponses: "3"},
- /*
- compileSearchFilterTest{filterStr: "(sn=Mill*)", filterType: FilterSubstrings},
- compileSearchFilterTest{filterStr: "(sn=*Mill)", filterType: FilterSubstrings},
- compileSearchFilterTest{filterStr: "(sn=*Mill*)", filterType: FilterSubstrings},
- compileSearchFilterTest{filterStr: "(sn>=Miller)", filterType: FilterGreaterOrEqual},
- compileSearchFilterTest{filterStr: "(sn<=Miller)", filterType: FilterLessOrEqual},
- compileSearchFilterTest{filterStr: "(sn~=Miller)", filterType: FilterApproxMatch},
- */
-}
-
-/////////////////////////
-func TestSearchFiltering(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.EnforceLDAP = true
- s.QuitChannel(quit)
- s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- for _, i := range searchFilterTestFilters {
- t.Log(i.name)
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", i.filterStr)
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "numResponses: "+i.numResponses) {
- t.Errorf("ldapsearch failed - expected numResponses==%d: %v", i.numResponses, string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- }
- quit <- true
-}
-
-/////////////////////////
-func TestSearchAttributes(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.EnforceLDAP = true
- s.QuitChannel(quit)
- s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- go func() {
- filterString := ""
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", serverBaseDN, "-D", "cn=testy,"+serverBaseDN, "-w", "iLike2test", filterString, "cn")
- out, _ := cmd.CombinedOutput()
-
- if !strings.Contains(string(out), "dn: cn=ned,o=testers,c=test") {
- t.Errorf("ldapsearch failed - missing requested DN attribute: %v", string(out))
- }
- if !strings.Contains(string(out), "cn: ned") {
- t.Errorf("ldapsearch failed - missing requested CN attribute: %v", string(out))
- }
- if strings.Contains(string(out), "uidNumber") {
- t.Errorf("ldapsearch failed - uidNumber attr should not be displayed: %v", string(out))
- }
- if strings.Contains(string(out), "accountstatus") {
- t.Errorf("ldapsearch failed - accountstatus attr should not be displayed: %v", string(out))
- }
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- quit <- true
-}
-
-/////////////////////////
-func TestSearchScope(t *testing.T) {
- quit := make(chan bool)
- done := make(chan bool)
- go func() {
- s := NewServer()
- s.EnforceLDAP = true
- s.QuitChannel(quit)
- s.SearchFunc("", searchSimple{})
- s.BindFunc("", bindSimple{})
- if err := s.ListenAndServe(listenString); err != nil {
- t.Errorf("s.ListenAndServe failed: %s", err.Error())
- }
- }()
-
- go func() {
- cmd := exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "sub", "cn=trent")
- out, _ := cmd.CombinedOutput()
- if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
- t.Errorf("ldapsearch 'sub' scope failed - didn't find expected DN: %v", string(out))
- }
-
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
- out, _ = cmd.CombinedOutput()
- if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
- t.Errorf("ldapsearch 'one' scope failed - didn't find expected DN: %v", string(out))
- }
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", "c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "one", "cn=trent")
- out, _ = cmd.CombinedOutput()
- if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
- t.Errorf("ldapsearch 'one' scope failed - found unexpected DN: %v", string(out))
- }
-
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", "cn=trent,o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
- out, _ = cmd.CombinedOutput()
- if !strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
- t.Errorf("ldapsearch 'base' scope failed - didn't find expected DN: %v", string(out))
- }
- cmd = exec.Command("ldapsearch", "-H", ldapURL, "-x",
- "-b", "o=testers,c=test", "-D", "cn=testy,o=testers,c=test", "-w", "iLike2test", "-s", "base", "cn=trent")
- out, _ = cmd.CombinedOutput()
- if strings.Contains(string(out), "dn: cn=trent,o=testers,c=test") {
- t.Errorf("ldapsearch 'base' scope failed - found unexpected DN: %v", string(out))
- }
-
- done <- true
- }()
-
- select {
- case <-done:
- case <-time.After(timeout):
- t.Errorf("ldapsearch command timed out")
- }
- quit <- true
-}
-
-/////////////////////////
type testStatsWriter struct {
buffer *bytes.Buffer
}
@@ -625,7 +283,8 @@ func TestSearchStats(t *testing.T) {
}
stats := s.GetStats()
- if stats.Conns != 1 || stats.Binds != 1 {
+ log.Println(stats)
+ if stats.Conns != 2 || stats.Binds != 1 {
t.Errorf("Stats data missing or incorrect: %v", w.buffer.String())
}
quit <- true
@@ -635,7 +294,7 @@ func TestSearchStats(t *testing.T) {
type bindAnonOK struct {
}
-func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "" && bindSimplePw == "" {
return LDAPResultSuccess, nil
}
@@ -645,7 +304,7 @@ func (b bindAnonOK) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, er
type bindSimple struct {
}
-func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "cn=testy,o=testers,c=test" && bindSimplePw == "iLike2test" {
return LDAPResultSuccess, nil
}
@@ -655,7 +314,7 @@ func (b bindSimple) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, er
type bindSimple2 struct {
}
-func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
if bindDN == "cn=testy,o=testers,c=testz" && bindSimplePw == "ZLike2test" {
return LDAPResultSuccess, nil
}
@@ -665,7 +324,7 @@ func (b bindSimple2) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, e
type bindPanic struct {
}
-func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) {
+func (b bindPanic) Bind(bindDN, bindSimplePw string, conn net.Conn) (LDAPResultCode, error) {
panic("test panic at the disco")
return LDAPResultInvalidCredentials, nil
}
@@ -730,3 +389,22 @@ func (s searchPanic) Search(boundDN string, searchReq SearchRequest, conn net.Co
panic("this is a test panic")
return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
}
+
+type searchControls struct {
+}
+
+func (s searchControls) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) {
+ entries := []*Entry{}
+ if len(searchReq.Controls) == 1 && searchReq.Controls[0].GetControlType() == "1.2.3.4.5" {
+ newEntry := &Entry{"cn=hamburger,o=testers,c=testz", []*EntryAttribute{
+ &EntryAttribute{"cn", []string{"hamburger"}},
+ &EntryAttribute{"o", []string{"testers"}},
+ &EntryAttribute{"uidNumber", []string{"5000"}},
+ &EntryAttribute{"accountstatus", []string{"active"}},
+ &EntryAttribute{"uid", []string{"hamburger"}},
+ &EntryAttribute{"objectclass", []string{"posixaccount"}},
+ }}
+ entries = append(entries, newEntry)
+ }
+ return ServerSearchResult{entries, []string{}, []Control{}, LDAPResultSuccess}, nil
+}
diff --git a/tests/add.ldif b/tests/add.ldif
new file mode 100644
index 0000000..f8cdf71
--- /dev/null
+++ b/tests/add.ldif
@@ -0,0 +1,6 @@
+dn: cn=Barbara Jensen,dc=example,dc=com
+objectClass: person
+cn: Barbara Jensen
+sn: Jensen
+mail: bjensen@example.com
+uid: bjensen
diff --git a/tests/add2.ldif b/tests/add2.ldif
new file mode 100644
index 0000000..ccb71ad
--- /dev/null
+++ b/tests/add2.ldif
@@ -0,0 +1,6 @@
+dn: cn=Big Bob,dc=example,dc=com
+objectClass: person
+cn: Big Bob
+sn: Bob
+mail: bob@example.com
+uid: bob
diff --git a/examples/cert_DONOTUSE.pem b/tests/cert_DONOTUSE.pem
index ee14324..ee14324 100644
--- a/examples/cert_DONOTUSE.pem
+++ b/tests/cert_DONOTUSE.pem
diff --git a/examples/key_DONOTUSE.pem b/tests/key_DONOTUSE.pem
index 7feaa11..7feaa11 100644
--- a/examples/key_DONOTUSE.pem
+++ b/tests/key_DONOTUSE.pem
diff --git a/tests/modify.ldif b/tests/modify.ldif
new file mode 100644
index 0000000..ac969cc
--- /dev/null
+++ b/tests/modify.ldif
@@ -0,0 +1,16 @@
+dn: cn=testy,dc=example,dc=com
+changetype: modify
+replace: mail
+mail: modme@example.com
+-
+delete: manager
+-
+add: title
+title: Grand Poobah
+-
+delete: description
+-
+delete: details
+-
+replace: fullname
+fullname: Test Testerson
diff --git a/tests/modify2.ldif b/tests/modify2.ldif
new file mode 100644
index 0000000..794d7f4
--- /dev/null
+++ b/tests/modify2.ldif
@@ -0,0 +1,10 @@
+dn: cn=testo,dc=example,dc=com
+changetype: modify
+replace: mail
+mail: modid@example.com
+-
+delete: manager
+-
+add: title
+title: Other Poobah
+-