diff options
Diffstat (limited to 'server.go')
-rw-r--r-- | server.go | 595 |
1 files changed, 595 insertions, 0 deletions
diff --git a/server.go b/server.go new file mode 100644 index 0000000..4a46e6f --- /dev/null +++ b/server.go @@ -0,0 +1,595 @@ +package ldap + +import ( + "crypto/tls" + "errors" + "fmt" + "github.com/nmcclain/asn1-ber" + "io" + "log" + "net" + "strings" + "sync" +) + +type Binder interface { + Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) +} +type Searcher interface { + Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) +} +type Closer interface { + Close(conn net.Conn) error +} + +///////////////////////// +type Server struct { + bindFns map[string]Binder + searchFns map[string]Searcher + closeFns map[string]Closer + quit chan bool + EnforceLDAP bool + stats *Stats +} + +type Stats struct { + Conns int + Binds int + Unbinds int + Searches int + statsMutex sync.Mutex +} + +type ServerSearchResult struct { + Entries []*Entry + Referrals []string + Controls []Control + ResultCode uint64 +} + +///////////////////////// +func NewServer() *Server { + s := new(Server) + s.quit = make(chan bool) + + d := defaultHandler{} + s.bindFns = make(map[string]Binder) + s.searchFns = make(map[string]Searcher) + s.closeFns = make(map[string]Closer) + s.bindFns[""] = d + s.searchFns[""] = d + s.closeFns[""] = d + s.stats = nil + return s +} +func (server *Server) BindFunc(baseDN string, bindFn Binder) { + server.bindFns[baseDN] = bindFn +} +func (server *Server) SearchFunc(baseDN string, searchFn Searcher) { + server.searchFns[baseDN] = searchFn +} +func (server *Server) CloseFunc(baseDN string, closeFn Closer) { + server.closeFns[baseDN] = closeFn +} +func (server *Server) QuitChannel(quit chan bool) { + server.quit = quit +} + +func (server *Server) ListenAndServeTLS(listenString string, certFile string, keyFile string) error { + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return err + } + tlsConfig := tls.Config{Certificates: []tls.Certificate{cert}} + tlsConfig.ServerName = "localhost" + ln, err := tls.Listen("tcp", listenString, &tlsConfig) + if err != nil { + return err + } + err = server.serve(ln) + if err != nil { + return err + } + return nil +} + +func (server *Server) SetStats(enable bool) { + if enable { + server.stats = &Stats{} + } else { + server.stats = nil + } +} + +func (server *Server) GetStats() Stats { + defer func() { + server.stats.statsMutex.Unlock() + }() + server.stats.statsMutex.Lock() + return *server.stats +} + +func (server *Server) ListenAndServe(listenString string) error { + ln, err := net.Listen("tcp", listenString) + if err != nil { + return err + } + err = server.serve(ln) + if err != nil { + return err + } + return nil +} + +func (server *Server) serve(ln net.Listener) error { + newConn := make(chan net.Conn) + go func() { + for { + conn, err := ln.Accept() + if err != nil { + if !strings.HasSuffix(err.Error(), "use of closed network connection") { + log.Printf("Error accepting network connection: %s", err.Error()) + } + break + } + newConn <- conn + } + }() + +listener: + for { + select { + case c := <-newConn: + server.stats.countConns(1) + go server.handleConnection(c) + case <-server.quit: + ln.Close() + break listener + } + } + return nil +} + +///////////////////////// + +func (server *Server) handleConnection(conn net.Conn) { + boundDN := "" // "" == anonymous + +handler: + for { + // read incoming LDAP packet + packet, err := ber.ReadPacket(conn) + if err == io.EOF { // Client closed connection + break + } else if err != nil { + log.Printf("handleConnection ber.ReadPacket ERROR: %s", err.Error()) + break + } + + // sanity check this packet + if len(packet.Children) < 2 { + log.Print("len(packet.Children) < 2") + break + } + // check the message ID and ClassType + messageID := packet.Children[0].Value.(uint64) + req := packet.Children[1] + if req.ClassType != ber.ClassApplication { + log.Print("req.ClassType != ber.ClassApplication") + break + } + // handle controls if present + if len(packet.Children) > 2 { + controls := packet.Children[2] + ber.PrintPacket(controls) + log.Print("TODO Parse Controls") + /* + Controls ::= SEQUENCE OF control Control + + Control ::= SEQUENCE { + controlType LDAPOID, + criticality BOOLEAN DEFAULT FALSE, // unavailableCriticalExtension + controlValue OCTET STRING OPTIONAL } + */ + } + + // dispatch the LDAP operation + switch req.Tag { // ldap op code + default: + //log.Printf("Bound as %s", boundDN) + //ber.PrintPacket(packet) + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + + case ApplicationBindRequest: + server.stats.countBinds(1) + ldapResultCode := server.handleBindRequest(req, server.bindFns, conn) + if ldapResultCode == LDAPResultSuccess { + boundDN = req.Children[1].Value.(string) + } + responsePacket := encodeBindResponse(messageID, ldapResultCode) + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler + } + case ApplicationSearchRequest: + server.stats.countSearches(1) + if err := server.handleSearchRequest(req, messageID, boundDN, server.searchFns, conn); err != nil { + log.Printf("handleSearchRequest error %s", err.Error()) // TODO: make this more testable/better err handling - stop using log, stop using breaks? + e := err.(*Error) + if err = sendPacket(conn, encodeSearchDone(messageID, uint64(e.ResultCode))); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + break handler + } else { + if err = sendPacket(conn, encodeSearchDone(messageID, LDAPResultSuccess)); err != nil { + log.Printf("sendPacket error %s", err.Error()) + break handler + } + } + case ApplicationUnbindRequest: + server.stats.countUnbinds(1) + break handler // simply disconnect - this IS implemented + case ApplicationExtendedRequest: + responsePacket := encodeLDAPResponse(messageID, ApplicationExtendedResponse, LDAPResultProtocolError, "Unsupported extended request") + if err = sendPacket(conn, responsePacket); err != nil { + log.Printf("sendPacket error %s", err.Error()) + } + break handler + case ApplicationAbandonRequest: + log.Printf("Abandoning request!") + break handler + + // Unimplemented LDAP operations: + case ApplicationModifyRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationAddRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationDelRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationModifyDNRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + case ApplicationCompareRequest: + log.Printf("Unhandled operation: %s [%d]", ApplicationMap[req.Tag], req.Tag) + break handler + } + } + + for _, c := range server.closeFns { + c.Close(conn) + } + + conn.Close() +} + +///////////////////////// +func (server *Server) handleSearchRequest(req *ber.Packet, messageID uint64, boundDN string, searchFns map[string]Searcher, conn net.Conn) (resultErr error) { + defer func() { + if r := recover(); r != nil { + resultErr = NewError(LDAPResultOperationsError, fmt.Errorf("Search function panic: %s", r)) + } + }() + + searchReq, err := parseSearchRequest(boundDN, req) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + filterPacket, err := CompileFilter(searchReq.Filter) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + + fnNames := []string{} + for k := range searchFns { + fnNames = append(fnNames, k) + } + searchFn := routeFunc(searchReq.BaseDN, fnNames) + searchResp, err := searchFns[searchFn].Search(boundDN, searchReq, conn) + if err != nil { + return NewError(uint8(searchResp.ResultCode), err) + } + + if server.EnforceLDAP { + if searchReq.DerefAliases != NeverDerefAliases { // [-a {never|always|search|find} + // TODO: Server DerefAliases not implemented: RFC4511 4.5.1.3. SearchRequest.derefAliases + } + if len(searchReq.Controls) > 0 { + return NewError(LDAPResultOperationsError, errors.New("Server controls not implemented")) // TODO + } + if searchReq.TimeLimit > 0 { + return NewError(LDAPResultOperationsError, errors.New("Server TimeLimit not implemented")) // TODO + } + } + + for i, entry := range searchResp.Entries { + if server.EnforceLDAP { + // size limit + if searchReq.SizeLimit > 0 && i >= searchReq.SizeLimit { + break + } + + // filter + keep, resultCode := ServerApplyFilter(filterPacket, entry) + if resultCode != LDAPResultSuccess { + return NewError(uint8(resultCode), errors.New("ServerApplyFilter error")) + } + if !keep { + continue + } + + // constrained search scope + switch searchReq.Scope { + case ScopeWholeSubtree: // The scope is constrained to the entry named by baseObject and to all its subordinates. + case ScopeBaseObject: // The scope is constrained to the entry named by baseObject. + if entry.DN != searchReq.BaseDN { + continue + } + case ScopeSingleLevel: // The scope is constrained to the immediate subordinates of the entry named by baseObject. + parts := strings.Split(entry.DN, ",") + if len(parts) < 2 && entry.DN != searchReq.BaseDN { + continue + } + if dn := strings.Join(parts[1:], ","); dn != searchReq.BaseDN { + continue + } + } + + // attributes + if len(searchReq.Attributes) > 1 || (len(searchReq.Attributes) == 1 && len(searchReq.Attributes[0]) > 0) { + entry, err = filterAttributes(entry, searchReq.Attributes) + if err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + } + + // respond + responsePacket := encodeSearchResponse(messageID, searchReq, entry) + if err = sendPacket(conn, responsePacket); err != nil { + return NewError(LDAPResultOperationsError, err) + } + } + return nil +} + +///////////////////////// +func (server *Server) handleBindRequest(req *ber.Packet, bindFns map[string]Binder, conn net.Conn) (resultCode uint64) { + defer func() { + if r := recover(); r != nil { + resultCode = LDAPResultOperationsError + } + }() + + // we only support ldapv3 + ldapVersion := req.Children[0].Value.(uint64) + if ldapVersion != 3 { + log.Printf("Unsupported LDAP version: %d", ldapVersion) + return LDAPResultInappropriateAuthentication + } + + // auth types + bindDN := req.Children[1].Value.(string) + bindAuth := req.Children[2] + switch bindAuth.Tag { + default: + log.Print("Unknown LDAP authentication method") + return LDAPResultInappropriateAuthentication + case LDAPBindAuthSimple: + if len(req.Children) == 3 { + fnNames := []string{} + for k := range bindFns { + fnNames = append(fnNames, k) + } + bindFn := routeFunc(bindDN, fnNames) + resultCode, err := bindFns[bindFn].Bind(bindDN, bindAuth.Data.String(), conn) + if err != nil { + log.Printf("BindFn Error %s", err.Error()) + } + return resultCode + } else { + log.Print("Simple bind request has wrong # children. len(req.Children) != 3") + return LDAPResultInappropriateAuthentication + } + case LDAPBindAuthSASL: + log.Print("SASL authentication is not supported") + return LDAPResultInappropriateAuthentication + } + return LDAPResultOperationsError +} + +///////////////////////// +func sendPacket(conn net.Conn, packet *ber.Packet) error { + _, err := conn.Write(packet.Bytes()) + if err != nil { + log.Printf("Error Sending Message: %s", err.Error()) + return err + } + return nil +} + +///////////////////////// +func parseSearchRequest(boundDN string, req *ber.Packet) (SearchRequest, error) { + if len(req.Children) != 8 { + return SearchRequest{}, NewError(LDAPResultOperationsError, errors.New("Bad search request")) + } + + // Parse the request + baseObject := req.Children[0].Value.(string) + scope := int(req.Children[1].Value.(uint64)) + derefAliases := int(req.Children[2].Value.(uint64)) + sizeLimit := int(req.Children[3].Value.(uint64)) + timeLimit := int(req.Children[4].Value.(uint64)) + typesOnly := false + if req.Children[5].Value != nil { + typesOnly = req.Children[5].Value.(bool) + } + filter, err := DecompileFilter(req.Children[6]) + if err != nil { + return SearchRequest{}, err + } + attributes := []string{} + for _, attr := range req.Children[7].Children { + attributes = append(attributes, attr.Value.(string)) + } + searchReq := SearchRequest{baseObject, scope, + derefAliases, sizeLimit, timeLimit, + typesOnly, filter, attributes, nil} + + return searchReq, nil +} + +///////////////////////// +func routeFunc(dn string, funcNames []string) string { + bestPick := "" + for _, fn := range funcNames { + if strings.HasSuffix(dn, fn) { + l := len(strings.Split(bestPick, ",")) + if bestPick == "" { + l = 0 + } + if len(strings.Split(fn, ",")) > l { + bestPick = fn + } + } + } + return bestPick +} + +///////////////////////// +func filterAttributes(entry *Entry, attributes []string) (*Entry, error) { + // only return requested attributes + newAttributes := []*EntryAttribute{} + + for _, attr := range entry.Attributes { + for _, requested := range attributes { + if strings.ToLower(attr.Name) == strings.ToLower(requested) { + newAttributes = append(newAttributes, attr) + } + } + } + entry.Attributes = newAttributes + + return entry, nil +} + +///////////////////////// +type defaultHandler struct { +} + +func (h defaultHandler) Bind(bindDN, bindSimplePw string, conn net.Conn) (uint64, error) { + return LDAPResultInappropriateAuthentication, nil +} +func (h defaultHandler) Search(boundDN string, searchReq SearchRequest, conn net.Conn) (ServerSearchResult, error) { + return ServerSearchResult{make([]*Entry, 0), []string{}, []Control{}, LDAPResultSuccess}, nil +} +func (h defaultHandler) Close(conn net.Conn) error { + conn.Close() + return nil +} + +///////////////////////// +func encodeBindResponse(messageID uint64, ldapResultCode uint64) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + bindReponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindResponse, nil, "Bind Response") + bindReponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + bindReponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + + responsePacket.AppendChild(bindReponse) + + // ber.PrintPacket(responsePacket) + return responsePacket +} +func encodeSearchResponse(messageID uint64, req SearchRequest, res *Entry) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + + searchEntry := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultEntry, nil, "Search Result Entry") + searchEntry.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, res.DN, "Object Name")) + + attrs := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attributes:") + for _, attribute := range res.Attributes { + attrs.AppendChild(encodeSearchAttribute(attribute.Name, attribute.Values)) + } + + searchEntry.AppendChild(attrs) + responsePacket.AppendChild(searchEntry) + + return responsePacket +} + +func encodeSearchAttribute(name string, values []string) *ber.Packet { + packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Attribute") + packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, name, "Attribute Name")) + + valuesPacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSet, nil, "Attribute Values") + for _, value := range values { + valuesPacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, value, "Attribute Value")) + } + + packet.AppendChild(valuesPacket) + + return packet +} + +func encodeSearchDone(messageID uint64, ldapResultCode uint64) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + donePacket := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationSearchResultDone, nil, "Search result done") + donePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + donePacket.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "errorMessage: ")) + responsePacket.AppendChild(donePacket) + + return responsePacket +} + +func encodeLDAPResponse(messageID uint64, responseType uint8, ldapResultCode uint64, message string) *ber.Packet { + responsePacket := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Response") + responsePacket.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "Message ID")) + reponse := ber.Encode(ber.ClassApplication, ber.TypeConstructed, responseType, nil, ApplicationMap[responseType]) + reponse.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagEnumerated, ldapResultCode, "resultCode: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, "", "matchedDN: ")) + reponse.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, message, "errorMessage: ")) + responsePacket.AppendChild(reponse) + return responsePacket +} + +///////////////////////// +func (stats *Stats) countConns(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Conns += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countBinds(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Binds += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countUnbinds(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Unbinds += delta + stats.statsMutex.Unlock() + } +} +func (stats *Stats) countSearches(delta int) { + if stats != nil { + stats.statsMutex.Lock() + stats.Searches += delta + stats.statsMutex.Unlock() + } +} + +///////////////////////// |