package client import ( "crypto/tls" "crypto/x509" "fmt" "os" "sync" "time" "github.com/go-ldap/ldap/v3" "github.com/rs/zerolog" "github.com/gosec/gsc-ops-api/internal/config" ) // LDAPClient manages a pool of LDAP connections type LDAPClient struct { cfg config.LDAPConfig pool chan *ldap.Conn mu sync.Mutex logger zerolog.Logger } // NewLDAPClient creates a new LDAP client with a connection pool func NewLDAPClient(cfg config.LDAPConfig, logger zerolog.Logger) (*LDAPClient, error) { if len(cfg.Servers) == 0 { return nil, fmt.Errorf("no LDAP servers configured") } c := &LDAPClient{ cfg: cfg, pool: make(chan *ldap.Conn, cfg.PoolSize), logger: logger.With().Str("component", "ldap").Logger(), } // Pre-fill pool with connections for i := 0; i < cfg.PoolSize; i++ { conn, err := c.connect() if err != nil { c.logger.Warn().Err(err).Int("index", i).Msg("failed to create initial LDAP connection") continue } c.pool <- conn } return c, nil } func (c *LDAPClient) connect() (*ldap.Conn, error) { var lastErr error for _, server := range c.cfg.Servers { var conn *ldap.Conn var err error if c.cfg.UseTLS { tlsCfg := &tls.Config{MinVersion: tls.VersionTLS12} if c.cfg.CAFile != "" { caCert, err := os.ReadFile(c.cfg.CAFile) if err != nil { lastErr = fmt.Errorf("failed to read CA file: %w", err) continue } pool := x509.NewCertPool() pool.AppendCertsFromPEM(caCert) tlsCfg.RootCAs = pool } conn, err = ldap.DialURL(server, ldap.DialWithTLSConfig(tlsCfg)) } else { conn, err = ldap.DialURL(server) } if err != nil { lastErr = fmt.Errorf("failed to connect to %s: %w", server, err) continue } conn.SetTimeout(10 * time.Second) if err := conn.Bind(c.cfg.BindDN, c.cfg.BindPass); err != nil { conn.Close() lastErr = fmt.Errorf("failed to bind to %s: %w", server, err) continue } return conn, nil } return nil, fmt.Errorf("all LDAP servers failed: %w", lastErr) } // Acquire gets a connection from the pool, creating one if needed func (c *LDAPClient) Acquire() (*ldap.Conn, error) { select { case conn := <-c.pool: // Test the connection with a no-op search _, err := conn.Search(&ldap.SearchRequest{ BaseDN: "", Scope: ldap.ScopeBaseObject, Filter: "(objectClass=*)", SizeLimit: 1, }) if err != nil { conn.Close() return c.connect() } return conn, nil default: return c.connect() } } // Release returns a connection to the pool func (c *LDAPClient) Release(conn *ldap.Conn) { if conn == nil { return } select { case c.pool <- conn: default: conn.Close() } } // Close closes all pooled connections func (c *LDAPClient) Close() { close(c.pool) for conn := range c.pool { conn.Close() } } // Health checks LDAP connectivity func (c *LDAPClient) Health() error { conn, err := c.Acquire() if err != nil { return err } defer c.Release(conn) return nil } // Search executes an LDAP search func (c *LDAPClient) Search(baseDN, filter string, attrs []string, sizeLimit int) ([]*ldap.Entry, error) { conn, err := c.Acquire() if err != nil { return nil, fmt.Errorf("failed to acquire LDAP connection: %w", err) } defer c.Release(conn) sr, err := conn.Search(&ldap.SearchRequest{ BaseDN: baseDN, Scope: ldap.ScopeWholeSubtree, Filter: filter, Attributes: attrs, SizeLimit: sizeLimit, }) if err != nil { // FreeIPA returns SizeLimitExceeded with partial results if ldap.IsErrorWithCode(err, ldap.LDAPResultSizeLimitExceeded) && sr != nil { return sr.Entries, nil } return nil, fmt.Errorf("LDAP search failed: %w", err) } return sr.Entries, nil } // SearchOne executes an LDAP search expecting exactly one result func (c *LDAPClient) SearchOne(baseDN, filter string, attrs []string) (*ldap.Entry, error) { entries, err := c.Search(baseDN, filter, attrs, 1) if err != nil { return nil, err } if len(entries) == 0 { return nil, nil } return entries[0], nil } // Add adds an LDAP entry func (c *LDAPClient) Add(req *ldap.AddRequest) error { conn, err := c.Acquire() if err != nil { return fmt.Errorf("failed to acquire LDAP connection: %w", err) } defer c.Release(conn) return conn.Add(req) } // Modify modifies an LDAP entry func (c *LDAPClient) Modify(req *ldap.ModifyRequest) error { conn, err := c.Acquire() if err != nil { return fmt.Errorf("failed to acquire LDAP connection: %w", err) } defer c.Release(conn) return conn.Modify(req) } // Delete deletes an LDAP entry func (c *LDAPClient) Delete(dn string) error { conn, err := c.Acquire() if err != nil { return fmt.Errorf("failed to acquire LDAP connection: %w", err) } defer c.Release(conn) return conn.Del(&ldap.DelRequest{DN: dn}) } // PasswordModify changes a user's password func (c *LDAPClient) PasswordModify(userDN, newPassword string) error { conn, err := c.Acquire() if err != nil { return fmt.Errorf("failed to acquire LDAP connection: %w", err) } defer c.Release(conn) _, err = conn.PasswordModify(&ldap.PasswordModifyRequest{ UserIdentity: userDN, NewPassword: newPassword, }) return err }