Initial import — snapshot from admin host /srv/gosec/gsc-ops-api
This repo had no version control prior to this commit. The import is a
straight snapshot of the working tree at 2026-05-03; the deployed
binary on fihelvop01 was being rebuilt from this source via `make
build` + scp into place, with no upstream review path.
The snapshot already includes one in-flight fix made on 2026-05-03 to
internal/service/persona.go:GetSelfModel — the handler queried
`source` and `strength` columns plus an `is_active = true` filter on
persona.persona_commitments, none of which exist on that table (its
shape is session-bound commitments with `status`, `commitment_meta`,
etc.). The query returned a 500 every time SynapseHub bootstrapped a
persona's self-model, dropping the IdentityConstraints / Commitments /
ConscienceStandards layer from the assembled prompt. The patched
query reads existing columns only (commitment_text, commitment_type),
filters on `status='active'`, and synthesises Source="learned" /
Strength=1.0 to keep the SelfModel response shape stable for callers.
Verified live: `GET /api/v1/personas/70f7cfd9-.../self-model` now
returns 200 with `{identityConstraints:[],commitments:[],
conscienceStandards:[]}` instead of 500.
Future changes go through PRs against this repo — no more bin-only
deploys.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
449
internal/service/carddav.go
Normal file
449
internal/service/carddav.go
Normal file
@@ -0,0 +1,449 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/md5"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// CardDAVService handles CardDAV principal, address book, and contact operations
|
||||
type CardDAVService struct {
|
||||
pool *pgxpool.Pool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewCardDAVService creates a new CardDAV service
|
||||
func NewCardDAVService(pool *pgxpool.Pool, logger zerolog.Logger) *CardDAVService {
|
||||
return &CardDAVService{
|
||||
pool: pool,
|
||||
logger: logger.With().Str("service", "carddav").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Principals ---
|
||||
|
||||
// ListPrincipals lists all principals
|
||||
func (s *CardDAVService) ListPrincipals(ctx context.Context) ([]types.CardDAVPrincipal, error) {
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT id, uri, email, displayname FROM principals ORDER BY id`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
principals := make([]types.CardDAVPrincipal, 0)
|
||||
for rows.Next() {
|
||||
var p types.CardDAVPrincipal
|
||||
var email, displayName *string
|
||||
if err := rows.Scan(&p.ID, &p.URI, &email, &displayName); err != nil {
|
||||
return nil, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
if email != nil {
|
||||
p.Email = *email
|
||||
}
|
||||
if displayName != nil {
|
||||
p.DisplayName = *displayName
|
||||
}
|
||||
principals = append(principals, p)
|
||||
}
|
||||
return principals, nil
|
||||
}
|
||||
|
||||
// GetPrincipal gets a principal by username
|
||||
func (s *CardDAVService) GetPrincipal(ctx context.Context, username string) (*types.CardDAVPrincipal, error) {
|
||||
uri := "principals/" + username
|
||||
|
||||
var p types.CardDAVPrincipal
|
||||
var email, displayName *string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, uri, email, displayname FROM principals WHERE uri = $1`, uri).
|
||||
Scan(&p.ID, &p.URI, &email, &displayName)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
if email != nil {
|
||||
p.Email = *email
|
||||
}
|
||||
if displayName != nil {
|
||||
p.DisplayName = *displayName
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// CreatePrincipal creates a new principal
|
||||
func (s *CardDAVService) CreatePrincipal(ctx context.Context, req *types.CardDAVPrincipalCreate) (*types.CardDAVPrincipal, error) {
|
||||
uri := "principals/" + req.Username
|
||||
|
||||
var id int
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO principals (uri, email, displayname) VALUES ($1, $2, $3) RETURNING id`,
|
||||
uri, nilIfEmpty(req.Email), nilIfEmpty(req.DisplayName)).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Str("username", req.Username).Int("id", id).Msg("Created principal")
|
||||
return s.GetPrincipal(ctx, req.Username)
|
||||
}
|
||||
|
||||
// DeletePrincipal deletes a principal and cascades to address books and contacts
|
||||
func (s *CardDAVService) DeletePrincipal(ctx context.Context, username string) error {
|
||||
uri := "principals/" + username
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx failed: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
// Delete contacts and changes for all address books owned by this principal
|
||||
_, err = tx.Exec(ctx,
|
||||
`DELETE FROM cards WHERE addressbookid IN (SELECT id FROM addressbooks WHERE principaluri = $1)`, uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete contacts failed: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx,
|
||||
`DELETE FROM addressbookchanges WHERE addressbookid IN (SELECT id FROM addressbooks WHERE principaluri = $1)`, uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete changes failed: %w", err)
|
||||
}
|
||||
|
||||
// Delete address books
|
||||
_, err = tx.Exec(ctx,
|
||||
`DELETE FROM addressbooks WHERE principaluri = $1`, uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete addressbooks failed: %w", err)
|
||||
}
|
||||
|
||||
// Delete principal
|
||||
ct, err := tx.Exec(ctx, `DELETE FROM principals WHERE uri = $1`, uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete principal failed: %w", err)
|
||||
}
|
||||
if ct.RowsAffected() == 0 {
|
||||
return fmt.Errorf("principal not found")
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Str("username", username).Msg("Deleted principal with cascade")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Address Books ---
|
||||
|
||||
// ListAddressBooks lists address books, optionally filtered by principal
|
||||
func (s *CardDAVService) ListAddressBooks(ctx context.Context, principal string) ([]types.AddressBook, error) {
|
||||
query := `SELECT id, principaluri, displayname, uri, description, synctoken FROM addressbooks`
|
||||
args := []interface{}{}
|
||||
|
||||
if principal != "" {
|
||||
query += ` WHERE principaluri = $1`
|
||||
args = append(args, "principals/"+principal)
|
||||
}
|
||||
query += ` ORDER BY id`
|
||||
|
||||
rows, err := s.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
books := make([]types.AddressBook, 0)
|
||||
for rows.Next() {
|
||||
var ab types.AddressBook
|
||||
var description *string
|
||||
if err := rows.Scan(&ab.ID, &ab.PrincipalURI, &ab.DisplayName, &ab.URI, &description, &ab.SyncToken); err != nil {
|
||||
return nil, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
if description != nil {
|
||||
ab.Description = *description
|
||||
}
|
||||
books = append(books, ab)
|
||||
}
|
||||
return books, nil
|
||||
}
|
||||
|
||||
// GetAddressBook gets an address book by ID
|
||||
func (s *CardDAVService) GetAddressBook(ctx context.Context, id int) (*types.AddressBook, error) {
|
||||
var ab types.AddressBook
|
||||
var description *string
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, principaluri, displayname, uri, description, synctoken FROM addressbooks WHERE id = $1`, id).
|
||||
Scan(&ab.ID, &ab.PrincipalURI, &ab.DisplayName, &ab.URI, &description, &ab.SyncToken)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
if description != nil {
|
||||
ab.Description = *description
|
||||
}
|
||||
return &ab, nil
|
||||
}
|
||||
|
||||
// CreateAddressBook creates a new address book
|
||||
func (s *CardDAVService) CreateAddressBook(ctx context.Context, req *types.AddressBookCreate) (*types.AddressBook, error) {
|
||||
var id int
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO addressbooks (principaluri, displayname, uri, description, synctoken)
|
||||
VALUES ($1, $2, $3, $4, 1) RETURNING id`,
|
||||
req.PrincipalURI, req.DisplayName, req.URI, nilIfEmpty(req.Description)).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Int("id", id).Str("uri", req.URI).Msg("Created address book")
|
||||
return s.GetAddressBook(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateAddressBook updates an address book
|
||||
func (s *CardDAVService) UpdateAddressBook(ctx context.Context, id int, req *types.AddressBookUpdate) (*types.AddressBook, error) {
|
||||
setClauses := []string{}
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if req.DisplayName != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("displayname = $%d", argIdx))
|
||||
args = append(args, *req.DisplayName)
|
||||
argIdx++
|
||||
}
|
||||
if req.Description != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("description = $%d", argIdx))
|
||||
args = append(args, *req.Description)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return s.GetAddressBook(ctx, id)
|
||||
}
|
||||
|
||||
args = append(args, id)
|
||||
query := fmt.Sprintf("UPDATE addressbooks SET %s WHERE id = $%d",
|
||||
join(setClauses, ", "), argIdx)
|
||||
|
||||
ct, err := s.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
if ct.RowsAffected() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return s.GetAddressBook(ctx, id)
|
||||
}
|
||||
|
||||
// DeleteAddressBook deletes an address book and its contacts
|
||||
func (s *CardDAVService) DeleteAddressBook(ctx context.Context, id int) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx failed: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx, `DELETE FROM cards WHERE addressbookid = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete contacts failed: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, `DELETE FROM addressbookchanges WHERE addressbookid = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete changes failed: %w", err)
|
||||
}
|
||||
|
||||
ct, err := tx.Exec(ctx, `DELETE FROM addressbooks WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete addressbook failed: %w", err)
|
||||
}
|
||||
if ct.RowsAffected() == 0 {
|
||||
return fmt.Errorf("address book not found")
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Int("id", id).Msg("Deleted address book with contacts")
|
||||
return nil
|
||||
}
|
||||
|
||||
// --- Contacts ---
|
||||
|
||||
// ListContacts lists contacts in an address book (metadata only, no carddata)
|
||||
func (s *CardDAVService) ListContacts(ctx context.Context, addressBookID int) ([]types.Contact, error) {
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT id, addressbookid, uri, lastmodified, etag, size
|
||||
FROM cards WHERE addressbookid = $1 ORDER BY id`, addressBookID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
contacts := make([]types.Contact, 0)
|
||||
for rows.Next() {
|
||||
var c types.Contact
|
||||
if err := rows.Scan(&c.ID, &c.AddressBookID, &c.URI, &c.LastModified, &c.ETag, &c.Size); err != nil {
|
||||
return nil, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
contacts = append(contacts, c)
|
||||
}
|
||||
return contacts, nil
|
||||
}
|
||||
|
||||
// GetContact gets a contact by address book ID and URI (returns full carddata)
|
||||
func (s *CardDAVService) GetContact(ctx context.Context, addressBookID int, uri string) (*types.Contact, error) {
|
||||
var c types.Contact
|
||||
var cardData []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, addressbookid, carddata, uri, lastmodified, etag, size
|
||||
FROM cards WHERE addressbookid = $1 AND uri = $2`, addressBookID, uri).
|
||||
Scan(&c.ID, &c.AddressBookID, &cardData, &c.URI, &c.LastModified, &c.ETag, &c.Size)
|
||||
if err != nil {
|
||||
if err == pgx.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("query failed: %w", err)
|
||||
}
|
||||
c.CardData = string(cardData)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// CreateContact creates a new contact in an address book
|
||||
func (s *CardDAVService) CreateContact(ctx context.Context, addressBookID int, req *types.ContactCreate) (*types.Contact, error) {
|
||||
etag := computeETag(req.CardData)
|
||||
size := len(req.CardData)
|
||||
lastModified := int(time.Now().Unix())
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx failed: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Exec(ctx,
|
||||
`INSERT INTO cards (addressbookid, carddata, uri, lastmodified, etag, size)
|
||||
VALUES ($1, $2, $3, $4, $5, $6)`,
|
||||
addressBookID, []byte(req.CardData), req.URI, lastModified, etag, size)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
// Record change and bump sync token (operation 1 = add)
|
||||
if err := addChange(ctx, tx, addressBookID, req.URI, 1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Int("addressbookId", addressBookID).Str("uri", req.URI).Msg("Created contact")
|
||||
return s.GetContact(ctx, addressBookID, req.URI)
|
||||
}
|
||||
|
||||
// UpdateContact updates a contact's vCard data
|
||||
func (s *CardDAVService) UpdateContact(ctx context.Context, addressBookID int, uri string, req *types.ContactUpdate) (*types.Contact, error) {
|
||||
etag := computeETag(req.CardData)
|
||||
size := len(req.CardData)
|
||||
lastModified := int(time.Now().Unix())
|
||||
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx failed: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
ct, err := tx.Exec(ctx,
|
||||
`UPDATE cards SET carddata = $1, lastmodified = $2, etag = $3, size = $4
|
||||
WHERE addressbookid = $5 AND uri = $6`,
|
||||
[]byte(req.CardData), lastModified, etag, size, addressBookID, uri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
if ct.RowsAffected() == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Record change and bump sync token (operation 2 = modify)
|
||||
if err := addChange(ctx, tx, addressBookID, uri, 2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Int("addressbookId", addressBookID).Str("uri", uri).Msg("Updated contact")
|
||||
return s.GetContact(ctx, addressBookID, uri)
|
||||
}
|
||||
|
||||
// DeleteContact deletes a contact from an address book
|
||||
func (s *CardDAVService) DeleteContact(ctx context.Context, addressBookID int, uri string) error {
|
||||
tx, err := s.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx failed: %w", err)
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
ct, err := tx.Exec(ctx,
|
||||
`DELETE FROM cards WHERE addressbookid = $1 AND uri = $2`, addressBookID, uri)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
if ct.RowsAffected() == 0 {
|
||||
return fmt.Errorf("contact not found")
|
||||
}
|
||||
|
||||
// Record change and bump sync token (operation 3 = delete)
|
||||
if err := addChange(ctx, tx, addressBookID, uri, 3); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return fmt.Errorf("commit failed: %w", err)
|
||||
}
|
||||
|
||||
s.logger.Info().Int("addressbookId", addressBookID).Str("uri", uri).Msg("Deleted contact")
|
||||
return nil
|
||||
}
|
||||
|
||||
// addChange records a change in addressbookchanges and bumps the sync token.
|
||||
// This is critical for CardDAV sync — without it, clients won't see incremental changes.
|
||||
// Operations: 1=add, 2=modify, 3=delete
|
||||
func addChange(ctx context.Context, tx pgx.Tx, addressBookID int, uri string, operation int) error {
|
||||
_, err := tx.Exec(ctx,
|
||||
`INSERT INTO addressbookchanges (uri, synctoken, addressbookid, operation)
|
||||
SELECT $1, synctoken, $2, $3 FROM addressbooks WHERE id = $2`,
|
||||
uri, addressBookID, operation)
|
||||
if err != nil {
|
||||
return fmt.Errorf("record change failed: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx,
|
||||
`UPDATE addressbooks SET synctoken = synctoken + 1 WHERE id = $1`, addressBookID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bump synctoken failed: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// computeETag computes the ETag for card data (raw MD5 hex, matching sabre/dav DB format)
|
||||
func computeETag(cardData string) string {
|
||||
hash := md5.Sum([]byte(cardData))
|
||||
return fmt.Sprintf("%x", hash)
|
||||
}
|
||||
161
internal/service/certificate.go
Normal file
161
internal/service/certificate.go
Normal file
@@ -0,0 +1,161 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/internal/client"
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// CertificateService handles EJBCA certificate operations
|
||||
type CertificateService struct {
|
||||
client *client.EJBCAClient
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewCertificateService creates a new certificate service
|
||||
func NewCertificateService(ejbcaClient *client.EJBCAClient, logger zerolog.Logger) *CertificateService {
|
||||
return &CertificateService{
|
||||
client: ejbcaClient,
|
||||
logger: logger.With().Str("service", "certificate").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ListCertificates searches for certificates
|
||||
func (s *CertificateService) ListCertificates(search string, limit int) ([]types.Certificate, error) {
|
||||
if limit <= 0 {
|
||||
limit = 50
|
||||
}
|
||||
|
||||
criteria := []client.CertSearchCriterion{}
|
||||
if search != "" {
|
||||
criteria = append(criteria, client.CertSearchCriterion{
|
||||
Property: "QUERY",
|
||||
Value: search,
|
||||
Operation: "LIKE",
|
||||
})
|
||||
}
|
||||
|
||||
certs, err := s.client.SearchCertificates(&client.CertSearchRequest{
|
||||
MaxResults: limit,
|
||||
Criteria: criteria,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]types.Certificate, 0, len(certs))
|
||||
for _, c := range certs {
|
||||
cert := types.Certificate{
|
||||
SerialNumber: c.SerialNumber,
|
||||
SubjectDN: c.SubjectDN,
|
||||
IssuerDN: c.IssuerDN,
|
||||
Status: c.Status,
|
||||
CAName: c.CAName,
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, c.NotBefore); err == nil {
|
||||
cert.NotBefore = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, c.NotAfter); err == nil {
|
||||
cert.NotAfter = t
|
||||
}
|
||||
result = append(result, cert)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetCertificate gets a certificate by serial number
|
||||
func (s *CertificateService) GetCertificate(serialNumber, issuerDN string) (*types.Certificate, error) {
|
||||
c, err := s.client.GetCertificate(issuerDN, serialNumber)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert := &types.Certificate{
|
||||
SerialNumber: c.SerialNumber,
|
||||
SubjectDN: c.SubjectDN,
|
||||
IssuerDN: c.IssuerDN,
|
||||
Status: c.Status,
|
||||
CAName: c.CAName,
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, c.NotBefore); err == nil {
|
||||
cert.NotBefore = t
|
||||
}
|
||||
if t, err := time.Parse(time.RFC3339, c.NotAfter); err == nil {
|
||||
cert.NotAfter = t
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// RequestCertificate requests a new certificate from EJBCA
|
||||
func (s *CertificateService) RequestCertificate(req *types.CertRequest) (*types.Certificate, error) {
|
||||
san := buildSANString(req.SubjectDN, req.SANs)
|
||||
|
||||
enrollReq := &client.CertEnrollRequest{
|
||||
CertificateProfileName: req.CertProfileName,
|
||||
EndEntityProfileName: req.EndEntityName,
|
||||
CAName: req.CAName,
|
||||
Username: req.EndEntityName,
|
||||
Password: "internal",
|
||||
IncludeChain: true,
|
||||
SubjectAltName: san,
|
||||
}
|
||||
|
||||
c, err := s.client.EnrollCertificate(enrollReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cert := &types.Certificate{
|
||||
SerialNumber: c.SerialNumber,
|
||||
SubjectDN: c.SubjectDN,
|
||||
IssuerDN: c.IssuerDN,
|
||||
Status: c.Status,
|
||||
}
|
||||
return cert, nil
|
||||
}
|
||||
|
||||
// buildSANString builds an EJBCA-format SAN string (e.g. "dNSName=foo,dNSName=bar").
|
||||
// If sans is empty, extracts CN from subjectDN as a fallback DNS SAN.
|
||||
func buildSANString(subjectDN string, sans []string) string {
|
||||
if len(sans) > 0 {
|
||||
parts := make([]string, 0, len(sans))
|
||||
for _, s := range sans {
|
||||
if s != "" {
|
||||
parts = append(parts, fmt.Sprintf("dNSName=%s", s))
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// Fallback: extract CN from SubjectDN and use as DNS SAN
|
||||
cn := extractCN(subjectDN)
|
||||
if cn != "" {
|
||||
return fmt.Sprintf("dNSName=%s", cn)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractCN extracts the CN value from a SubjectDN string like "CN=foo.bar,O=Org"
|
||||
func extractCN(subjectDN string) string {
|
||||
for _, part := range strings.Split(subjectDN, ",") {
|
||||
part = strings.TrimSpace(part)
|
||||
if strings.HasPrefix(part, "CN=") {
|
||||
return strings.TrimPrefix(part, "CN=")
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// RevokeCertificate revokes a certificate
|
||||
func (s *CertificateService) RevokeCertificate(serialNumber string, req *types.CertRevoke) error {
|
||||
reason := req.Reason
|
||||
if reason == "" {
|
||||
reason = "UNSPECIFIED"
|
||||
}
|
||||
return s.client.RevokeCertificate(req.IssuerDN, serialNumber, reason)
|
||||
}
|
||||
413
internal/service/database.go
Normal file
413
internal/service/database.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// DatabaseService handles tenant and user database operations
|
||||
type DatabaseService struct {
|
||||
pool *pgxpool.Pool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewDatabaseService creates a new database service
|
||||
func NewDatabaseService(pool *pgxpool.Pool, logger zerolog.Logger) *DatabaseService {
|
||||
return &DatabaseService{
|
||||
pool: pool,
|
||||
logger: logger.With().Str("service", "database").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ListTenants lists tenants with optional filters
|
||||
func (s *DatabaseService) ListTenants(ctx context.Context, params types.ListParams) ([]types.Tenant, int64, error) {
|
||||
params = types.DefaultListParams(params)
|
||||
|
||||
countQuery := `SELECT COUNT(*) FROM admin.tenants WHERE 1=1`
|
||||
listQuery := `SELECT id, customer_id, code, name, display_name, domain, logo_url, primary_color,
|
||||
max_users, max_storage_gb, max_recording_hours, is_active, metadata, created_at, updated_at
|
||||
FROM admin.tenants WHERE 1=1`
|
||||
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if params.Status != "" {
|
||||
if params.Status == "active" {
|
||||
countQuery += " AND is_active = true"
|
||||
listQuery += " AND is_active = true"
|
||||
} else if params.Status == "inactive" {
|
||||
countQuery += " AND is_active = false"
|
||||
listQuery += " AND is_active = false"
|
||||
}
|
||||
}
|
||||
if params.Search != "" {
|
||||
countQuery += fmt.Sprintf(" AND (name ILIKE $%d OR code ILIKE $%d OR domain ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
listQuery += fmt.Sprintf(" AND (name ILIKE $%d OR code ILIKE $%d OR domain ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
args = append(args, "%"+params.Search+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("count query failed: %w", err)
|
||||
}
|
||||
|
||||
listQuery += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, params.Limit, params.Offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
tenants := make([]types.Tenant, 0)
|
||||
for rows.Next() {
|
||||
var t types.Tenant
|
||||
var metadataJSON []byte
|
||||
if err := rows.Scan(&t.ID, &t.CustomerID, &t.Code, &t.Name, &t.DisplayName, &t.Domain,
|
||||
&t.LogoURL, &t.PrimaryColor, &t.MaxUsers, &t.MaxStorageGB, &t.MaxRecordingHours,
|
||||
&t.IsActive, &metadataJSON, &t.CreatedAt, &t.UpdatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &t.Metadata)
|
||||
}
|
||||
tenants = append(tenants, t)
|
||||
}
|
||||
|
||||
return tenants, total, nil
|
||||
}
|
||||
|
||||
// GetTenant gets a tenant by ID
|
||||
func (s *DatabaseService) GetTenant(ctx context.Context, id uuid.UUID) (*types.Tenant, error) {
|
||||
var t types.Tenant
|
||||
var metadataJSON []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, customer_id, code, name, display_name, domain, logo_url, primary_color,
|
||||
max_users, max_storage_gb, max_recording_hours, is_active, metadata, created_at, updated_at
|
||||
FROM admin.tenants WHERE id = $1`, id).
|
||||
Scan(&t.ID, &t.CustomerID, &t.Code, &t.Name, &t.DisplayName, &t.Domain,
|
||||
&t.LogoURL, &t.PrimaryColor, &t.MaxUsers, &t.MaxStorageGB, &t.MaxRecordingHours,
|
||||
&t.IsActive, &metadataJSON, &t.CreatedAt, &t.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &t.Metadata)
|
||||
}
|
||||
return &t, nil
|
||||
}
|
||||
|
||||
// CreateTenant creates a new tenant
|
||||
func (s *DatabaseService) CreateTenant(ctx context.Context, req *types.TenantCreate) (*types.Tenant, error) {
|
||||
id := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
var metadataJSON []byte
|
||||
if req.Metadata != nil {
|
||||
var err error
|
||||
metadataJSON, err = json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`INSERT INTO admin.tenants (id, customer_id, code, name, display_name, domain, logo_url, primary_color,
|
||||
max_users, max_storage_gb, max_recording_hours, is_active, metadata, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, true, $12, $13, $13)`,
|
||||
id, req.CustomerID, req.Code, req.Name, nilIfEmpty(req.DisplayName), nilIfEmpty(req.Domain),
|
||||
nilIfEmpty(req.LogoURL), nilIfEmpty(req.PrimaryColor),
|
||||
req.MaxUsers, req.MaxStorageGB, req.MaxRecordingHours, metadataJSON, now)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
return s.GetTenant(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateTenant updates a tenant
|
||||
func (s *DatabaseService) UpdateTenant(ctx context.Context, id uuid.UUID, req *types.TenantUpdate) (*types.Tenant, error) {
|
||||
setClauses := []string{}
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if req.Name != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("name = $%d", argIdx))
|
||||
args = append(args, *req.Name)
|
||||
argIdx++
|
||||
}
|
||||
if req.DisplayName != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("display_name = $%d", argIdx))
|
||||
args = append(args, *req.DisplayName)
|
||||
argIdx++
|
||||
}
|
||||
if req.Domain != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("domain = $%d", argIdx))
|
||||
args = append(args, *req.Domain)
|
||||
argIdx++
|
||||
}
|
||||
if req.LogoURL != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("logo_url = $%d", argIdx))
|
||||
args = append(args, *req.LogoURL)
|
||||
argIdx++
|
||||
}
|
||||
if req.PrimaryColor != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("primary_color = $%d", argIdx))
|
||||
args = append(args, *req.PrimaryColor)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxUsers != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("max_users = $%d", argIdx))
|
||||
args = append(args, *req.MaxUsers)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxStorageGB != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("max_storage_gb = $%d", argIdx))
|
||||
args = append(args, *req.MaxStorageGB)
|
||||
argIdx++
|
||||
}
|
||||
if req.MaxRecordingHours != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("max_recording_hours = $%d", argIdx))
|
||||
args = append(args, *req.MaxRecordingHours)
|
||||
argIdx++
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("is_active = $%d", argIdx))
|
||||
args = append(args, *req.IsActive)
|
||||
argIdx++
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
metadataJSON, err := json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
setClauses = append(setClauses, fmt.Sprintf("metadata = $%d", argIdx))
|
||||
args = append(args, metadataJSON)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return s.GetTenant(ctx, id)
|
||||
}
|
||||
|
||||
setClauses = append(setClauses, fmt.Sprintf("updated_at = $%d", argIdx))
|
||||
args = append(args, time.Now().UTC())
|
||||
argIdx++
|
||||
|
||||
args = append(args, id)
|
||||
query := fmt.Sprintf("UPDATE admin.tenants SET %s WHERE id = $%d",
|
||||
join(setClauses, ", "), argIdx)
|
||||
|
||||
_, err := s.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
|
||||
return s.GetTenant(ctx, id)
|
||||
}
|
||||
|
||||
// SoftDeleteTenant deactivates a tenant
|
||||
func (s *DatabaseService) SoftDeleteTenant(ctx context.Context, id uuid.UUID) error {
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`UPDATE admin.tenants SET is_active = false, updated_at = $1 WHERE id = $2`,
|
||||
time.Now().UTC(), id)
|
||||
return err
|
||||
}
|
||||
|
||||
// ListUsers lists users with optional filters
|
||||
func (s *DatabaseService) ListUsers(ctx context.Context, params types.ListParams) ([]types.DBUser, int64, error) {
|
||||
params = types.DefaultListParams(params)
|
||||
|
||||
countQuery := `SELECT COUNT(*) FROM admin.users WHERE 1=1`
|
||||
listQuery := `SELECT id, gscsid, first_name, last_name, display_name, email, timezone, locale, status,
|
||||
last_login_at, last_activity_at, metadata, created_at, updated_at
|
||||
FROM admin.users WHERE 1=1`
|
||||
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if params.Status != "" {
|
||||
countQuery += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||
listQuery += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||
args = append(args, params.Status)
|
||||
argIdx++
|
||||
}
|
||||
if params.Search != "" {
|
||||
countQuery += fmt.Sprintf(" AND (gscsid ILIKE $%d OR display_name ILIKE $%d OR email ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
listQuery += fmt.Sprintf(" AND (gscsid ILIKE $%d OR display_name ILIKE $%d OR email ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
args = append(args, "%"+params.Search+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("count query failed: %w", err)
|
||||
}
|
||||
|
||||
listQuery += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, params.Limit, params.Offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
users := make([]types.DBUser, 0)
|
||||
for rows.Next() {
|
||||
var u types.DBUser
|
||||
var metadataJSON []byte
|
||||
if err := rows.Scan(&u.ID, &u.GscSID, &u.FirstName, &u.LastName, &u.DisplayName, &u.Email, &u.Timezone, &u.Locale, &u.Status,
|
||||
&u.LastLoginAt, &u.LastActivityAt, &metadataJSON, &u.CreatedAt, &u.UpdatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &u.Metadata)
|
||||
}
|
||||
users = append(users, u)
|
||||
}
|
||||
|
||||
return users, total, nil
|
||||
}
|
||||
|
||||
// GetUser gets a user by ID
|
||||
func (s *DatabaseService) GetUser(ctx context.Context, id uuid.UUID) (*types.DBUser, error) {
|
||||
var u types.DBUser
|
||||
var metadataJSON []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, gscsid, first_name, last_name, display_name, email, timezone, locale, status,
|
||||
last_login_at, last_activity_at, metadata, created_at, updated_at
|
||||
FROM admin.users WHERE id = $1`, id).
|
||||
Scan(&u.ID, &u.GscSID, &u.FirstName, &u.LastName, &u.DisplayName, &u.Email, &u.Timezone, &u.Locale, &u.Status,
|
||||
&u.LastLoginAt, &u.LastActivityAt, &metadataJSON, &u.CreatedAt, &u.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(metadataJSON) > 0 {
|
||||
json.Unmarshal(metadataJSON, &u.Metadata)
|
||||
}
|
||||
return &u, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new user record
|
||||
func (s *DatabaseService) CreateUser(ctx context.Context, req *types.DBUserCreate) (*types.DBUser, error) {
|
||||
id := uuid.New()
|
||||
now := time.Now().UTC()
|
||||
|
||||
var metadataJSON []byte
|
||||
if req.Metadata != nil {
|
||||
var err error
|
||||
metadataJSON, err = json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`INSERT INTO admin.users (id, gscsid, first_name, last_name, display_name, email, timezone, locale, status, metadata, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, 'active', $9, $10, $10)`,
|
||||
id, req.GscSID, nilIfEmpty(req.FirstName), nilIfEmpty(req.LastName), nilIfEmpty(req.DisplayName), nilIfEmpty(req.Email), nilIfEmpty(req.Timezone), nilIfEmpty(req.Locale), metadataJSON, now)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
return s.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateUser updates a user record
|
||||
func (s *DatabaseService) UpdateUser(ctx context.Context, id uuid.UUID, req *types.DBUserUpdate) (*types.DBUser, error) {
|
||||
setClauses := []string{}
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if req.Timezone != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("timezone = $%d", argIdx))
|
||||
args = append(args, *req.Timezone)
|
||||
argIdx++
|
||||
}
|
||||
if req.Locale != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("locale = $%d", argIdx))
|
||||
args = append(args, *req.Locale)
|
||||
argIdx++
|
||||
}
|
||||
if req.Status != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("status = $%d", argIdx))
|
||||
args = append(args, *req.Status)
|
||||
argIdx++
|
||||
}
|
||||
if req.LastLoginAt != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("last_login_at = $%d", argIdx))
|
||||
args = append(args, *req.LastLoginAt)
|
||||
argIdx++
|
||||
}
|
||||
if req.LastActivityAt != nil {
|
||||
setClauses = append(setClauses, fmt.Sprintf("last_activity_at = $%d", argIdx))
|
||||
args = append(args, *req.LastActivityAt)
|
||||
argIdx++
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
metadataJSON, err := json.Marshal(req.Metadata)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal metadata: %w", err)
|
||||
}
|
||||
setClauses = append(setClauses, fmt.Sprintf("metadata = $%d", argIdx))
|
||||
args = append(args, metadataJSON)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return s.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
setClauses = append(setClauses, fmt.Sprintf("updated_at = $%d", argIdx))
|
||||
args = append(args, time.Now().UTC())
|
||||
argIdx++
|
||||
|
||||
args = append(args, id)
|
||||
query := fmt.Sprintf("UPDATE admin.users SET %s WHERE id = $%d",
|
||||
join(setClauses, ", "), argIdx)
|
||||
|
||||
_, err := s.pool.Exec(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
|
||||
return s.GetUser(ctx, id)
|
||||
}
|
||||
|
||||
// DeactivateUser deactivates a user
|
||||
func (s *DatabaseService) DeactivateUser(ctx context.Context, id uuid.UUID) error {
|
||||
now := time.Now().UTC()
|
||||
_, err := s.pool.Exec(ctx,
|
||||
`UPDATE admin.users SET status = 'inactive', updated_at = $1 WHERE id = $2`,
|
||||
now, id)
|
||||
return err
|
||||
}
|
||||
|
||||
func nilIfEmpty(s string) *string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
return &s
|
||||
}
|
||||
|
||||
func join(strs []string, sep string) string {
|
||||
result := ""
|
||||
for i, s := range strs {
|
||||
if i > 0 {
|
||||
result += sep
|
||||
}
|
||||
result += s
|
||||
}
|
||||
return result
|
||||
}
|
||||
324
internal/service/dns.go
Normal file
324
internal/service/dns.go
Normal file
@@ -0,0 +1,324 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/internal/client"
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// DNSService handles PowerDNS zone and record operations
|
||||
type DNSService struct {
|
||||
client *client.PowerDNSClient
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewDNSService creates a new DNS service
|
||||
func NewDNSService(pdnsClient *client.PowerDNSClient, logger zerolog.Logger) *DNSService {
|
||||
return &DNSService{
|
||||
client: pdnsClient,
|
||||
logger: logger.With().Str("service", "dns").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ListZones lists all DNS zones
|
||||
func (s *DNSService) ListZones() ([]types.DNSZone, error) {
|
||||
zones, err := s.client.ListZones()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]types.DNSZone, 0, len(zones))
|
||||
for _, z := range zones {
|
||||
result = append(result, types.DNSZone{
|
||||
ID: z.ID,
|
||||
Name: z.Name,
|
||||
Kind: z.Kind,
|
||||
DNSSec: z.DNSSec,
|
||||
Serial: z.Serial,
|
||||
NotifiedSerial: z.NotifiedSerial,
|
||||
})
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetZone gets a zone with records
|
||||
func (s *DNSService) GetZone(zoneID string) (*types.DNSZone, error) {
|
||||
z, err := s.client.GetZone(zoneID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone := &types.DNSZone{
|
||||
ID: z.ID,
|
||||
Name: z.Name,
|
||||
Kind: z.Kind,
|
||||
DNSSec: z.DNSSec,
|
||||
Serial: z.Serial,
|
||||
NotifiedSerial: z.NotifiedSerial,
|
||||
SOAEdit: z.SOAEdit,
|
||||
SOAEditAPI: z.SOAEditAPI,
|
||||
}
|
||||
|
||||
records := make([]types.DNSRecord, 0, len(z.RRSets))
|
||||
for _, rr := range z.RRSets {
|
||||
entries := make([]types.DNSRecordEntry, 0, len(rr.Records))
|
||||
for _, r := range rr.Records {
|
||||
entries = append(entries, types.DNSRecordEntry{
|
||||
Content: r.Content,
|
||||
Disabled: r.Disabled,
|
||||
})
|
||||
}
|
||||
records = append(records, types.DNSRecord{
|
||||
Name: rr.Name,
|
||||
Type: rr.Type,
|
||||
TTL: rr.TTL,
|
||||
Records: entries,
|
||||
})
|
||||
}
|
||||
zone.Records = records
|
||||
|
||||
return zone, nil
|
||||
}
|
||||
|
||||
// CreateZone creates a new DNS zone
|
||||
func (s *DNSService) CreateZone(req *types.DNSZoneCreate) (*types.DNSZone, error) {
|
||||
kind := req.Kind
|
||||
if kind == "" {
|
||||
kind = "Native"
|
||||
}
|
||||
|
||||
name := req.Name
|
||||
if !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
z, err := s.client.CreateZone(&client.ZoneCreate{
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Nameservers: req.Nameservers,
|
||||
Masters: req.Masters,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &types.DNSZone{
|
||||
ID: z.ID,
|
||||
Name: z.Name,
|
||||
Kind: z.Kind,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UpdateZone updates zone metadata
|
||||
func (s *DNSService) UpdateZone(zoneID string, req *types.DNSZoneUpdate) error {
|
||||
data := make(map[string]interface{})
|
||||
if req.Kind != nil {
|
||||
data["kind"] = *req.Kind
|
||||
}
|
||||
if req.Masters != nil {
|
||||
data["masters"] = req.Masters
|
||||
}
|
||||
return s.client.UpdateZone(zoneID, data)
|
||||
}
|
||||
|
||||
// DeleteZone deletes a zone
|
||||
func (s *DNSService) DeleteZone(zoneID string) error {
|
||||
return s.client.DeleteZone(zoneID)
|
||||
}
|
||||
|
||||
// NotifyZone sends NOTIFY to slaves
|
||||
func (s *DNSService) NotifyZone(zoneID string) error {
|
||||
return s.client.NotifyZone(zoneID)
|
||||
}
|
||||
|
||||
// ListRecords lists records in a zone
|
||||
func (s *DNSService) ListRecords(zoneID string) ([]types.DNSRecord, error) {
|
||||
zone, err := s.GetZone(zoneID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return zone.Records, nil
|
||||
}
|
||||
|
||||
// ChangeRecords applies record changes to a zone using PATCH semantics
|
||||
func (s *DNSService) ChangeRecords(zoneID string, changes []types.DNSRecordChange) error {
|
||||
rrsets := make([]client.RRSet, 0, len(changes))
|
||||
for _, ch := range changes {
|
||||
name := ch.Name
|
||||
if !strings.HasSuffix(name, ".") {
|
||||
name += "."
|
||||
}
|
||||
|
||||
records := make([]client.Record, 0, len(ch.Records))
|
||||
for _, r := range ch.Records {
|
||||
records = append(records, client.Record{
|
||||
Content: r.Content,
|
||||
Disabled: r.Disabled,
|
||||
})
|
||||
}
|
||||
|
||||
ttl := ch.TTL
|
||||
if ttl == 0 {
|
||||
ttl = 3600
|
||||
}
|
||||
|
||||
rrsets = append(rrsets, client.RRSet{
|
||||
Name: name,
|
||||
Type: ch.Type,
|
||||
TTL: ttl,
|
||||
ChangeType: ch.ChangeType,
|
||||
Records: records,
|
||||
})
|
||||
}
|
||||
return s.client.PatchRRSets(zoneID, rrsets)
|
||||
}
|
||||
|
||||
// SetupDomain creates a zone with standard mail DNS records (MX, SPF, DKIM, DMARC)
|
||||
func (s *DNSService) SetupDomain(req *types.DomainSetup) (*types.DNSZone, error) {
|
||||
domain := req.Domain
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain += "."
|
||||
}
|
||||
|
||||
// Create zone first
|
||||
zone, err := s.client.CreateZone(&client.ZoneCreate{
|
||||
Name: domain,
|
||||
Kind: "Native",
|
||||
Nameservers: []string{
|
||||
"ns1.gosec.cloud.",
|
||||
"ns2.gosec.cloud.",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create zone: %w", err)
|
||||
}
|
||||
|
||||
// Build standard mail records
|
||||
mxHost := req.MXHost
|
||||
if mxHost == "" {
|
||||
mxHost = "mail.gosec.cloud."
|
||||
}
|
||||
if !strings.HasSuffix(mxHost, ".") {
|
||||
mxHost += "."
|
||||
}
|
||||
|
||||
rrsets := []client.RRSet{
|
||||
{
|
||||
Name: domain,
|
||||
Type: "MX",
|
||||
TTL: 3600,
|
||||
ChangeType: "REPLACE",
|
||||
Records: []client.Record{{Content: "10 " + mxHost}},
|
||||
},
|
||||
}
|
||||
|
||||
// SPF record
|
||||
spf := "v=spf1"
|
||||
if len(req.SPFIncludes) > 0 {
|
||||
for _, inc := range req.SPFIncludes {
|
||||
spf += " include:" + inc
|
||||
}
|
||||
}
|
||||
spf += " mx -all"
|
||||
rrsets = append(rrsets, client.RRSet{
|
||||
Name: domain,
|
||||
Type: "TXT",
|
||||
TTL: 3600,
|
||||
ChangeType: "REPLACE",
|
||||
Records: []client.Record{{Content: fmt.Sprintf(`"%s"`, spf)}},
|
||||
})
|
||||
|
||||
// DKIM record
|
||||
if req.DKIMKey != "" {
|
||||
rrsets = append(rrsets, client.RRSet{
|
||||
Name: "default._domainkey." + domain,
|
||||
Type: "TXT",
|
||||
TTL: 3600,
|
||||
ChangeType: "REPLACE",
|
||||
Records: []client.Record{{Content: fmt.Sprintf(`"v=DKIM1; k=rsa; p=%s"`, req.DKIMKey)}},
|
||||
})
|
||||
}
|
||||
|
||||
// DMARC record
|
||||
rrsets = append(rrsets, client.RRSet{
|
||||
Name: "_dmarc." + domain,
|
||||
Type: "TXT",
|
||||
TTL: 3600,
|
||||
ChangeType: "REPLACE",
|
||||
Records: []client.Record{{Content: fmt.Sprintf(`"v=DMARC1; p=quarantine; rua=mailto:postmaster@%s"`, strings.TrimSuffix(domain, "."))}},
|
||||
})
|
||||
|
||||
if err := s.client.PatchRRSets(zone.ID, rrsets); err != nil {
|
||||
return nil, fmt.Errorf("zone created but record setup failed: %w", err)
|
||||
}
|
||||
|
||||
result := &types.DNSZone{
|
||||
ID: zone.ID,
|
||||
Name: zone.Name,
|
||||
Kind: zone.Kind,
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// VerifyDomain checks DNS propagation for a domain
|
||||
func (s *DNSService) VerifyDomain(domain string) (*types.DomainVerifyResult, error) {
|
||||
if !strings.HasSuffix(domain, ".") {
|
||||
domain += "."
|
||||
}
|
||||
|
||||
zone, err := s.client.GetZone(domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("zone not found: %w", err)
|
||||
}
|
||||
|
||||
results := make(map[string]string)
|
||||
hasMX, hasSPF, hasDMARC := false, false, false
|
||||
|
||||
for _, rr := range zone.RRSets {
|
||||
switch {
|
||||
case rr.Type == "MX" && rr.Name == domain:
|
||||
hasMX = true
|
||||
results["MX"] = "OK"
|
||||
case rr.Type == "TXT" && rr.Name == domain:
|
||||
for _, r := range rr.Records {
|
||||
if strings.Contains(r.Content, "v=spf1") {
|
||||
hasSPF = true
|
||||
results["SPF"] = "OK"
|
||||
}
|
||||
}
|
||||
case rr.Type == "TXT" && rr.Name == "_dmarc."+domain:
|
||||
hasDMARC = true
|
||||
results["DMARC"] = "OK"
|
||||
case rr.Type == "TXT" && strings.HasSuffix(rr.Name, "._domainkey."+domain):
|
||||
results["DKIM"] = "OK"
|
||||
}
|
||||
}
|
||||
|
||||
if !hasMX {
|
||||
results["MX"] = "MISSING"
|
||||
}
|
||||
if !hasSPF {
|
||||
results["SPF"] = "MISSING"
|
||||
}
|
||||
if !hasDMARC {
|
||||
results["DMARC"] = "MISSING"
|
||||
}
|
||||
|
||||
allOK := true
|
||||
for _, v := range results {
|
||||
if v != "OK" {
|
||||
allOK = false
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return &types.DomainVerifyResult{
|
||||
Domain: strings.TrimSuffix(domain, "."),
|
||||
Results: results,
|
||||
AllOK: allOK,
|
||||
}, nil
|
||||
}
|
||||
648
internal/service/ldap.go
Normal file
648
internal/service/ldap.go
Normal file
@@ -0,0 +1,648 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/internal/client"
|
||||
"github.com/gosec/gsc-ops-api/internal/schema"
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// LDAPService handles FreeIPA user and group operations
|
||||
type LDAPService struct {
|
||||
client *client.LDAPClient
|
||||
baseDN string
|
||||
logger zerolog.Logger
|
||||
registry *schema.Registry
|
||||
}
|
||||
|
||||
// NewLDAPService creates a new LDAP service
|
||||
func NewLDAPService(ldapClient *client.LDAPClient, baseDN string, logger zerolog.Logger, registry *schema.Registry) *LDAPService {
|
||||
return &LDAPService{
|
||||
client: ldapClient,
|
||||
baseDN: baseDN,
|
||||
logger: logger.With().Str("service", "ldap").Logger(),
|
||||
registry: registry,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LDAPService) userBaseDN() string {
|
||||
return "cn=users,cn=accounts," + s.baseDN
|
||||
}
|
||||
|
||||
func (s *LDAPService) groupBaseDN() string {
|
||||
return "cn=groups,cn=accounts," + s.baseDN
|
||||
}
|
||||
|
||||
func (s *LDAPService) userDN(uid string) string {
|
||||
return fmt.Sprintf("uid=%s,%s", ldap.EscapeFilter(uid), s.userBaseDN())
|
||||
}
|
||||
|
||||
func (s *LDAPService) groupDN(cn string) string {
|
||||
return fmt.Sprintf("cn=%s,%s", ldap.EscapeFilter(cn), s.groupBaseDN())
|
||||
}
|
||||
|
||||
// coreUserAttrs are the base LDAP attributes for user listing (no gsc* attrs)
|
||||
var coreUserAttrs = []string{
|
||||
"uid", "givenName", "sn", "displayName", "mail", "telephoneNumber",
|
||||
"title", "nsAccountLock", "loginShell", "homeDirectory", "memberOf",
|
||||
}
|
||||
|
||||
// userSearchAttrs returns core attrs plus all gsc* attrs for full user retrieval
|
||||
func (s *LDAPService) userSearchAttrs() []string {
|
||||
gscAttrs := s.registry.AllUserAttrs()
|
||||
attrs := make([]string, 0, len(coreUserAttrs)+len(gscAttrs)+1)
|
||||
attrs = append(attrs, coreUserAttrs...)
|
||||
attrs = append(attrs, "objectClass")
|
||||
attrs = append(attrs, gscAttrs...)
|
||||
return attrs
|
||||
}
|
||||
|
||||
var groupAttrs = []string{
|
||||
"cn", "description", "member", "gidNumber",
|
||||
}
|
||||
|
||||
// ListUsers searches for users, optionally filtering by search string,
|
||||
// service objectClasses, and/or arbitrary LDAP attribute values.
|
||||
//
|
||||
// attrFilters maps raw LDAP attribute names to match values. Values may
|
||||
// contain LDAP wildcards (e.g. "*@example.com"). The attribute name itself
|
||||
// is sanitised to prevent filter injection.
|
||||
func (s *LDAPService) ListUsers(search string, limit int, serviceFilters []string, attrFilters map[string]string) ([]types.LDAPUser, error) {
|
||||
// Start with base object class filter
|
||||
parts := []string{"(objectClass=posixAccount)"}
|
||||
|
||||
// Free-text search across core fields
|
||||
if search != "" {
|
||||
escaped := ldap.EscapeFilter(search)
|
||||
parts = append(parts, fmt.Sprintf("(|(uid=*%s*)(givenName=*%s*)(sn=*%s*)(mail=*%s*))",
|
||||
escaped, escaped, escaped, escaped))
|
||||
}
|
||||
|
||||
// Service objectClass filters
|
||||
for _, svc := range serviceFilters {
|
||||
oc := s.registry.UserOCForDomain(svc)
|
||||
if oc != "" {
|
||||
parts = append(parts, fmt.Sprintf("(objectClass=%s)", oc))
|
||||
}
|
||||
}
|
||||
|
||||
// Dynamic LDAP attribute filters
|
||||
// Collect extra attrs we need to request so the server evaluates the filter
|
||||
var extraAttrs []string
|
||||
for attr, val := range attrFilters {
|
||||
// Sanitise attribute name: only allow alphanumeric, dash, semicolon
|
||||
safe := true
|
||||
for _, ch := range attr {
|
||||
if !((ch >= 'a' && ch <= 'z') || (ch >= 'A' && ch <= 'Z') || (ch >= '0' && ch <= '9') || ch == '-' || ch == ';') {
|
||||
safe = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !safe || attr == "" {
|
||||
continue
|
||||
}
|
||||
// Escape value but preserve * wildcards for substring matching.
|
||||
// Split on *, escape each segment, rejoin with *.
|
||||
segments := strings.Split(val, "*")
|
||||
for i, seg := range segments {
|
||||
segments[i] = ldap.EscapeFilter(seg)
|
||||
}
|
||||
escapedVal := strings.Join(segments, "*")
|
||||
parts = append(parts, fmt.Sprintf("(%s=%s)", attr, escapedVal))
|
||||
extraAttrs = append(extraAttrs, attr)
|
||||
}
|
||||
|
||||
// Build final filter
|
||||
var filter string
|
||||
if len(parts) == 1 {
|
||||
filter = parts[0]
|
||||
} else {
|
||||
filter = "(&" + strings.Join(parts, "") + ")"
|
||||
}
|
||||
|
||||
// When service filters are present, fetch full gsc* attrs so the
|
||||
// response includes the services block (e.g. gscSID for chat).
|
||||
includeServices := len(serviceFilters) > 0
|
||||
var attrs []string
|
||||
if includeServices {
|
||||
attrs = s.userSearchAttrs()
|
||||
if len(extraAttrs) > 0 {
|
||||
attrs = append(attrs, extraAttrs...)
|
||||
}
|
||||
} else {
|
||||
attrs = coreUserAttrs
|
||||
if len(extraAttrs) > 0 {
|
||||
attrs = make([]string, len(coreUserAttrs), len(coreUserAttrs)+len(extraAttrs))
|
||||
copy(attrs, coreUserAttrs)
|
||||
attrs = append(attrs, extraAttrs...)
|
||||
}
|
||||
}
|
||||
|
||||
entries, err := s.client.Search(s.userBaseDN(), filter, attrs, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
users := make([]types.LDAPUser, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
users = append(users, s.entryToUser(entry, includeServices))
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// GetUser gets a user by UID with full service attributes
|
||||
func (s *LDAPService) GetUser(uid string) (*types.LDAPUser, error) {
|
||||
filter := fmt.Sprintf("(&(objectClass=posixAccount)(uid=%s))", ldap.EscapeFilter(uid))
|
||||
entry, err := s.client.SearchOne(s.userBaseDN(), filter, s.userSearchAttrs())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
user := s.entryToUser(entry, true)
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
// GetUserServices returns only service attributes for a user
|
||||
func (s *LDAPService) GetUserServices(uid string, domain string) (map[string]map[string]interface{}, error) {
|
||||
filter := fmt.Sprintf("(&(objectClass=posixAccount)(uid=%s))", ldap.EscapeFilter(uid))
|
||||
entry, err := s.client.SearchOne(s.userBaseDN(), filter, s.userSearchAttrs())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
services := s.extractServices(entry)
|
||||
if domain != "" {
|
||||
filtered := make(map[string]map[string]interface{})
|
||||
if svc, ok := services[domain]; ok {
|
||||
filtered[domain] = svc
|
||||
}
|
||||
return filtered, nil
|
||||
}
|
||||
return services, nil
|
||||
}
|
||||
|
||||
// CreateUser creates a new FreeIPA user
|
||||
func (s *LDAPService) CreateUser(req *types.LDAPUserCreate) (*types.LDAPUser, error) {
|
||||
dn := s.userDN(req.UID)
|
||||
|
||||
objectClasses := []string{"top", "person", "organizationalPerson", "inetOrgPerson", "posixAccount", "krbPrincipalAux", "ipaObject"}
|
||||
|
||||
addReq := ldap.NewAddRequest(dn, nil)
|
||||
addReq.Attribute("uid", []string{req.UID})
|
||||
addReq.Attribute("givenName", []string{req.FirstName})
|
||||
addReq.Attribute("sn", []string{req.LastName})
|
||||
addReq.Attribute("cn", []string{req.FirstName + " " + req.LastName})
|
||||
addReq.Attribute("displayName", []string{req.FirstName + " " + req.LastName})
|
||||
|
||||
if req.Email != "" {
|
||||
addReq.Attribute("mail", []string{req.Email})
|
||||
}
|
||||
if req.Phone != "" {
|
||||
addReq.Attribute("telephoneNumber", []string{req.Phone})
|
||||
}
|
||||
if req.Title != "" {
|
||||
addReq.Attribute("title", []string{req.Title})
|
||||
}
|
||||
|
||||
shell := "/bin/bash"
|
||||
if req.Shell != "" {
|
||||
shell = req.Shell
|
||||
}
|
||||
addReq.Attribute("loginShell", []string{shell})
|
||||
addReq.Attribute("homeDirectory", []string{"/home/" + req.UID})
|
||||
|
||||
// Process service attributes
|
||||
if len(req.Services) > 0 {
|
||||
svcOCs, svcAttrs, err := s.resolveServices(req.Services)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid services: %w", err)
|
||||
}
|
||||
objectClasses = append(objectClasses, svcOCs...)
|
||||
for attrName, vals := range svcAttrs {
|
||||
addReq.Attribute(attrName, vals)
|
||||
}
|
||||
// Add audit timestamps
|
||||
now := time.Now().UTC().Format("20060102150405Z")
|
||||
addReq.Attribute("gscCreatedAt", []string{now})
|
||||
addReq.Attribute("gscModifiedAt", []string{now})
|
||||
}
|
||||
|
||||
addReq.Attribute("objectClass", objectClasses)
|
||||
|
||||
if err := s.client.Add(addReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to create user: %w", err)
|
||||
}
|
||||
|
||||
// Set password if provided
|
||||
if req.Password != "" {
|
||||
if err := s.client.PasswordModify(dn, req.Password); err != nil {
|
||||
s.logger.Warn().Err(err).Str("uid", req.UID).Msg("user created but password set failed")
|
||||
}
|
||||
}
|
||||
|
||||
return s.GetUser(req.UID)
|
||||
}
|
||||
|
||||
// UpdateUser updates a user's attributes
|
||||
func (s *LDAPService) UpdateUser(uid string, req *types.LDAPUserUpdate) (*types.LDAPUser, error) {
|
||||
dn := s.userDN(uid)
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
modified := false
|
||||
|
||||
if req.FirstName != nil {
|
||||
modReq.Replace("givenName", []string{*req.FirstName})
|
||||
modified = true
|
||||
}
|
||||
if req.LastName != nil {
|
||||
modReq.Replace("sn", []string{*req.LastName})
|
||||
modified = true
|
||||
}
|
||||
if req.FirstName != nil || req.LastName != nil {
|
||||
// Update display name and cn
|
||||
first, last := "", ""
|
||||
if req.FirstName != nil {
|
||||
first = *req.FirstName
|
||||
}
|
||||
if req.LastName != nil {
|
||||
last = *req.LastName
|
||||
}
|
||||
if first != "" || last != "" {
|
||||
display := strings.TrimSpace(first + " " + last)
|
||||
if display != "" {
|
||||
modReq.Replace("displayName", []string{display})
|
||||
modReq.Replace("cn", []string{display})
|
||||
}
|
||||
}
|
||||
}
|
||||
if req.Email != nil {
|
||||
modReq.Replace("mail", []string{*req.Email})
|
||||
modified = true
|
||||
}
|
||||
if req.Phone != nil {
|
||||
modReq.Replace("telephoneNumber", []string{*req.Phone})
|
||||
modified = true
|
||||
}
|
||||
if req.Title != nil {
|
||||
modReq.Replace("title", []string{*req.Title})
|
||||
modified = true
|
||||
}
|
||||
if req.Shell != nil {
|
||||
modReq.Replace("loginShell", []string{*req.Shell})
|
||||
modified = true
|
||||
}
|
||||
if req.Disabled != nil {
|
||||
if *req.Disabled {
|
||||
modReq.Replace("nsAccountLock", []string{"TRUE"})
|
||||
} else {
|
||||
modReq.Replace("nsAccountLock", []string{"FALSE"})
|
||||
}
|
||||
modified = true
|
||||
}
|
||||
|
||||
// Process service attributes
|
||||
if len(req.Services) > 0 {
|
||||
svcOCs, svcAttrs, err := s.resolveServices(req.Services)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid services: %w", err)
|
||||
}
|
||||
|
||||
// Fetch current objectClasses to determine which to add
|
||||
currentOCs, err := s.getCurrentObjectClasses(uid)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read current objectClasses: %w", err)
|
||||
}
|
||||
|
||||
currentOCSet := make(map[string]bool, len(currentOCs))
|
||||
for _, oc := range currentOCs {
|
||||
currentOCSet[oc] = true
|
||||
}
|
||||
|
||||
newOCs := make([]string, 0)
|
||||
for _, oc := range svcOCs {
|
||||
if !currentOCSet[oc] {
|
||||
newOCs = append(newOCs, oc)
|
||||
}
|
||||
}
|
||||
if len(newOCs) > 0 {
|
||||
modReq.Add("objectClass", newOCs)
|
||||
}
|
||||
|
||||
for attrName, vals := range svcAttrs {
|
||||
modReq.Replace(attrName, vals)
|
||||
}
|
||||
|
||||
// Update audit timestamp
|
||||
now := time.Now().UTC().Format("20060102150405Z")
|
||||
modReq.Replace("gscModifiedAt", []string{now})
|
||||
|
||||
modified = true
|
||||
}
|
||||
|
||||
if !modified {
|
||||
return s.GetUser(uid)
|
||||
}
|
||||
|
||||
if err := s.client.Modify(modReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to update user: %w", err)
|
||||
}
|
||||
|
||||
return s.GetUser(uid)
|
||||
}
|
||||
|
||||
// DisableUser disables a user account
|
||||
func (s *LDAPService) DisableUser(uid string) error {
|
||||
dn := s.userDN(uid)
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
modReq.Replace("nsAccountLock", []string{"TRUE"})
|
||||
return s.client.Modify(modReq)
|
||||
}
|
||||
|
||||
// ResetPassword resets a user's password
|
||||
func (s *LDAPService) ResetPassword(uid, newPassword string) error {
|
||||
dn := s.userDN(uid)
|
||||
return s.client.PasswordModify(dn, newPassword)
|
||||
}
|
||||
|
||||
// GetUserGroups lists groups a user belongs to
|
||||
func (s *LDAPService) GetUserGroups(uid string) ([]string, error) {
|
||||
filter := fmt.Sprintf("(&(objectClass=posixAccount)(uid=%s))", ldap.EscapeFilter(uid))
|
||||
entry, err := s.client.SearchOne(s.userBaseDN(), filter, []string{"memberOf"})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
memberOf := entry.GetAttributeValues("memberOf")
|
||||
groups := make([]string, 0, len(memberOf))
|
||||
for _, dn := range memberOf {
|
||||
// Extract cn from DN
|
||||
parts := strings.Split(dn, ",")
|
||||
if len(parts) > 0 && strings.HasPrefix(parts[0], "cn=") {
|
||||
groups = append(groups, strings.TrimPrefix(parts[0], "cn="))
|
||||
}
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// ListGroups searches for groups
|
||||
func (s *LDAPService) ListGroups(search string, limit int) ([]types.LDAPGroup, error) {
|
||||
filter := "(objectClass=groupOfNames)"
|
||||
if search != "" {
|
||||
escaped := ldap.EscapeFilter(search)
|
||||
filter = fmt.Sprintf("(&(objectClass=groupOfNames)(|(cn=*%s*)(description=*%s*)))", escaped, escaped)
|
||||
}
|
||||
|
||||
entries, err := s.client.Search(s.groupBaseDN(), filter, groupAttrs, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
groups := make([]types.LDAPGroup, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
groups = append(groups, s.entryToGroup(entry))
|
||||
}
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// GetGroup gets a group by CN
|
||||
func (s *LDAPService) GetGroup(cn string) (*types.LDAPGroup, error) {
|
||||
filter := fmt.Sprintf("(&(objectClass=groupOfNames)(cn=%s))", ldap.EscapeFilter(cn))
|
||||
entry, err := s.client.SearchOne(s.groupBaseDN(), filter, groupAttrs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
group := s.entryToGroup(entry)
|
||||
return &group, nil
|
||||
}
|
||||
|
||||
// CreateGroup creates a new group
|
||||
func (s *LDAPService) CreateGroup(req *types.LDAPGroupCreate) (*types.LDAPGroup, error) {
|
||||
dn := s.groupDN(req.CN)
|
||||
|
||||
addReq := ldap.NewAddRequest(dn, nil)
|
||||
addReq.Attribute("objectClass", []string{"top", "groupOfNames", "posixGroup", "ipaObject"})
|
||||
addReq.Attribute("cn", []string{req.CN})
|
||||
if req.Description != "" {
|
||||
addReq.Attribute("description", []string{req.Description})
|
||||
}
|
||||
|
||||
if err := s.client.Add(addReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to create group: %w", err)
|
||||
}
|
||||
|
||||
return s.GetGroup(req.CN)
|
||||
}
|
||||
|
||||
// UpdateGroup updates a group's attributes
|
||||
func (s *LDAPService) UpdateGroup(cn string, req *types.LDAPGroupUpdate) (*types.LDAPGroup, error) {
|
||||
dn := s.groupDN(cn)
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
|
||||
if req.Description != nil {
|
||||
modReq.Replace("description", []string{*req.Description})
|
||||
}
|
||||
|
||||
if err := s.client.Modify(modReq); err != nil {
|
||||
return nil, fmt.Errorf("failed to update group: %w", err)
|
||||
}
|
||||
|
||||
return s.GetGroup(cn)
|
||||
}
|
||||
|
||||
// DeleteGroup deletes a group
|
||||
func (s *LDAPService) DeleteGroup(cn string) error {
|
||||
dn := s.groupDN(cn)
|
||||
return s.client.Delete(dn)
|
||||
}
|
||||
|
||||
// GetGroupMembers lists members of a group
|
||||
func (s *LDAPService) GetGroupMembers(cn string) ([]string, error) {
|
||||
group, err := s.GetGroup(cn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if group == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return group.Members, nil
|
||||
}
|
||||
|
||||
// AddGroupMembers adds members to a group
|
||||
func (s *LDAPService) AddGroupMembers(cn string, uids []string) error {
|
||||
dn := s.groupDN(cn)
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
|
||||
for _, uid := range uids {
|
||||
memberDN := s.userDN(uid)
|
||||
modReq.Add("member", []string{memberDN})
|
||||
}
|
||||
|
||||
return s.client.Modify(modReq)
|
||||
}
|
||||
|
||||
// RemoveGroupMember removes a member from a group
|
||||
func (s *LDAPService) RemoveGroupMember(cn, uid string) error {
|
||||
dn := s.groupDN(cn)
|
||||
memberDN := s.userDN(uid)
|
||||
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
modReq.Delete("member", []string{memberDN})
|
||||
|
||||
return s.client.Modify(modReq)
|
||||
}
|
||||
|
||||
func (s *LDAPService) entryToUser(entry *ldap.Entry, includeServices bool) types.LDAPUser {
|
||||
memberOf := entry.GetAttributeValues("memberOf")
|
||||
groups := make([]string, 0, len(memberOf))
|
||||
for _, dn := range memberOf {
|
||||
parts := strings.Split(dn, ",")
|
||||
if len(parts) > 0 && strings.HasPrefix(parts[0], "cn=") {
|
||||
groups = append(groups, strings.TrimPrefix(parts[0], "cn="))
|
||||
}
|
||||
}
|
||||
|
||||
disabled := strings.EqualFold(entry.GetAttributeValue("nsAccountLock"), "TRUE")
|
||||
|
||||
user := types.LDAPUser{
|
||||
UID: entry.GetAttributeValue("uid"),
|
||||
FirstName: entry.GetAttributeValue("givenName"),
|
||||
LastName: entry.GetAttributeValue("sn"),
|
||||
DisplayName: entry.GetAttributeValue("displayName"),
|
||||
Email: entry.GetAttributeValue("mail"),
|
||||
Phone: entry.GetAttributeValue("telephoneNumber"),
|
||||
Title: entry.GetAttributeValue("title"),
|
||||
Disabled: disabled,
|
||||
Groups: groups,
|
||||
Shell: entry.GetAttributeValue("loginShell"),
|
||||
HomeDir: entry.GetAttributeValue("homeDirectory"),
|
||||
}
|
||||
|
||||
if includeServices {
|
||||
user.ObjectClasses = entry.GetAttributeValues("objectClass")
|
||||
user.Services = s.extractServices(entry)
|
||||
}
|
||||
|
||||
return user
|
||||
}
|
||||
|
||||
// extractServices extracts gsc* attributes from an LDAP entry, grouped by domain
|
||||
func (s *LDAPService) extractServices(entry *ldap.Entry) map[string]map[string]interface{} {
|
||||
services := make(map[string]map[string]interface{})
|
||||
|
||||
for _, attr := range entry.Attributes {
|
||||
def := s.registry.GetAttr(attr.Name)
|
||||
if def == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
val := s.registry.LDAPValueToGo(def, attr.Values)
|
||||
if val == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if services[def.Domain] == nil {
|
||||
services[def.Domain] = make(map[string]interface{})
|
||||
}
|
||||
services[def.Domain][def.JSONName] = val
|
||||
}
|
||||
|
||||
return services
|
||||
}
|
||||
|
||||
// resolveServices validates and converts service attributes to LDAP format.
|
||||
// Returns: required objectClasses, LDAP attribute map, or error.
|
||||
func (s *LDAPService) resolveServices(services map[string]map[string]interface{}) ([]string, map[string][]string, error) {
|
||||
ldapAttrs := make(map[string][]string)
|
||||
usedLDAPNames := make([]string, 0)
|
||||
|
||||
for domain, attrs := range services {
|
||||
domainDefs := s.registry.AttrsForDomain(domain)
|
||||
if domainDefs == nil {
|
||||
return nil, nil, fmt.Errorf("unknown service domain: %s", domain)
|
||||
}
|
||||
|
||||
for jsonName, value := range attrs {
|
||||
def := s.registry.GetAttrByJSON(domain, jsonName)
|
||||
if def == nil {
|
||||
return nil, nil, fmt.Errorf("unknown attribute %s in domain %s", jsonName, domain)
|
||||
}
|
||||
if def.ReadOnly {
|
||||
continue // skip read-only attrs silently
|
||||
}
|
||||
|
||||
vals, err := s.registry.GoValueToLDAP(def, value)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("attribute %s.%s: %w", domain, jsonName, err)
|
||||
}
|
||||
if vals != nil {
|
||||
ldapAttrs[def.LDAPName] = vals
|
||||
usedLDAPNames = append(usedLDAPNames, def.LDAPName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Determine required objectClasses
|
||||
ocs := s.registry.RequiredOCsForAttrs(usedLDAPNames)
|
||||
|
||||
// Validate that all MUST attrs for each OC are provided
|
||||
for _, ocName := range ocs {
|
||||
oc := s.registry.GetObjectClass(ocName)
|
||||
if oc == nil {
|
||||
continue
|
||||
}
|
||||
for _, must := range oc.Must {
|
||||
if _, ok := ldapAttrs[must]; !ok {
|
||||
return nil, nil, fmt.Errorf("objectClass %s requires attribute %s", ocName, must)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ocs, ldapAttrs, nil
|
||||
}
|
||||
|
||||
// getCurrentObjectClasses fetches the current objectClasses of a user entry
|
||||
func (s *LDAPService) getCurrentObjectClasses(uid string) ([]string, error) {
|
||||
filter := fmt.Sprintf("(&(objectClass=posixAccount)(uid=%s))", ldap.EscapeFilter(uid))
|
||||
entry, err := s.client.SearchOne(s.userBaseDN(), filter, []string{"objectClass"})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, fmt.Errorf("user not found: %s", uid)
|
||||
}
|
||||
return entry.GetAttributeValues("objectClass"), nil
|
||||
}
|
||||
|
||||
func (s *LDAPService) entryToGroup(entry *ldap.Entry) types.LDAPGroup {
|
||||
members := entry.GetAttributeValues("member")
|
||||
uids := make([]string, 0, len(members))
|
||||
for _, dn := range members {
|
||||
parts := strings.Split(dn, ",")
|
||||
if len(parts) > 0 && strings.HasPrefix(parts[0], "uid=") {
|
||||
uids = append(uids, strings.TrimPrefix(parts[0], "uid="))
|
||||
}
|
||||
}
|
||||
|
||||
return types.LDAPGroup{
|
||||
CN: entry.GetAttributeValue("cn"),
|
||||
Description: entry.GetAttributeValue("description"),
|
||||
Members: uids,
|
||||
GIDNumber: entry.GetAttributeValue("gidNumber"),
|
||||
}
|
||||
}
|
||||
321
internal/service/ldap_entities.go
Normal file
321
internal/service/ldap_entities.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-ldap/ldap/v3"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/internal/client"
|
||||
"github.com/gosec/gsc-ops-api/internal/schema"
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// LDAPEntityService handles generic LDAP entity CRUD operations
|
||||
type LDAPEntityService struct {
|
||||
client *client.LDAPClient
|
||||
baseDN string
|
||||
registry *schema.Registry
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewLDAPEntityService creates a new entity service
|
||||
func NewLDAPEntityService(ldapClient *client.LDAPClient, baseDN string, registry *schema.Registry, logger zerolog.Logger) *LDAPEntityService {
|
||||
return &LDAPEntityService{
|
||||
client: ldapClient,
|
||||
baseDN: baseDN,
|
||||
registry: registry,
|
||||
logger: logger.With().Str("service", "ldap-entities").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// entityBaseDN returns the full base DN for an entity type
|
||||
func (s *LDAPEntityService) entityBaseDN(et *schema.EntityTypeDef) string {
|
||||
return et.BaseDN + "," + s.baseDN
|
||||
}
|
||||
|
||||
// entityDN returns the full DN for a specific entity
|
||||
func (s *LDAPEntityService) entityDN(et *schema.EntityTypeDef, rdnValue string) string {
|
||||
return fmt.Sprintf("%s=%s,%s", et.RDNAttribute, ldap.EscapeFilter(rdnValue), s.entityBaseDN(et))
|
||||
}
|
||||
|
||||
// entityAttrs returns all searchable LDAP attribute names for an entity type
|
||||
func (s *LDAPEntityService) entityAttrs(et *schema.EntityTypeDef) []string {
|
||||
attrs := []string{"objectClass"}
|
||||
for _, ocName := range et.ObjectClasses {
|
||||
oc := s.registry.GetObjectClass(ocName)
|
||||
if oc == nil {
|
||||
continue
|
||||
}
|
||||
attrs = append(attrs, oc.Must...)
|
||||
attrs = append(attrs, oc.May...)
|
||||
}
|
||||
// Deduplicate
|
||||
seen := make(map[string]bool, len(attrs))
|
||||
unique := make([]string, 0, len(attrs))
|
||||
for _, a := range attrs {
|
||||
if !seen[a] {
|
||||
seen[a] = true
|
||||
unique = append(unique, a)
|
||||
}
|
||||
}
|
||||
return unique
|
||||
}
|
||||
|
||||
// ListEntities searches for entities of a given type
|
||||
func (s *LDAPEntityService) ListEntities(typeName, search string, limit int) ([]types.LDAPEntity, error) {
|
||||
et := s.registry.GetEntityType(typeName)
|
||||
if et == nil {
|
||||
return nil, fmt.Errorf("unknown entity type: %s", typeName)
|
||||
}
|
||||
|
||||
filter := et.SearchFilter
|
||||
if search != "" {
|
||||
escaped := ldap.EscapeFilter(search)
|
||||
// Search by RDN attribute or description
|
||||
filter = fmt.Sprintf("(&%s(|(%s=*%s*)(gscDescription=*%s*)))",
|
||||
et.SearchFilter, et.RDNAttribute, escaped, escaped)
|
||||
}
|
||||
|
||||
attrs := s.entityAttrs(et)
|
||||
entries, err := s.client.Search(s.entityBaseDN(et), filter, attrs, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entity search failed: %w", err)
|
||||
}
|
||||
|
||||
entities := make([]types.LDAPEntity, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
entities = append(entities, s.entryToEntity(entry, et))
|
||||
}
|
||||
return entities, nil
|
||||
}
|
||||
|
||||
// GetEntity retrieves a single entity by its RDN value
|
||||
func (s *LDAPEntityService) GetEntity(typeName, rdnValue string) (*types.LDAPEntity, error) {
|
||||
et := s.registry.GetEntityType(typeName)
|
||||
if et == nil {
|
||||
return nil, fmt.Errorf("unknown entity type: %s", typeName)
|
||||
}
|
||||
|
||||
filter := fmt.Sprintf("(&%s(%s=%s))",
|
||||
et.SearchFilter, et.RDNAttribute, ldap.EscapeFilter(rdnValue))
|
||||
attrs := s.entityAttrs(et)
|
||||
|
||||
entry, err := s.client.SearchOne(s.entityBaseDN(et), filter, attrs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entity lookup failed: %w", err)
|
||||
}
|
||||
if entry == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
entity := s.entryToEntity(entry, et)
|
||||
return &entity, nil
|
||||
}
|
||||
|
||||
// CreateEntity creates a new entity
|
||||
func (s *LDAPEntityService) CreateEntity(typeName string, req *types.LDAPEntityCreate) (*types.LDAPEntity, error) {
|
||||
et := s.registry.GetEntityType(typeName)
|
||||
if et == nil {
|
||||
return nil, fmt.Errorf("unknown entity type: %s", typeName)
|
||||
}
|
||||
|
||||
// Resolve attributes from JSON names to LDAP names
|
||||
ldapAttrs, err := s.resolveEntityAttrs(et, req.Attributes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Validate required attributes
|
||||
for _, reqAttr := range et.RequiredAttrs {
|
||||
if _, ok := ldapAttrs[reqAttr]; !ok {
|
||||
// Try to find JSON name for better error message
|
||||
def := s.registry.GetAttr(reqAttr)
|
||||
jsonName := reqAttr
|
||||
if def != nil {
|
||||
jsonName = def.JSONName
|
||||
}
|
||||
return nil, fmt.Errorf("required attribute missing: %s", jsonName)
|
||||
}
|
||||
}
|
||||
|
||||
// Determine RDN value
|
||||
rdnVals, ok := ldapAttrs[et.RDNAttribute]
|
||||
if !ok || len(rdnVals) == 0 {
|
||||
return nil, fmt.Errorf("RDN attribute %s is required", et.RDNAttribute)
|
||||
}
|
||||
rdnValue := rdnVals[0]
|
||||
|
||||
// Build DN
|
||||
dn := s.entityDN(et, rdnValue)
|
||||
|
||||
// Add audit timestamps
|
||||
now := time.Now().UTC().Format("20060102150405Z")
|
||||
ldapAttrs["gscCreatedAt"] = []string{now}
|
||||
ldapAttrs["gscModifiedAt"] = []string{now}
|
||||
|
||||
// Create LDAP entry
|
||||
addReq := ldap.NewAddRequest(dn, nil)
|
||||
addReq.Attribute("objectClass", et.ObjectClasses)
|
||||
|
||||
for attrName, vals := range ldapAttrs {
|
||||
addReq.Attribute(attrName, vals)
|
||||
}
|
||||
|
||||
if err := s.client.Add(addReq); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultEntryAlreadyExists) {
|
||||
return nil, fmt.Errorf("CONFLICT: entity already exists: %s=%s", et.RDNAttribute, rdnValue)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create entity: %w", err)
|
||||
}
|
||||
|
||||
return s.GetEntity(typeName, rdnValue)
|
||||
}
|
||||
|
||||
// UpdateEntity modifies an existing entity
|
||||
func (s *LDAPEntityService) UpdateEntity(typeName, rdnValue string, req *types.LDAPEntityUpdate) (*types.LDAPEntity, error) {
|
||||
et := s.registry.GetEntityType(typeName)
|
||||
if et == nil {
|
||||
return nil, fmt.Errorf("unknown entity type: %s", typeName)
|
||||
}
|
||||
|
||||
// Resolve attributes
|
||||
ldapAttrs, err := s.resolveEntityAttrs(et, req.Attributes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(ldapAttrs) == 0 {
|
||||
return s.GetEntity(typeName, rdnValue)
|
||||
}
|
||||
|
||||
dn := s.entityDN(et, rdnValue)
|
||||
modReq := ldap.NewModifyRequest(dn, nil)
|
||||
|
||||
for attrName, vals := range ldapAttrs {
|
||||
modReq.Replace(attrName, vals)
|
||||
}
|
||||
|
||||
// Update audit timestamp
|
||||
now := time.Now().UTC().Format("20060102150405Z")
|
||||
modReq.Replace("gscModifiedAt", []string{now})
|
||||
|
||||
if err := s.client.Modify(modReq); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
|
||||
return nil, fmt.Errorf("NOT_FOUND: entity not found: %s=%s", et.RDNAttribute, rdnValue)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to update entity: %w", err)
|
||||
}
|
||||
|
||||
return s.GetEntity(typeName, rdnValue)
|
||||
}
|
||||
|
||||
// DeleteEntity removes an entity
|
||||
func (s *LDAPEntityService) DeleteEntity(typeName, rdnValue string) error {
|
||||
et := s.registry.GetEntityType(typeName)
|
||||
if et == nil {
|
||||
return fmt.Errorf("unknown entity type: %s", typeName)
|
||||
}
|
||||
|
||||
dn := s.entityDN(et, rdnValue)
|
||||
if err := s.client.Delete(dn); err != nil {
|
||||
if ldap.IsErrorWithCode(err, ldap.LDAPResultNoSuchObject) {
|
||||
return fmt.Errorf("NOT_FOUND: entity not found: %s=%s", et.RDNAttribute, rdnValue)
|
||||
}
|
||||
return fmt.Errorf("failed to delete entity: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveEntityAttrs converts JSON attribute names to LDAP attribute names with type conversion
|
||||
func (s *LDAPEntityService) resolveEntityAttrs(et *schema.EntityTypeDef, attrs map[string]interface{}) (map[string][]string, error) {
|
||||
ldapAttrs := make(map[string][]string)
|
||||
|
||||
for jsonName, value := range attrs {
|
||||
// Try to find by JSON name in the entity's domain
|
||||
def := s.registry.GetAttrByJSON(et.Domain, jsonName)
|
||||
if def == nil {
|
||||
// Also try common domain
|
||||
def = s.registry.GetAttrByJSON("common", jsonName)
|
||||
}
|
||||
if def == nil {
|
||||
// Try as direct LDAP name
|
||||
def = s.registry.GetAttr(jsonName)
|
||||
}
|
||||
if def == nil {
|
||||
return nil, fmt.Errorf("unknown attribute: %s", jsonName)
|
||||
}
|
||||
if def.ReadOnly {
|
||||
continue
|
||||
}
|
||||
|
||||
vals, err := s.registry.GoValueToLDAP(def, value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("attribute %s: %w", jsonName, err)
|
||||
}
|
||||
if vals != nil {
|
||||
ldapAttrs[def.LDAPName] = vals
|
||||
}
|
||||
}
|
||||
|
||||
return ldapAttrs, nil
|
||||
}
|
||||
|
||||
// entryToEntity converts an LDAP entry to a generic entity response
|
||||
func (s *LDAPEntityService) entryToEntity(entry *ldap.Entry, et *schema.EntityTypeDef) types.LDAPEntity {
|
||||
attrs := make(map[string]interface{})
|
||||
|
||||
for _, ldapAttr := range entry.Attributes {
|
||||
if ldapAttr.Name == "objectClass" {
|
||||
continue
|
||||
}
|
||||
def := s.registry.GetAttr(ldapAttr.Name)
|
||||
if def == nil {
|
||||
// Include unregistered attrs as raw strings
|
||||
if len(ldapAttr.Values) == 1 {
|
||||
attrs[ldapAttr.Name] = ldapAttr.Values[0]
|
||||
} else if len(ldapAttr.Values) > 1 {
|
||||
attrs[ldapAttr.Name] = ldapAttr.Values
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
val := s.registry.LDAPValueToGo(def, ldapAttr.Values)
|
||||
if val != nil {
|
||||
attrs[def.JSONName] = val
|
||||
}
|
||||
}
|
||||
|
||||
// Extract RDN value
|
||||
rdnValue := ""
|
||||
if et.RDNAttribute == "cn" {
|
||||
rdnValue = entry.GetAttributeValue("cn")
|
||||
} else {
|
||||
rdnValue = entry.GetAttributeValue(et.RDNAttribute)
|
||||
}
|
||||
|
||||
return types.LDAPEntity{
|
||||
DN: entry.DN,
|
||||
Type: et.Name,
|
||||
RDN: rdnValue,
|
||||
ObjectClasses: entry.GetAttributeValues("objectClass"),
|
||||
Attributes: attrs,
|
||||
}
|
||||
}
|
||||
|
||||
// ClassifyError classifies LDAP errors for HTTP status mapping
|
||||
func ClassifyError(err error) (string, string) {
|
||||
msg := err.Error()
|
||||
if strings.HasPrefix(msg, "CONFLICT:") {
|
||||
return "conflict", strings.TrimPrefix(msg, "CONFLICT: ")
|
||||
}
|
||||
if strings.HasPrefix(msg, "NOT_FOUND:") {
|
||||
return "not_found", strings.TrimPrefix(msg, "NOT_FOUND: ")
|
||||
}
|
||||
if strings.Contains(msg, "unknown entity type") || strings.Contains(msg, "unknown attribute") || strings.Contains(msg, "required attribute") {
|
||||
return "validation", msg
|
||||
}
|
||||
return "internal", msg
|
||||
}
|
||||
1091
internal/service/pbx.go
Normal file
1091
internal/service/pbx.go
Normal file
File diff suppressed because it is too large
Load Diff
514
internal/service/persona.go
Normal file
514
internal/service/persona.go
Normal file
@@ -0,0 +1,514 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// PersonaService handles persona operations against gsc_persona database
|
||||
type PersonaService struct {
|
||||
pool *pgxpool.Pool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewPersonaService creates a new persona service
|
||||
func NewPersonaService(pool *pgxpool.Pool, logger zerolog.Logger) *PersonaService {
|
||||
return &PersonaService{
|
||||
pool: pool,
|
||||
logger: logger.With().Str("service", "persona").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ListPersonas lists personas for a tenant with optional status filter
|
||||
func (s *PersonaService) ListPersonas(ctx context.Context, tenantID uuid.UUID, params types.ListParams) ([]types.PersonaSummary, int64, error) {
|
||||
params = types.DefaultListParams(params)
|
||||
|
||||
countQuery := `SELECT COUNT(*) FROM persona.personas WHERE tenant_id = $1`
|
||||
listQuery := `SELECT id, tenant_id, name, archetype, status, is_default, created_at, updated_at
|
||||
FROM persona.personas WHERE tenant_id = $1`
|
||||
|
||||
args := []interface{}{tenantID}
|
||||
argIdx := 2
|
||||
|
||||
if params.Status != "" {
|
||||
countQuery += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||
listQuery += fmt.Sprintf(" AND status = $%d", argIdx)
|
||||
args = append(args, params.Status)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("count query failed: %w", err)
|
||||
}
|
||||
|
||||
listQuery += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, params.Limit, params.Offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
personas := make([]types.PersonaSummary, 0)
|
||||
for rows.Next() {
|
||||
var p types.PersonaSummary
|
||||
if err := rows.Scan(&p.ID, &p.TenantID, &p.Name, &p.Archetype, &p.Status, &p.IsDefault, &p.CreatedAt, &p.UpdatedAt); err != nil {
|
||||
return nil, 0, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
personas = append(personas, p)
|
||||
}
|
||||
|
||||
return personas, total, nil
|
||||
}
|
||||
|
||||
// GetPersona gets a full persona configuration by ID and tenant
|
||||
func (s *PersonaService) GetPersona(ctx context.Context, id, tenantID uuid.UUID) (*types.PersonaConfig, error) {
|
||||
var p types.PersonaConfig
|
||||
var positiveRules, negativeRules, guardrailsConfig []byte
|
||||
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
SELECT id, tenant_id, name, archetype, voice_tone, mbti,
|
||||
openness, conscientiousness, extraversion, agreeableness, neuroticism,
|
||||
positive_rules, negative_rules, backstory, world_building,
|
||||
guardrails_config, topical_rails, status,
|
||||
default_model, temperature, max_tokens_per_turn,
|
||||
moral_care, moral_fairness, moral_rights,
|
||||
moral_loyalty, moral_authority, moral_sanctity,
|
||||
is_default, created_at, updated_at
|
||||
FROM persona.personas
|
||||
WHERE id = $1 AND tenant_id = $2
|
||||
`, id, tenantID).Scan(
|
||||
&p.ID, &p.TenantID, &p.Name, &p.Archetype, &p.VoiceTone, &p.MBTI,
|
||||
&p.Openness, &p.Conscientiousness, &p.Extraversion, &p.Agreeableness, &p.Neuroticism,
|
||||
&positiveRules, &negativeRules, &p.Backstory, &p.WorldBuilding,
|
||||
&guardrailsConfig, &p.TopicalRails, &p.Status,
|
||||
&p.DefaultModel, &p.Temperature, &p.MaxTokensPerTurn,
|
||||
&p.MoralCare, &p.MoralFairness, &p.MoralRights,
|
||||
&p.MoralLoyalty, &p.MoralAuthority, &p.MoralSanctity,
|
||||
&p.IsDefault, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("not found: %w", err)
|
||||
}
|
||||
|
||||
p.PositiveRules = json.RawMessage(positiveRules)
|
||||
p.NegativeRules = json.RawMessage(negativeRules)
|
||||
p.GuardrailsConfig = json.RawMessage(guardrailsConfig)
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// CreatePersona creates a new persona
|
||||
func (s *PersonaService) CreatePersona(ctx context.Context, req *types.PersonaCreate) (*types.PersonaConfig, error) {
|
||||
positiveRules := req.PositiveRules
|
||||
if len(positiveRules) == 0 {
|
||||
positiveRules = json.RawMessage(`[]`)
|
||||
}
|
||||
negativeRules := req.NegativeRules
|
||||
if len(negativeRules) == 0 {
|
||||
negativeRules = json.RawMessage(`[]`)
|
||||
}
|
||||
guardrailsConfig := req.GuardrailsConfig
|
||||
if len(guardrailsConfig) == 0 {
|
||||
guardrailsConfig = json.RawMessage(`{}`)
|
||||
}
|
||||
|
||||
var p types.PersonaConfig
|
||||
var prOut, nrOut, gcOut []byte
|
||||
|
||||
err := s.pool.QueryRow(ctx, `
|
||||
INSERT INTO persona.personas (
|
||||
tenant_id, name, archetype, voice_tone, mbti,
|
||||
openness, conscientiousness, extraversion, agreeableness, neuroticism,
|
||||
positive_rules, negative_rules, backstory, world_building,
|
||||
guardrails_config, topical_rails,
|
||||
default_model, temperature, max_tokens_per_turn,
|
||||
moral_care, moral_fairness, moral_rights,
|
||||
moral_loyalty, moral_authority, moral_sanctity
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25)
|
||||
RETURNING id, tenant_id, name, archetype, voice_tone, mbti,
|
||||
openness, conscientiousness, extraversion, agreeableness, neuroticism,
|
||||
positive_rules, negative_rules, backstory, world_building,
|
||||
guardrails_config, topical_rails, status,
|
||||
default_model, temperature, max_tokens_per_turn,
|
||||
moral_care, moral_fairness, moral_rights,
|
||||
moral_loyalty, moral_authority, moral_sanctity,
|
||||
is_default, created_at, updated_at`,
|
||||
req.TenantID, req.Name, req.Archetype, req.VoiceTone, req.MBTI,
|
||||
req.Openness, req.Conscientiousness, req.Extraversion, req.Agreeableness, req.Neuroticism,
|
||||
positiveRules, negativeRules, req.Backstory, req.WorldBuilding,
|
||||
guardrailsConfig, req.TopicalRails,
|
||||
req.DefaultModel, req.Temperature, req.MaxTokensPerTurn,
|
||||
req.MoralCare, req.MoralFairness, req.MoralRights,
|
||||
req.MoralLoyalty, req.MoralAuthority, req.MoralSanctity,
|
||||
).Scan(
|
||||
&p.ID, &p.TenantID, &p.Name, &p.Archetype, &p.VoiceTone, &p.MBTI,
|
||||
&p.Openness, &p.Conscientiousness, &p.Extraversion, &p.Agreeableness, &p.Neuroticism,
|
||||
&prOut, &nrOut, &p.Backstory, &p.WorldBuilding,
|
||||
&gcOut, &p.TopicalRails, &p.Status,
|
||||
&p.DefaultModel, &p.Temperature, &p.MaxTokensPerTurn,
|
||||
&p.MoralCare, &p.MoralFairness, &p.MoralRights,
|
||||
&p.MoralLoyalty, &p.MoralAuthority, &p.MoralSanctity,
|
||||
&p.IsDefault, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
|
||||
p.PositiveRules = json.RawMessage(prOut)
|
||||
p.NegativeRules = json.RawMessage(nrOut)
|
||||
p.GuardrailsConfig = json.RawMessage(gcOut)
|
||||
s.logger.Info().Str("id", p.ID.String()).Str("name", p.Name).Msg("Created persona")
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// UpdatePersona updates an existing persona
|
||||
func (s *PersonaService) UpdatePersona(ctx context.Context, id, tenantID uuid.UUID, req *types.PersonaUpdate) (*types.PersonaConfig, error) {
|
||||
setClauses := []string{}
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
addField := func(clause string, val interface{}) {
|
||||
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", clause, argIdx))
|
||||
args = append(args, val)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.Name != nil {
|
||||
addField("name", *req.Name)
|
||||
}
|
||||
if req.Archetype != nil {
|
||||
addField("archetype", *req.Archetype)
|
||||
}
|
||||
if req.VoiceTone != nil {
|
||||
addField("voice_tone", *req.VoiceTone)
|
||||
}
|
||||
if req.MBTI != nil {
|
||||
addField("mbti", *req.MBTI)
|
||||
}
|
||||
if req.Openness != nil {
|
||||
addField("openness", *req.Openness)
|
||||
}
|
||||
if req.Conscientiousness != nil {
|
||||
addField("conscientiousness", *req.Conscientiousness)
|
||||
}
|
||||
if req.Extraversion != nil {
|
||||
addField("extraversion", *req.Extraversion)
|
||||
}
|
||||
if req.Agreeableness != nil {
|
||||
addField("agreeableness", *req.Agreeableness)
|
||||
}
|
||||
if req.Neuroticism != nil {
|
||||
addField("neuroticism", *req.Neuroticism)
|
||||
}
|
||||
if len(req.PositiveRules) > 0 {
|
||||
addField("positive_rules", req.PositiveRules)
|
||||
}
|
||||
if len(req.NegativeRules) > 0 {
|
||||
addField("negative_rules", req.NegativeRules)
|
||||
}
|
||||
if req.Backstory != nil {
|
||||
addField("backstory", *req.Backstory)
|
||||
}
|
||||
if req.WorldBuilding != nil {
|
||||
addField("world_building", *req.WorldBuilding)
|
||||
}
|
||||
if len(req.GuardrailsConfig) > 0 {
|
||||
addField("guardrails_config", req.GuardrailsConfig)
|
||||
}
|
||||
if req.TopicalRails != nil {
|
||||
addField("topical_rails", *req.TopicalRails)
|
||||
}
|
||||
if req.Status != nil {
|
||||
addField("status", *req.Status)
|
||||
}
|
||||
if req.DefaultModel != nil {
|
||||
addField("default_model", *req.DefaultModel)
|
||||
}
|
||||
if req.Temperature != nil {
|
||||
addField("temperature", *req.Temperature)
|
||||
}
|
||||
if req.MaxTokensPerTurn != nil {
|
||||
addField("max_tokens_per_turn", *req.MaxTokensPerTurn)
|
||||
}
|
||||
if req.MoralCare != nil {
|
||||
addField("moral_care", *req.MoralCare)
|
||||
}
|
||||
if req.MoralFairness != nil {
|
||||
addField("moral_fairness", *req.MoralFairness)
|
||||
}
|
||||
if req.MoralRights != nil {
|
||||
addField("moral_rights", *req.MoralRights)
|
||||
}
|
||||
if req.MoralLoyalty != nil {
|
||||
addField("moral_loyalty", *req.MoralLoyalty)
|
||||
}
|
||||
if req.MoralAuthority != nil {
|
||||
addField("moral_authority", *req.MoralAuthority)
|
||||
}
|
||||
if req.MoralSanctity != nil {
|
||||
addField("moral_sanctity", *req.MoralSanctity)
|
||||
}
|
||||
if req.IsDefault != nil {
|
||||
addField("is_default", *req.IsDefault)
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return s.GetPersona(ctx, id, tenantID)
|
||||
}
|
||||
|
||||
setClauses = append(setClauses, "updated_at = NOW()")
|
||||
|
||||
query := fmt.Sprintf("UPDATE persona.personas SET %s WHERE id = $%d AND tenant_id = $%d",
|
||||
joinClauses(setClauses), argIdx, argIdx+1)
|
||||
args = append(args, id, tenantID)
|
||||
query += ` RETURNING id, tenant_id, name, archetype, voice_tone, mbti,
|
||||
openness, conscientiousness, extraversion, agreeableness, neuroticism,
|
||||
positive_rules, negative_rules, backstory, world_building,
|
||||
guardrails_config, topical_rails, status,
|
||||
default_model, temperature, max_tokens_per_turn,
|
||||
moral_care, moral_fairness, moral_rights,
|
||||
moral_loyalty, moral_authority, moral_sanctity,
|
||||
is_default, created_at, updated_at`
|
||||
|
||||
var p types.PersonaConfig
|
||||
var positiveRules, negativeRules, guardrailsConfig []byte
|
||||
err := s.pool.QueryRow(ctx, query, args...).Scan(
|
||||
&p.ID, &p.TenantID, &p.Name, &p.Archetype, &p.VoiceTone, &p.MBTI,
|
||||
&p.Openness, &p.Conscientiousness, &p.Extraversion, &p.Agreeableness, &p.Neuroticism,
|
||||
&positiveRules, &negativeRules, &p.Backstory, &p.WorldBuilding,
|
||||
&guardrailsConfig, &p.TopicalRails, &p.Status,
|
||||
&p.DefaultModel, &p.Temperature, &p.MaxTokensPerTurn,
|
||||
&p.MoralCare, &p.MoralFairness, &p.MoralRights,
|
||||
&p.MoralLoyalty, &p.MoralAuthority, &p.MoralSanctity,
|
||||
&p.IsDefault, &p.CreatedAt, &p.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
p.PositiveRules = json.RawMessage(positiveRules)
|
||||
p.NegativeRules = json.RawMessage(negativeRules)
|
||||
p.GuardrailsConfig = json.RawMessage(guardrailsConfig)
|
||||
s.logger.Info().Str("id", id.String()).Msg("Updated persona")
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// DeletePersona deletes a persona
|
||||
func (s *PersonaService) DeletePersona(ctx context.Context, id, tenantID uuid.UUID) error {
|
||||
tag, err := s.pool.Exec(ctx, `DELETE FROM persona.personas WHERE id = $1 AND tenant_id = $2`, id, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("persona not found")
|
||||
}
|
||||
s.logger.Info().Str("id", id.String()).Msg("Deleted persona")
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetSelfModel returns the self-model snapshot for a persona
|
||||
func (s *PersonaService) GetSelfModel(ctx context.Context, personaID, tenantID uuid.UUID) (*types.SelfModelSnapshot, error) {
|
||||
snapshot := &types.SelfModelSnapshot{
|
||||
IdentityConstraints: make([]types.IdentityConstraint, 0),
|
||||
Commitments: make([]types.PersonaCommitment, 0),
|
||||
ConscienceStandards: make([]types.ConscienceStandard, 0),
|
||||
}
|
||||
|
||||
// Identity constraints
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT constraint_type, constraint_text, description, source, strength
|
||||
FROM persona.identity_constraints
|
||||
WHERE persona_id = $1 AND tenant_id = $2 AND is_active = true
|
||||
ORDER BY strength DESC
|
||||
`, personaID, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("constraints query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var c types.IdentityConstraint
|
||||
var strength *float64
|
||||
if err := rows.Scan(&c.ConstraintType, &c.ConstraintText, &c.Description, &c.Source, &strength); err != nil {
|
||||
return nil, fmt.Errorf("constraint scan failed: %w", err)
|
||||
}
|
||||
if strength != nil {
|
||||
c.Strength = *strength
|
||||
} else {
|
||||
c.Strength = 1.0
|
||||
}
|
||||
snapshot.IdentityConstraints = append(snapshot.IdentityConstraints, c)
|
||||
}
|
||||
|
||||
// Commitments
|
||||
//
|
||||
// persona_commitments tracks session-bound commitments the assistant
|
||||
// has made during conversation; it has no `source` or `strength`
|
||||
// columns (the active flag is `status='active'`, not `is_active`).
|
||||
// Synthesise both fields for the snapshot so the SelfModel contract
|
||||
// stays stable for callers.
|
||||
commitRows, err := s.pool.Query(ctx, `
|
||||
SELECT commitment_text, COALESCE(commitment_type, '')
|
||||
FROM persona.persona_commitments
|
||||
WHERE persona_id = $1 AND tenant_id = $2 AND status = 'active'
|
||||
ORDER BY created_at DESC
|
||||
`, personaID, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("commitments query failed: %w", err)
|
||||
}
|
||||
defer commitRows.Close()
|
||||
|
||||
commitSource := "learned"
|
||||
for commitRows.Next() {
|
||||
var c types.PersonaCommitment
|
||||
if err := commitRows.Scan(&c.CommitmentText, &c.CommitmentType); err != nil {
|
||||
return nil, fmt.Errorf("commitment scan failed: %w", err)
|
||||
}
|
||||
c.Source = &commitSource
|
||||
c.Strength = 1.0
|
||||
snapshot.Commitments = append(snapshot.Commitments, c)
|
||||
}
|
||||
|
||||
// Conscience standards
|
||||
stdRows, err := s.pool.Query(ctx, `
|
||||
SELECT standard_text, standard_type, moral_foundation, strength
|
||||
FROM persona.conscience_standards
|
||||
WHERE persona_id = $1 AND tenant_id = $2 AND is_active = true
|
||||
ORDER BY strength DESC
|
||||
`, personaID, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("standards query failed: %w", err)
|
||||
}
|
||||
defer stdRows.Close()
|
||||
|
||||
for stdRows.Next() {
|
||||
var s types.ConscienceStandard
|
||||
var strength *float64
|
||||
if err := stdRows.Scan(&s.StandardText, &s.StandardType, &s.MoralFoundation, &strength); err != nil {
|
||||
return nil, fmt.Errorf("standard scan failed: %w", err)
|
||||
}
|
||||
if strength != nil {
|
||||
s.Strength = *strength
|
||||
} else {
|
||||
s.Strength = 1.0
|
||||
}
|
||||
snapshot.ConscienceStandards = append(snapshot.ConscienceStandards, s)
|
||||
}
|
||||
|
||||
return snapshot, nil
|
||||
}
|
||||
|
||||
// SearchExperiences returns experiences for a persona ordered by importance
|
||||
func (s *PersonaService) SearchExperiences(ctx context.Context, personaID, tenantID uuid.UUID, limit int) ([]types.Experience, error) {
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 20
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT id, event_summary, event_type, occurred_at, place,
|
||||
actors, outcome, outcome_detail,
|
||||
emotional_valence, lesson_learned, importance_score
|
||||
FROM persona.experiences
|
||||
WHERE persona_id = $1 AND tenant_id = $2
|
||||
ORDER BY importance_score DESC, occurred_at DESC
|
||||
LIMIT $3
|
||||
`, personaID, tenantID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("experiences query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
experiences := make([]types.Experience, 0)
|
||||
for rows.Next() {
|
||||
var e types.Experience
|
||||
if err := rows.Scan(
|
||||
&e.ID, &e.EventSummary, &e.EventType, &e.OccurredAt, &e.Place,
|
||||
&e.Actors, &e.Outcome, &e.OutcomeDetail,
|
||||
&e.EmotionalValence, &e.LessonLearned, &e.ImportanceScore,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("experience scan failed: %w", err)
|
||||
}
|
||||
experiences = append(experiences, e)
|
||||
}
|
||||
|
||||
return experiences, nil
|
||||
}
|
||||
|
||||
// GetEvaluations returns evaluations for a session
|
||||
func (s *PersonaService) GetEvaluations(ctx context.Context, sessionID uuid.UUID, limit int) ([]types.Evaluation, error) {
|
||||
if limit <= 0 || limit > 100 {
|
||||
limit = 10
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT e.role_fidelity, e.voice_consistency,
|
||||
e.safety_compliance, e.character_break,
|
||||
e.drift_score, e.evaluator_model, e.evaluated_at
|
||||
FROM persona.evaluations e
|
||||
JOIN persona.messages m ON m.id = e.message_id
|
||||
WHERE m.session_id = $1
|
||||
ORDER BY e.evaluated_at DESC
|
||||
LIMIT $2
|
||||
`, sessionID, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("evaluations query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
evaluations := make([]types.Evaluation, 0)
|
||||
for rows.Next() {
|
||||
var e types.Evaluation
|
||||
if err := rows.Scan(
|
||||
&e.RoleFidelity, &e.VoiceConsistency,
|
||||
&e.SafetyCompliance, &e.CharacterBreak,
|
||||
&e.DriftScore, &e.EvaluatorModel, &e.EvaluatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("evaluation scan failed: %w", err)
|
||||
}
|
||||
evaluations = append(evaluations, e)
|
||||
}
|
||||
|
||||
return evaluations, nil
|
||||
}
|
||||
|
||||
// GetMoralPattern returns moral assessments for a session
|
||||
func (s *PersonaService) GetMoralPattern(ctx context.Context, sessionID, tenantID uuid.UUID) ([]types.MoralAssessment, error) {
|
||||
rows, err := s.pool.Query(ctx, `
|
||||
SELECT activated_foundations, assessment_text,
|
||||
has_tension, tension_foundations,
|
||||
resolution_foundation, confidence
|
||||
FROM persona.moral_assessments
|
||||
WHERE session_id = $1 AND tenant_id = $2
|
||||
ORDER BY created_at DESC
|
||||
LIMIT 5
|
||||
`, sessionID, tenantID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("moral pattern query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
assessments := make([]types.MoralAssessment, 0)
|
||||
for rows.Next() {
|
||||
var a types.MoralAssessment
|
||||
var activatedFoundations []byte
|
||||
if err := rows.Scan(
|
||||
&activatedFoundations, &a.AssessmentText,
|
||||
&a.HasTension, &a.TensionFoundations,
|
||||
&a.ResolutionFoundation, &a.Confidence,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("moral assessment scan failed: %w", err)
|
||||
}
|
||||
a.ActivatedFoundations = json.RawMessage(activatedFoundations)
|
||||
assessments = append(assessments, a)
|
||||
}
|
||||
|
||||
return assessments, nil
|
||||
}
|
||||
78
internal/service/personal_agent.go
Normal file
78
internal/service/personal_agent.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// PersonalAgentService handles personal agent config operations
|
||||
type PersonalAgentService struct {
|
||||
pool *pgxpool.Pool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewPersonalAgentService creates a new personal agent service
|
||||
func NewPersonalAgentService(pool *pgxpool.Pool, logger zerolog.Logger) *PersonalAgentService {
|
||||
return &PersonalAgentService{
|
||||
pool: pool,
|
||||
logger: logger.With().Str("service", "personal_agent").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetConfig gets a user's personal agent config
|
||||
func (s *PersonalAgentService) GetConfig(ctx context.Context, userID, tenantID uuid.UUID) (*types.UserAgentConfig, error) {
|
||||
var c types.UserAgentConfig
|
||||
var configBytes []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, user_id, tenant_id, config, created_at, updated_at
|
||||
FROM admin.user_agent_configs WHERE user_id = $1 AND tenant_id = $2`,
|
||||
userID, tenantID).
|
||||
Scan(&c.ID, &c.UserID, &c.TenantID, &configBytes, &c.CreatedAt, &c.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("not found: %w", err)
|
||||
}
|
||||
c.Config = json.RawMessage(configBytes)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// UpsertConfig creates or updates a user's personal agent config
|
||||
func (s *PersonalAgentService) UpsertConfig(ctx context.Context, req *types.UserAgentConfigUpsert) (*types.UserAgentConfig, error) {
|
||||
var c types.UserAgentConfig
|
||||
var configBytes []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO admin.user_agent_configs (user_id, tenant_id, config)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (user_id, tenant_id) DO UPDATE
|
||||
SET config = EXCLUDED.config, updated_at = NOW()
|
||||
RETURNING id, user_id, tenant_id, config, created_at, updated_at`,
|
||||
req.UserID, req.TenantID, req.Config).
|
||||
Scan(&c.ID, &c.UserID, &c.TenantID, &configBytes, &c.CreatedAt, &c.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("upsert failed: %w", err)
|
||||
}
|
||||
c.Config = json.RawMessage(configBytes)
|
||||
s.logger.Info().Str("userId", req.UserID.String()).Str("tenantId", req.TenantID.String()).Msg("Upserted personal agent config")
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// DeleteConfig deletes a user's personal agent config
|
||||
func (s *PersonalAgentService) DeleteConfig(ctx context.Context, userID, tenantID uuid.UUID) error {
|
||||
tag, err := s.pool.Exec(ctx,
|
||||
`DELETE FROM admin.user_agent_configs WHERE user_id = $1 AND tenant_id = $2`,
|
||||
userID, tenantID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("config not found")
|
||||
}
|
||||
s.logger.Info().Str("userId", userID.String()).Str("tenantId", tenantID.String()).Msg("Deleted personal agent config")
|
||||
return nil
|
||||
}
|
||||
110
internal/service/pgp.go
Normal file
110
internal/service/pgp.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/internal/client"
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// PGPService handles Hockeypuck PGP key operations
|
||||
type PGPService struct {
|
||||
client *client.HockeypuckClient
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewPGPService creates a new PGP service
|
||||
func NewPGPService(hkpClient *client.HockeypuckClient, logger zerolog.Logger) *PGPService {
|
||||
return &PGPService{
|
||||
client: hkpClient,
|
||||
logger: logger.With().Str("service", "pgp").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// SearchKeys searches for PGP keys
|
||||
func (s *PGPService) SearchKeys(query string) ([]types.PGPKey, error) {
|
||||
result, err := s.client.SearchKeys(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if result == "" {
|
||||
return []types.PGPKey{}, nil
|
||||
}
|
||||
|
||||
return parseMachineReadableIndex(result), nil
|
||||
}
|
||||
|
||||
// GetKey retrieves a PGP key by key ID
|
||||
func (s *PGPService) GetKey(keyID string) (*types.PGPKey, error) {
|
||||
armoredKey, err := s.client.GetKey(keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if armoredKey == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return &types.PGPKey{
|
||||
KeyID: keyID,
|
||||
ArmoredKey: armoredKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// UploadKey uploads a PGP public key
|
||||
func (s *PGPService) UploadKey(keyText string) error {
|
||||
return s.client.UploadKey(keyText)
|
||||
}
|
||||
|
||||
// DeleteKey deletes a PGP key
|
||||
func (s *PGPService) DeleteKey(keyID string) error {
|
||||
return s.client.DeleteKey(keyID)
|
||||
}
|
||||
|
||||
// parseMachineReadableIndex parses the HKP machine-readable index format
|
||||
func parseMachineReadableIndex(data string) []types.PGPKey {
|
||||
keys := []types.PGPKey{}
|
||||
var current *types.PGPKey
|
||||
|
||||
scanner := bufio.NewScanner(strings.NewReader(data))
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Split(line, ":")
|
||||
|
||||
if len(fields) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch fields[0] {
|
||||
case "pub":
|
||||
if current != nil {
|
||||
keys = append(keys, *current)
|
||||
}
|
||||
current = &types.PGPKey{}
|
||||
if len(fields) > 1 {
|
||||
current.KeyID = fields[1]
|
||||
}
|
||||
if len(fields) > 2 {
|
||||
current.Algorithm = fields[2]
|
||||
}
|
||||
if len(fields) > 4 {
|
||||
current.Created = fields[4]
|
||||
}
|
||||
if len(fields) > 5 {
|
||||
current.Expires = fields[5]
|
||||
}
|
||||
case "uid":
|
||||
if current != nil && len(fields) > 1 {
|
||||
current.UIDs = append(current.UIDs, fields[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if current != nil {
|
||||
keys = append(keys, *current)
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
453
internal/service/voice_agent.go
Normal file
453
internal/service/voice_agent.go
Normal file
@@ -0,0 +1,453 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"github.com/gosec/gsc-ops-api/pkg/types"
|
||||
)
|
||||
|
||||
// VoiceAgentService handles voice agent config and session operations
|
||||
type VoiceAgentService struct {
|
||||
pool *pgxpool.Pool
|
||||
logger zerolog.Logger
|
||||
}
|
||||
|
||||
// NewVoiceAgentService creates a new voice agent service
|
||||
func NewVoiceAgentService(pool *pgxpool.Pool, logger zerolog.Logger) *VoiceAgentService {
|
||||
return &VoiceAgentService{
|
||||
pool: pool,
|
||||
logger: logger.With().Str("service", "voice_agent").Logger(),
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Voice Agent Configs
|
||||
// ============================================================================
|
||||
|
||||
// ListConfigs lists voice agent configs with optional filters
|
||||
func (s *VoiceAgentService) ListConfigs(ctx context.Context, params types.ListParams, tenantID *uuid.UUID) ([]types.VoiceAgentConfig, int64, error) {
|
||||
params = types.DefaultListParams(params)
|
||||
|
||||
countQuery := `SELECT COUNT(*) FROM voice_agent_configs WHERE 1=1`
|
||||
listQuery := `SELECT id, tenant_id, agent_id, greeting_text, goodbye_text,
|
||||
voice_id, language,
|
||||
stt_provider, stt_model, tts_provider, tts_model,
|
||||
max_call_duration_seconds, silence_timeout_seconds,
|
||||
barge_in_enabled, vad_sensitivity,
|
||||
transfer_enabled, transfer_number,
|
||||
business_hours_enabled, business_hours, after_hours_text,
|
||||
is_active, created_at, updated_at
|
||||
FROM voice_agent_configs WHERE 1=1`
|
||||
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
if tenantID != nil {
|
||||
countQuery += fmt.Sprintf(" AND tenant_id = $%d", argIdx)
|
||||
listQuery += fmt.Sprintf(" AND tenant_id = $%d", argIdx)
|
||||
args = append(args, *tenantID)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if params.Search != "" {
|
||||
countQuery += fmt.Sprintf(" AND (greeting_text ILIKE $%d OR voice_id ILIKE $%d OR language ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
listQuery += fmt.Sprintf(" AND (greeting_text ILIKE $%d OR voice_id ILIKE $%d OR language ILIKE $%d)", argIdx, argIdx, argIdx)
|
||||
args = append(args, "%"+params.Search+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
var total int64
|
||||
if err := s.pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("count query failed: %w", err)
|
||||
}
|
||||
|
||||
listQuery += fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", argIdx, argIdx+1)
|
||||
args = append(args, params.Limit, params.Offset)
|
||||
|
||||
rows, err := s.pool.Query(ctx, listQuery, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
configs := make([]types.VoiceAgentConfig, 0)
|
||||
for rows.Next() {
|
||||
var c types.VoiceAgentConfig
|
||||
var businessHours []byte
|
||||
if err := rows.Scan(
|
||||
&c.ID, &c.TenantID, &c.AgentID, &c.GreetingText, &c.GoodbyeText,
|
||||
&c.VoiceID, &c.Language,
|
||||
&c.STTProvider, &c.STTModel, &c.TTSProvider, &c.TTSModel,
|
||||
&c.MaxCallDurationSeconds, &c.SilenceTimeoutSeconds,
|
||||
&c.BargeInEnabled, &c.VADSensitivity,
|
||||
&c.TransferEnabled, &c.TransferNumber,
|
||||
&c.BusinessHoursEnabled, &businessHours, &c.AfterHoursText,
|
||||
&c.IsActive, &c.CreatedAt, &c.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
c.BusinessHours = json.RawMessage(businessHours)
|
||||
configs = append(configs, c)
|
||||
}
|
||||
|
||||
return configs, total, nil
|
||||
}
|
||||
|
||||
// GetConfig gets a voice agent config by ID
|
||||
func (s *VoiceAgentService) GetConfig(ctx context.Context, id uuid.UUID) (*types.VoiceAgentConfig, error) {
|
||||
var c types.VoiceAgentConfig
|
||||
var businessHours []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, tenant_id, agent_id, greeting_text, goodbye_text,
|
||||
voice_id, language,
|
||||
stt_provider, stt_model, tts_provider, tts_model,
|
||||
max_call_duration_seconds, silence_timeout_seconds,
|
||||
barge_in_enabled, vad_sensitivity,
|
||||
transfer_enabled, transfer_number,
|
||||
business_hours_enabled, business_hours, after_hours_text,
|
||||
is_active, created_at, updated_at
|
||||
FROM voice_agent_configs WHERE id = $1`, id).
|
||||
Scan(
|
||||
&c.ID, &c.TenantID, &c.AgentID, &c.GreetingText, &c.GoodbyeText,
|
||||
&c.VoiceID, &c.Language,
|
||||
&c.STTProvider, &c.STTModel, &c.TTSProvider, &c.TTSModel,
|
||||
&c.MaxCallDurationSeconds, &c.SilenceTimeoutSeconds,
|
||||
&c.BargeInEnabled, &c.VADSensitivity,
|
||||
&c.TransferEnabled, &c.TransferNumber,
|
||||
&c.BusinessHoursEnabled, &businessHours, &c.AfterHoursText,
|
||||
&c.IsActive, &c.CreatedAt, &c.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("not found: %w", err)
|
||||
}
|
||||
c.BusinessHours = json.RawMessage(businessHours)
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// CreateConfig creates a new voice agent config
|
||||
func (s *VoiceAgentService) CreateConfig(ctx context.Context, req *types.VoiceAgentConfigCreate) (*types.VoiceAgentConfig, error) {
|
||||
// Set defaults
|
||||
greeting := req.GreetingText
|
||||
if greeting == "" {
|
||||
greeting = "Hello, how can I help you today?"
|
||||
}
|
||||
goodbye := req.GoodbyeText
|
||||
if goodbye == "" {
|
||||
goodbye = "Goodbye, have a great day."
|
||||
}
|
||||
voiceID := req.VoiceID
|
||||
if voiceID == "" {
|
||||
voiceID = "alloy"
|
||||
}
|
||||
lang := req.Language
|
||||
if lang == "" {
|
||||
lang = "en"
|
||||
}
|
||||
maxDuration := 1800
|
||||
if req.MaxCallDurationSeconds != nil {
|
||||
maxDuration = *req.MaxCallDurationSeconds
|
||||
}
|
||||
silenceTimeout := 30
|
||||
if req.SilenceTimeoutSeconds != nil {
|
||||
silenceTimeout = *req.SilenceTimeoutSeconds
|
||||
}
|
||||
bargeIn := true
|
||||
if req.BargeInEnabled != nil {
|
||||
bargeIn = *req.BargeInEnabled
|
||||
}
|
||||
vadSens := req.VADSensitivity
|
||||
if vadSens == "" {
|
||||
vadSens = "medium"
|
||||
}
|
||||
transfer := true
|
||||
if req.TransferEnabled != nil {
|
||||
transfer = *req.TransferEnabled
|
||||
}
|
||||
bizHoursEnabled := false
|
||||
if req.BusinessHoursEnabled != nil {
|
||||
bizHoursEnabled = *req.BusinessHoursEnabled
|
||||
}
|
||||
bizHours := req.BusinessHours
|
||||
if len(bizHours) == 0 {
|
||||
bizHours = json.RawMessage(`{}`)
|
||||
}
|
||||
|
||||
var c types.VoiceAgentConfig
|
||||
var businessHoursOut []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`INSERT INTO voice_agent_configs (
|
||||
tenant_id, agent_id, greeting_text, goodbye_text,
|
||||
voice_id, language,
|
||||
stt_provider, stt_model, tts_provider, tts_model,
|
||||
max_call_duration_seconds, silence_timeout_seconds,
|
||||
barge_in_enabled, vad_sensitivity,
|
||||
transfer_enabled, transfer_number,
|
||||
business_hours_enabled, business_hours, after_hours_text
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||
RETURNING id, tenant_id, agent_id, greeting_text, goodbye_text,
|
||||
voice_id, language,
|
||||
stt_provider, stt_model, tts_provider, tts_model,
|
||||
max_call_duration_seconds, silence_timeout_seconds,
|
||||
barge_in_enabled, vad_sensitivity,
|
||||
transfer_enabled, transfer_number,
|
||||
business_hours_enabled, business_hours, after_hours_text,
|
||||
is_active, created_at, updated_at`,
|
||||
req.TenantID, req.AgentID, greeting, goodbye,
|
||||
voiceID, lang,
|
||||
req.STTProvider, req.STTModel, req.TTSProvider, req.TTSModel,
|
||||
maxDuration, silenceTimeout,
|
||||
bargeIn, vadSens,
|
||||
transfer, req.TransferNumber,
|
||||
bizHoursEnabled, bizHours, req.AfterHoursText,
|
||||
).Scan(
|
||||
&c.ID, &c.TenantID, &c.AgentID, &c.GreetingText, &c.GoodbyeText,
|
||||
&c.VoiceID, &c.Language,
|
||||
&c.STTProvider, &c.STTModel, &c.TTSProvider, &c.TTSModel,
|
||||
&c.MaxCallDurationSeconds, &c.SilenceTimeoutSeconds,
|
||||
&c.BargeInEnabled, &c.VADSensitivity,
|
||||
&c.TransferEnabled, &c.TransferNumber,
|
||||
&c.BusinessHoursEnabled, &businessHoursOut, &c.AfterHoursText,
|
||||
&c.IsActive, &c.CreatedAt, &c.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("insert failed: %w", err)
|
||||
}
|
||||
c.BusinessHours = json.RawMessage(businessHoursOut)
|
||||
s.logger.Info().Str("id", c.ID.String()).Str("agentId", c.AgentID.String()).Msg("Created voice agent config")
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// UpdateConfig updates a voice agent config
|
||||
func (s *VoiceAgentService) UpdateConfig(ctx context.Context, id uuid.UUID, req *types.VoiceAgentConfigUpdate) (*types.VoiceAgentConfig, error) {
|
||||
// Build dynamic SET clause
|
||||
setClauses := []string{}
|
||||
args := []interface{}{}
|
||||
argIdx := 1
|
||||
|
||||
addField := func(clause string, val interface{}) {
|
||||
setClauses = append(setClauses, fmt.Sprintf("%s = $%d", clause, argIdx))
|
||||
args = append(args, val)
|
||||
argIdx++
|
||||
}
|
||||
|
||||
if req.GreetingText != nil {
|
||||
addField("greeting_text", *req.GreetingText)
|
||||
}
|
||||
if req.GoodbyeText != nil {
|
||||
addField("goodbye_text", *req.GoodbyeText)
|
||||
}
|
||||
if req.VoiceID != nil {
|
||||
addField("voice_id", *req.VoiceID)
|
||||
}
|
||||
if req.Language != nil {
|
||||
addField("language", *req.Language)
|
||||
}
|
||||
if req.STTProvider != nil {
|
||||
addField("stt_provider", *req.STTProvider)
|
||||
}
|
||||
if req.STTModel != nil {
|
||||
addField("stt_model", *req.STTModel)
|
||||
}
|
||||
if req.TTSProvider != nil {
|
||||
addField("tts_provider", *req.TTSProvider)
|
||||
}
|
||||
if req.TTSModel != nil {
|
||||
addField("tts_model", *req.TTSModel)
|
||||
}
|
||||
if req.MaxCallDurationSeconds != nil {
|
||||
addField("max_call_duration_seconds", *req.MaxCallDurationSeconds)
|
||||
}
|
||||
if req.SilenceTimeoutSeconds != nil {
|
||||
addField("silence_timeout_seconds", *req.SilenceTimeoutSeconds)
|
||||
}
|
||||
if req.BargeInEnabled != nil {
|
||||
addField("barge_in_enabled", *req.BargeInEnabled)
|
||||
}
|
||||
if req.VADSensitivity != nil {
|
||||
addField("vad_sensitivity", *req.VADSensitivity)
|
||||
}
|
||||
if req.TransferEnabled != nil {
|
||||
addField("transfer_enabled", *req.TransferEnabled)
|
||||
}
|
||||
if req.TransferNumber != nil {
|
||||
addField("transfer_number", *req.TransferNumber)
|
||||
}
|
||||
if req.BusinessHoursEnabled != nil {
|
||||
addField("business_hours_enabled", *req.BusinessHoursEnabled)
|
||||
}
|
||||
if len(req.BusinessHours) > 0 {
|
||||
addField("business_hours", req.BusinessHours)
|
||||
}
|
||||
if req.AfterHoursText != nil {
|
||||
addField("after_hours_text", *req.AfterHoursText)
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
addField("is_active", *req.IsActive)
|
||||
}
|
||||
|
||||
if len(setClauses) == 0 {
|
||||
return s.GetConfig(ctx, id)
|
||||
}
|
||||
|
||||
// Always update updated_at
|
||||
setClauses = append(setClauses, "updated_at = NOW()")
|
||||
|
||||
query := fmt.Sprintf("UPDATE voice_agent_configs SET %s WHERE id = $%d",
|
||||
joinClauses(setClauses), argIdx)
|
||||
args = append(args, id)
|
||||
query += ` RETURNING id, tenant_id, agent_id, greeting_text, goodbye_text,
|
||||
voice_id, language,
|
||||
stt_provider, stt_model, tts_provider, tts_model,
|
||||
max_call_duration_seconds, silence_timeout_seconds,
|
||||
barge_in_enabled, vad_sensitivity,
|
||||
transfer_enabled, transfer_number,
|
||||
business_hours_enabled, business_hours, after_hours_text,
|
||||
is_active, created_at, updated_at`
|
||||
|
||||
var c types.VoiceAgentConfig
|
||||
var businessHours []byte
|
||||
err := s.pool.QueryRow(ctx, query, args...).Scan(
|
||||
&c.ID, &c.TenantID, &c.AgentID, &c.GreetingText, &c.GoodbyeText,
|
||||
&c.VoiceID, &c.Language,
|
||||
&c.STTProvider, &c.STTModel, &c.TTSProvider, &c.TTSModel,
|
||||
&c.MaxCallDurationSeconds, &c.SilenceTimeoutSeconds,
|
||||
&c.BargeInEnabled, &c.VADSensitivity,
|
||||
&c.TransferEnabled, &c.TransferNumber,
|
||||
&c.BusinessHoursEnabled, &businessHours, &c.AfterHoursText,
|
||||
&c.IsActive, &c.CreatedAt, &c.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update failed: %w", err)
|
||||
}
|
||||
c.BusinessHours = json.RawMessage(businessHours)
|
||||
s.logger.Info().Str("id", id.String()).Msg("Updated voice agent config")
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// DeleteConfig deletes a voice agent config
|
||||
func (s *VoiceAgentService) DeleteConfig(ctx context.Context, id uuid.UUID) error {
|
||||
tag, err := s.pool.Exec(ctx, `DELETE FROM voice_agent_configs WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete failed: %w", err)
|
||||
}
|
||||
if tag.RowsAffected() == 0 {
|
||||
return fmt.Errorf("config not found")
|
||||
}
|
||||
s.logger.Info().Str("id", id.String()).Msg("Deleted voice agent config")
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Voice Sessions
|
||||
// ============================================================================
|
||||
|
||||
// ListSessions lists voice sessions for a specific agent
|
||||
func (s *VoiceAgentService) ListSessions(ctx context.Context, agentID uuid.UUID, params types.ListParams) ([]types.VoiceSession, int64, error) {
|
||||
params = types.DefaultListParams(params)
|
||||
|
||||
var total int64
|
||||
if err := s.pool.QueryRow(ctx,
|
||||
`SELECT COUNT(*) FROM voice_sessions WHERE agent_id = $1`, agentID).Scan(&total); err != nil {
|
||||
return nil, 0, fmt.Errorf("count query failed: %w", err)
|
||||
}
|
||||
|
||||
rows, err := s.pool.Query(ctx,
|
||||
`SELECT id, tenant_id, agent_id, caller_number, called_number,
|
||||
asterisk_call_id, agent_session_id,
|
||||
total_turns, stt_provider, tts_provider,
|
||||
stt_audio_seconds, tts_characters,
|
||||
started_at, ended_at, end_reason, metadata, created_at
|
||||
FROM voice_sessions WHERE agent_id = $1
|
||||
ORDER BY started_at DESC LIMIT $2 OFFSET $3`,
|
||||
agentID, params.Limit, params.Offset)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("list query failed: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
sessions := make([]types.VoiceSession, 0)
|
||||
for rows.Next() {
|
||||
var vs types.VoiceSession
|
||||
var metadata []byte
|
||||
if err := rows.Scan(
|
||||
&vs.ID, &vs.TenantID, &vs.AgentID, &vs.CallerNumber, &vs.CalledNumber,
|
||||
&vs.AsteriskCallID, &vs.AgentSessionID,
|
||||
&vs.TotalTurns, &vs.STTProvider, &vs.TTSProvider,
|
||||
&vs.STTAudioSeconds, &vs.TTSCharacters,
|
||||
&vs.StartedAt, &vs.EndedAt, &vs.EndReason, &metadata, &vs.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, 0, fmt.Errorf("scan failed: %w", err)
|
||||
}
|
||||
vs.Metadata = json.RawMessage(metadata)
|
||||
sessions = append(sessions, vs)
|
||||
}
|
||||
|
||||
return sessions, total, nil
|
||||
}
|
||||
|
||||
// GetSession gets a voice session by ID, including turns
|
||||
func (s *VoiceAgentService) GetSession(ctx context.Context, sessionID uuid.UUID) (*types.VoiceSession, error) {
|
||||
var vs types.VoiceSession
|
||||
var metadata []byte
|
||||
err := s.pool.QueryRow(ctx,
|
||||
`SELECT id, tenant_id, agent_id, caller_number, called_number,
|
||||
asterisk_call_id, agent_session_id,
|
||||
total_turns, stt_provider, tts_provider,
|
||||
stt_audio_seconds, tts_characters,
|
||||
started_at, ended_at, end_reason, metadata, created_at
|
||||
FROM voice_sessions WHERE id = $1`, sessionID).
|
||||
Scan(
|
||||
&vs.ID, &vs.TenantID, &vs.AgentID, &vs.CallerNumber, &vs.CalledNumber,
|
||||
&vs.AsteriskCallID, &vs.AgentSessionID,
|
||||
&vs.TotalTurns, &vs.STTProvider, &vs.TTSProvider,
|
||||
&vs.STTAudioSeconds, &vs.TTSCharacters,
|
||||
&vs.StartedAt, &vs.EndedAt, &vs.EndReason, &metadata, &vs.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("session not found: %w", err)
|
||||
}
|
||||
vs.Metadata = json.RawMessage(metadata)
|
||||
|
||||
// Fetch turns
|
||||
turnRows, err := s.pool.Query(ctx,
|
||||
`SELECT id, session_id, turn_number, role, text,
|
||||
stt_confidence, agent_latency_ms, was_interrupted, created_at
|
||||
FROM voice_session_turns WHERE session_id = $1
|
||||
ORDER BY turn_number ASC`, sessionID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("turns query failed: %w", err)
|
||||
}
|
||||
defer turnRows.Close()
|
||||
|
||||
vs.Turns = make([]types.VoiceSessionTurn, 0)
|
||||
for turnRows.Next() {
|
||||
var t types.VoiceSessionTurn
|
||||
if err := turnRows.Scan(
|
||||
&t.ID, &t.SessionID, &t.TurnNumber, &t.Role, &t.Text,
|
||||
&t.STTConfidence, &t.AgentLatencyMs, &t.WasInterrupted, &t.CreatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("turn scan failed: %w", err)
|
||||
}
|
||||
vs.Turns = append(vs.Turns, t)
|
||||
}
|
||||
|
||||
return &vs, nil
|
||||
}
|
||||
|
||||
// joinClauses joins SQL SET clauses with commas
|
||||
func joinClauses(clauses []string) string {
|
||||
result := ""
|
||||
for i, c := range clauses {
|
||||
if i > 0 {
|
||||
result += ", "
|
||||
}
|
||||
result += c
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user