package caching
import (
"context"
"fmt"
"log/slog"
"strings"
"time"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/registry"
)
// ExchangeCache handles bulk caching of exchange rates
// in the infrastructure layer
type ExchangeCache struct {
exchangeRegistry registry.Provider
logger *slog.Logger
cfg *config.ExchangeRateCache
stopChan chan struct{} // Channel to stop background refresh
}
// NewExchangeCache creates a new ExchangeCache instance
func NewExchangeCache(
exchangeRegistry registry.Provider,
logger *slog.Logger,
cfg *config.ExchangeRateCache,
) *ExchangeCache {
if cfg == nil {
cfg = &config.ExchangeRateCache{
TTL: 15 * time.Minute, // Default TTL if not configured
}
}
if logger == nil {
logger = slog.Default()
}
logger = logger.With(slog.String("component", "exchange_cache"))
return &ExchangeCache{
exchangeRegistry: exchangeRegistry,
logger: logger,
cfg: cfg,
stopChan: make(chan struct{}),
}
}
// GetLastUpdated returns the timestamp of the last rate update
func (c *ExchangeCache) GetLastUpdated(ctx context.Context) (time.Time, error) {
key := c.getLastUpdatedKey()
entry, err := c.exchangeRegistry.Get(ctx, key)
if err != nil {
// If the entry doesn't exist, return zero time
if strings.Contains(
err.Error(), "not found") ||
strings.Contains(err.Error(), "no such file or directory") ||
strings.Contains(err.Error(), "entity not found") {
return time.Time{}, nil
}
return time.Time{}, fmt.Errorf("failed to get last updated time: %w", err)
}
if entry == nil {
return time.Time{}, nil
}
// Use reflection to handle different underlying types that might implement the interface
if info, ok := entry.(interface{ GetTimestamp() time.Time }); ok {
return info.GetTimestamp(), nil
}
// Fallback to checking metadata for timestamp
if meta, ok := entry.(interface{ Metadata() map[string]string }); ok {
if tsStr, ok := meta.Metadata()["timestamp"]; ok {
ts, err := time.Parse(time.RFC3339Nano, tsStr)
if err == nil {
return ts, nil
}
}
}
// If we have an UpdatedAt method, use that
if updatable, ok := entry.(interface{ UpdatedAt() time.Time }); ok {
return updatable.UpdatedAt(), nil
}
// Last resort: return zero time
return time.Time{}, nil
}
// getLastUpdatedKey returns the key used to store the last updated timestamp
func (c *ExchangeCache) getLastUpdatedKey() string {
// Use the exact key format exr:rate:last_updated
return fmt.Sprintf("%s:last_updated", c.cfg.Prefix)
}
// IsCacheStale checks if the cache is older than the refresh threshold or doesn't exist
// It only checks the last updated timestamp, not individual rate entries
// Returns:
// - bool: true if cache is stale and needs refresh
// - time.Duration: time until next refresh is needed
// - error: any error that occurred
func (c *ExchangeCache) IsCacheStale(ctx context.Context) (bool, time.Duration, error) {
lastUpdated, err := c.GetLastUpdated(ctx)
if err != nil {
// If the entry doesn't exist, consider it stale
if strings.Contains(err.Error(), "not found") ||
strings.Contains(err.Error(), "no such file") ||
strings.Contains(err.Error(), "entity not found") {
return true, 0, nil
}
return false, 0, fmt.Errorf("failed to check cache staleness: %w", err)
}
// If we've never updated, cache is stale
if lastUpdated.IsZero() {
return true, 0, nil
}
// Calculate time since last update
sinceLastUpdate := time.Since(lastUpdated)
// Log cache status for debugging
c.logger.Debug(
"Cache status check",
"last_updated", lastUpdated.Format(time.RFC3339),
"time_since_update", sinceLastUpdate.Round(time.Second).String(),
"ttl", c.cfg.TTL,
)
// Calculate refresh threshold (80% of TTL)
refreshThreshold := time.Duration(float64(c.cfg.TTL) * 0.8)
// If we're past the TTL, cache is definitely stale
if sinceLastUpdate > c.cfg.TTL {
return true, 0, nil
}
// If we're past the refresh threshold, cache is getting stale
if sinceLastUpdate > refreshThreshold {
// Cache is getting stale, return time until TTL
return true, c.cfg.TTL - sinceLastUpdate, nil
}
// Cache is still fresh, return time until next refresh
return false, refreshThreshold - sinceLastUpdate, nil
}
// updateLastUpdated updates the last updated timestamp to now
func (c *ExchangeCache) updateLastUpdated(ctx context.Context) error {
now := time.Now().UTC()
key := c.getLastUpdatedKey()
entry := &exchangeRateInfo{
BaseEntity: *registry.NewBaseEntity(key, key),
Timestamp: now,
}
entry.SetActive(true)
entry.SetMetadata("timestamp", now.Format(time.RFC3339Nano))
return c.exchangeRegistry.Register(ctx, entry)
}
// CacheRates caches multiple exchange rates in a single operation
func (c *ExchangeCache) CacheRates(
ctx context.Context,
rates map[string]*exchange.RateInfo,
source string,
) error {
if len(rates) == 0 {
return nil
}
// Create a timestamp that will be used for all cached entries
now := time.Now().UTC()
var firstErr error
// Update the last updated timestamp after all rates are cached
defer func() {
if err := c.updateLastUpdated(ctx); err != nil {
c.logger.Error("Failed to update last updated timestamp", "error", err)
if firstErr == nil {
firstErr = fmt.Errorf("failed to update last updated timestamp: %w", err)
}
} else {
c.logger.Debug("Updated exchange rate cache last updated timestamp")
}
}()
count := 0
for to, rate := range rates {
if rate == nil {
c.logger.Error("Skipping nil rate", "to", to)
continue
}
if rate.Rate <= 0 {
c.logger.Warn(
"Skipping non-positive conversion rate",
"from", rate.FromCurrency,
"to", to,
"rate", rate.Rate,
)
continue
}
cacheKey := fmt.Sprintf("exr:rate:%s:%s", rate.FromCurrency, to)
cacheEntry := &exchangeRateInfo{
BaseEntity: *registry.NewBaseEntity(cacheKey, cacheKey),
From: rate.FromCurrency,
To: to,
Rate: rate.Rate,
Source: source,
Timestamp: now,
}
cacheEntry.SetActive(true)
cacheEntry.SetMetadata("source", source)
cacheEntry.SetMetadata("rate", fmt.Sprintf("%f", rate.Rate))
cacheEntry.SetMetadata("timestamp", now.Format(time.RFC3339Nano))
cacheEntry.SetMetadata("from", rate.FromCurrency)
cacheEntry.SetMetadata("to", to)
if err := c.exchangeRegistry.Register(ctx, cacheEntry); err != nil {
c.logger.Error(
"Failed to cache rate",
"from", rate.FromCurrency,
"to", to,
"error", err,
)
if firstErr == nil {
firstErr = err
}
continue
}
count++
// Cache the inverse rate as well (rate is guaranteed > 0 here)
inverseRate := 1 / rate.Rate
inverseKey := fmt.Sprintf("exr:rate:%s:%s", to, rate.FromCurrency)
inverseEntry := &exchangeRateInfo{
BaseEntity: *registry.NewBaseEntity(inverseKey, inverseKey),
From: to,
To: rate.FromCurrency,
Rate: inverseRate,
Source: source,
Timestamp: now,
}
inverseEntry.SetActive(true)
inverseEntry.SetMetadata("source", source)
inverseEntry.SetMetadata("rate", fmt.Sprintf("%f", inverseRate))
inverseEntry.SetMetadata(
"timestamp", now.Format(time.RFC3339Nano),
)
inverseEntry.SetMetadata("from", to)
inverseEntry.SetMetadata("to", rate.FromCurrency)
inverseEntry.SetMetadata("original_currency", rate.FromCurrency)
if err := c.exchangeRegistry.Register(ctx, inverseEntry); err != nil {
c.logger.Error(
"Failed to cache inverse rate",
"from", to,
"to", rate.FromCurrency,
"error", err,
)
} else {
count++
}
}
// Update last updated timestamp after caching all rates
lastUpdatedKey := c.getLastUpdatedKey()
lastUpdatedEntry := &exchangeRateInfo{
Timestamp: time.Now().UTC(),
}
// Set the ID using the proper method
if err := lastUpdatedEntry.SetID(lastUpdatedKey); err != nil {
c.logger.Error("Failed to set ID for last updated timestamp",
"error", err,
"key", lastUpdatedKey,
)
}
// Update the registry with the last updated timestamp
if err := c.exchangeRegistry.Register(ctx, lastUpdatedEntry); err != nil {
c.logger.Error("Failed to update last_updated timestamp",
"error", err,
"key", lastUpdatedKey,
)
firstErr = fmt.Errorf("failed to update last_updated timestamp: %w", err)
}
c.logger.Info("Successfully cached exchange rates",
"num_rates", count,
"source", source,
)
return firstErr
}
// exchangeRateInfo is a private type used for caching exchange rates and last updated timestamp
// It implements registry.Entity
type exchangeRateInfo struct {
registry.BaseEntity
From string `json:"from,omitempty"`
To string `json:"to,omitempty"`
Rate float64 `json:"rate,omitempty"`
Source string `json:"source,omitempty"`
Timestamp time.Time `json:"timestamp"`
}
// ID returns the unique identifier for the entity
func (e *exchangeRateInfo) ID() string {
if id := e.BaseEntity.ID(); id != "" {
return id
}
if e.From != "" && e.To != "" {
return fmt.Sprintf("exr:rate:%s:%s", e.From, e.To)
}
return ""
}
// Name returns the name of the entity
func (e *exchangeRateInfo) Name() string {
return e.ID()
}
// Active returns whether the entity is active
func (e *exchangeRateInfo) Active() bool {
return true // Exchange rate entries are always considered active
}
// Metadata returns the metadata for the entity
func (e *exchangeRateInfo) Metadata() map[string]string {
return map[string]string{
"from": e.From,
"to": e.To,
"rate": fmt.Sprintf("%f", e.Rate),
"source": e.Source,
"timestamp": e.Timestamp.Format(time.RFC3339),
}
}
// CreatedAt returns the creation time of the entity
func (e *exchangeRateInfo) CreatedAt() time.Time {
return e.Timestamp
}
// UpdatedAt returns the last update time of the entity
func (e *exchangeRateInfo) UpdatedAt() time.Time {
return e.Timestamp
}
// Type returns the type of the entity
func (e *exchangeRateInfo) Type() string {
return "exchange_rate"
}
// GetTimestamp returns the timestamp of the exchange rate
func (e *exchangeRateInfo) GetTimestamp() time.Time {
return e.Timestamp
}
package infra
import (
"errors"
"time"
"github.com/amirasaad/fintech/pkg/config"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
// Add appEnv as a parameter for dependency-injected environment
func NewDBConnection(
cnf *config.DB,
appEnv string,
) (*gorm.DB, error) {
databaseUrl := cnf.Url
if databaseUrl == "" {
return nil, errors.New("DATABASE_URL is not set")
}
var logMode logger.LogLevel
if appEnv == "development" {
logMode = logger.Info
} else {
logMode = logger.Silent
}
connection, err := gorm.Open(postgres.Open(databaseUrl), &gorm.Config{
Logger: logger.Default.LogMode(logMode),
SkipDefaultTransaction: true,
// TranslateError normalizes database-specific errors (PostgreSQL, MySQL, etc.)
// into GORM generic errors (gorm.ErrDuplicatedKey, gorm.ErrRecordNotFound).
// These are then mapped to domain errors by MapGormErrorToDomain in UoW.
// This two-layer approach ensures database-agnostic error handling.
TranslateError: true,
})
if err != nil {
return nil, err
}
sqlDB, err := connection.DB()
if err != nil {
return nil, err
}
sqlDB.SetMaxOpenConns(25)
sqlDB.SetMaxIdleConns(25)
sqlDB.SetConnMaxLifetime(1 * time.Hour)
return connection, nil
}
//go:build !kafka
// +build !kafka
package eventbus
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
)
type KafkaEventBusConfig struct {
GroupID string
TopicPrefix string
DLQRetryInterval time.Duration
DLQBatchSize int
}
type KafkaEventBus struct{}
func NewWithKafka(
brokers string,
logger *slog.Logger,
config *KafkaEventBusConfig,
) (*KafkaEventBus, error) {
return nil, fmt.Errorf("kafka event bus: build with -tags kafka to enable")
}
func (b *KafkaEventBus) Register(eventType events.EventType, handler eventbus.HandlerFunc) {
}
func (b *KafkaEventBus) Emit(ctx context.Context, event events.Event) error {
return fmt.Errorf("kafka event bus: build with -tags kafka to enable")
}
var _ eventbus.Bus = (*KafkaEventBus)(nil)
package eventbus
import (
"context"
"github.com/amirasaad/fintech/pkg/domain/events"
"log/slog"
"sync"
"github.com/amirasaad/fintech/pkg/eventbus"
)
const EventDepthKey = "eventDepth"
const MaxEventDepth = 10
// MemoryEventBus is a simple in-memory implementation of the EventBus interface.
type MemoryEventBus struct {
handlers map[events.EventType][]eventbus.HandlerFunc
mu sync.RWMutex
logger *slog.Logger
published []events.Event // Added for testing purposes
}
// NewWithMemory creates a new in-memory event bus for event-driven
// communication.
func NewWithMemory(logger *slog.Logger) *MemoryEventBus {
return &MemoryEventBus{
handlers: make(map[events.EventType][]eventbus.HandlerFunc),
logger: logger.With("bus", "memory"),
published: make([]events.Event, 0), // Initialize the slice
}
}
// Register registers a handler for a specific event type.
func (b *MemoryEventBus) Register(
eventType events.EventType,
handler eventbus.HandlerFunc,
) {
b.mu.Lock()
defer b.mu.Unlock()
b.handlers[eventType] = append(b.handlers[eventType], handler)
}
// Emit dispatches the event to all registered handlers for its type.
func (b *MemoryEventBus) Emit(ctx context.Context, event events.Event) error {
eventType := events.EventType(event.Type())
b.mu.RLock()
handlers := b.handlers[eventType]
b.mu.RUnlock()
// Store the published event for testing
b.mu.Lock()
b.published = append(b.published, event)
b.mu.Unlock()
for _, handler := range handlers {
handler(ctx, event) //nolint:errcheck
}
return nil
}
// ClearPublished clears the list of published events.
// This is useful for testing.
func (b *MemoryEventBus) ClearPublished() {
b.mu.Lock()
defer b.mu.Unlock()
b.published = make([]events.Event, 0)
}
// Published returns the list of published events. This is useful for testing.
func (b *MemoryEventBus) Published() []events.Event {
b.mu.RLock()
defer b.mu.RUnlock()
return b.published
}
// Ensure MemoryEventBus implements the EventBus interface.
var _ eventbus.Bus = (*MemoryEventBus)(nil)
// MemoryAsyncEventBus is a registry-based in-memory event bus implementation.
type MemoryAsyncEventBus struct {
handlers map[events.EventType][]eventbus.HandlerFunc
mu sync.RWMutex
eventCh chan struct {
ctx context.Context
event events.Event
}
log *slog.Logger
}
// NewWithMemoryAsync creates a new registry-based in-memory event bus.
func NewWithMemoryAsync(logger *slog.Logger) *MemoryAsyncEventBus {
b := &MemoryAsyncEventBus{
handlers: make(map[events.EventType][]eventbus.HandlerFunc),
eventCh: make(chan struct {
ctx context.Context
event events.Event
}, 100),
}
go b.process()
b.log = logger.With("event-bus", "memory")
return b
}
func (b *MemoryAsyncEventBus) Register(
eventType events.EventType,
handler eventbus.HandlerFunc,
) {
b.mu.Lock()
defer b.mu.Unlock()
b.handlers[eventType] = append(b.handlers[eventType], handler)
}
func (b *MemoryAsyncEventBus) Emit(
ctx context.Context,
event events.Event,
) error {
b.eventCh <- struct {
ctx context.Context
event events.Event
}{ctx, event}
return nil
}
// getHandlers returns a copy of the handlers for the given event type.
func (b *MemoryAsyncEventBus) getHandlers(
eventType events.EventType,
) []eventbus.HandlerFunc {
b.mu.RLock()
defer b.mu.RUnlock()
handlers := make([]eventbus.HandlerFunc, len(b.handlers[eventType]))
copy(handlers, b.handlers[eventType])
return handlers
}
func (b *MemoryAsyncEventBus) process() {
for item := range b.eventCh {
eventType := events.EventType(item.event.Type())
handlers := b.getHandlers(eventType)
for _, handler := range handlers {
go func(
h eventbus.HandlerFunc,
ctx context.Context,
evt events.Event,
) {
if err := h(ctx, evt); err != nil {
b.log.Error("error handling event", "error", err, "event_type", eventType)
}
}(handler, item.ctx, item.event)
}
}
}
// Ensure MemoryRegistryEventBus implements the Bus interface.
var _ eventbus.Bus = (*MemoryAsyncEventBus)(nil)
//go:build !redis
// +build !redis
package eventbus
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
)
type RedisEventBusConfig struct {
DLQRetryInterval time.Duration
DLQBatchSize int64
DLQMaxRetries int
DLQInitialBackoff time.Duration
DLQMaxBackoff time.Duration
}
func DefaultRedisEventBusConfig() *RedisEventBusConfig {
return &RedisEventBusConfig{
DLQRetryInterval: 5 * time.Minute,
DLQBatchSize: 10,
}
}
type RedisEventBus struct{}
func NewWithRedis(
url string,
logger *slog.Logger,
config *RedisEventBusConfig,
) (*RedisEventBus, error) {
return nil, fmt.Errorf("redis event bus: build with -tags redis to enable")
}
func (b *RedisEventBus) Register(eventType events.EventType, handler eventbus.HandlerFunc) {
}
func (b *RedisEventBus) Emit(ctx context.Context, event events.Event) error {
return fmt.Errorf("redis event bus: build with -tags redis to enable")
}
var _ eventbus.Bus = (*RedisEventBus)(nil)
package initializer
import (
"context"
"fmt"
"log/slog"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/amirasaad/fintech/infra"
"github.com/amirasaad/fintech/infra/caching"
infra_eventbus "github.com/amirasaad/fintech/infra/eventbus"
exchangerateapi "github.com/amirasaad/fintech/infra/provider/exchangerateapi"
stripepayment "github.com/amirasaad/fintech/infra/provider/stripepayment"
infra_repository "github.com/amirasaad/fintech/infra/repository"
currencyfixtures "github.com/amirasaad/fintech/internal/fixtures/currency"
"github.com/amirasaad/fintech/pkg/app"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/registry"
)
// loadCurrencyFixtures loads currency metadata from embedded CSV into the registry
func loadCurrencyFixtures(ctx context.Context, registry registry.Provider, logger *slog.Logger) {
// Load currency metadata from embedded CSV
logger.Info("Loading embedded currency metadata")
_, filename, _, _ := runtime.Caller(0)
fixturePath := filepath.Join(
filepath.Dir(filename),
"../../internal/fixtures/currency/meta.csv",
)
entities, err := currencyfixtures.LoadCurrencyMetaCSV(fixturePath)
if err != nil {
logger.Warn("Failed to load currency meta from CSV", "error", err)
return
}
logger.Info("Loading currency meta from fixture",
"to_register", len(entities))
var registeredCount int
for _, entity := range entities {
if err := registry.Register(ctx, entity); err != nil {
logger.Error("Failed to register currency", "code", entity.ID(), "error", err)
// Continue with other currencies even if one fails
} else {
registeredCount++
}
}
logger.Info("Successfully loaded currency fixtures", "registered_count", registeredCount)
}
// InitializeDependencies initializes all the application dependencies
func InitializeDependencies(cfg *config.App) (
deps *app.Deps,
err error,
) {
// Load configuration
deps = &app.Deps{}
logger := setupLogger(cfg.Log)
deps.Logger = logger
// Initialize registry providers for each service
deps.RegistryProvider, err = GetDefaultRegistry(cfg, logger)
if err != nil {
return nil, fmt.Errorf("failed to initialize main registry provider: %w", err)
}
// Initialize currency registry with dedicated provider
deps.CurrencyRegistry, err = GetCurrencyRegistry(cfg, logger)
if err != nil {
return nil, fmt.Errorf("failed to initialize currency registry provider: %w", err)
}
ctx := context.Background()
// Only load currency fixtures if the registry is empty
count, err := deps.CurrencyRegistry.Count(ctx)
if err != nil {
logger.Warn("Failed to check currency registry count", "error", err)
}
if count == 0 {
loadCurrencyFixtures(ctx, deps.CurrencyRegistry, logger)
} else {
logger.Info("Skipping currency fixtures load; registry not empty", "existing_count", count)
}
// Initialize checkout registry
deps.CheckoutRegistry, err = GetCheckoutRegistry(cfg, logger)
if err != nil {
return nil, fmt.Errorf("failed to initialize checkout registry provider: %w", err)
}
// Initialize exchange rate registry
deps.ExchangeRateRegistry, err = GetExchangeRateRegistry(cfg, logger)
if err != nil {
return nil, fmt.Errorf("failed to initialize exchange rate registry provider: %w", err)
}
// Create the exchange rate provider
exchangeProvider := exchangerateapi.NewExchangeRateAPIProvider(
cfg.ExchangeRateAPIProviders.ExchangeRateApi,
logger,
)
deps.ExchangeRateProvider = exchangeProvider
// Initialize exchange rates
if eerr := initializeExchangeRates(
ctx,
exchangeProvider,
deps.ExchangeRateRegistry,
cfg.ExchangeRateCache,
logger,
); eerr != nil {
logger.Error("Failed to initialize exchange rates", "error", eerr)
// Don't fail the entire startup for exchange rate initialization
}
// Initialize database
db, err := infra.NewDBConnection(cfg.DB, cfg.Env)
if err != nil {
logger.Error("Failed to initialize database", "error", err)
return nil, err
}
// Initialize unit of work
deps.Uow = infra_repository.NewUoW(db)
// Initialize event bus
bus, err := initEventBus(cfg, logger)
if err != nil {
return nil, err
}
deps.EventBus = bus
// Initialize payment provider with the checkout registry and unit of work
deps.PaymentProvider = stripepayment.New(
bus,
deps.CheckoutRegistry, // Use the checkout-specific registry
cfg.PaymentProviders.Stripe,
logger,
deps.Uow, // Pass the repository's UnitOfWork
)
return
}
func initEventBus(cfg *config.App, logger *slog.Logger) (eventbus.Bus, error) {
explicitDriver := ""
if cfg.EventBus != nil {
explicitDriver = strings.TrimSpace(cfg.EventBus.Driver)
}
if explicitDriver == "" {
return infra_eventbus.NewWithMemoryAsync(logger), nil
}
driver := strings.TrimSpace(strings.ToLower(explicitDriver))
switch driver {
case "memory":
return infra_eventbus.NewWithMemoryAsync(logger), nil
case "redis":
redisURL := ""
if cfg.EventBus != nil {
redisURL = strings.TrimSpace(cfg.EventBus.RedisURL)
}
if redisURL == "" && cfg.Redis != nil {
redisURL = strings.TrimSpace(cfg.Redis.URL)
}
if redisURL == "" {
return nil, fmt.Errorf("event bus redis: redis url is required")
}
busConfig := &infra_eventbus.RedisEventBusConfig{
DLQRetryInterval: 5 * time.Minute,
DLQBatchSize: 10,
}
bus, err := infra_eventbus.NewWithRedis(redisURL, logger, busConfig)
if err != nil {
logger.Warn("Redis event bus init failed, falling back to memory async", "error", err)
return infra_eventbus.NewWithMemoryAsync(logger), nil
}
return bus, nil
case "kafka":
if cfg.EventBus == nil {
return nil, fmt.Errorf("event bus kafka: configuration is required")
}
brokers := strings.TrimSpace(cfg.EventBus.KafkaBrokers)
if brokers == "" {
return nil, fmt.Errorf("event bus kafka: brokers are required")
}
kafkaConfig := &infra_eventbus.KafkaEventBusConfig{
GroupID: strings.TrimSpace(cfg.EventBus.KafkaGroupID),
TopicPrefix: strings.TrimSpace(cfg.EventBus.KafkaTopic),
DLQRetryInterval: 5 * time.Minute,
DLQBatchSize: 10,
}
bus, err := infra_eventbus.NewWithKafka(brokers, logger, kafkaConfig)
if err != nil {
logger.Warn("Kafka event bus init failed, falling back to memory async", "error", err)
return infra_eventbus.NewWithMemoryAsync(logger), nil
}
return bus, nil
default:
return nil, fmt.Errorf("unsupported event bus driver: %s", driver)
}
}
// initializeExchangeRates fetches and caches exchange rates during application startup
// and sets up a background refresh mechanism
func initializeExchangeRates(
ctx context.Context,
exchangeRateProvider exchange.Exchange,
registryProvider registry.Provider,
cfg *config.ExchangeRateCache,
logger *slog.Logger,
) error {
// Start the background refresh goroutine
go func(ctx context.Context, cacheLogger *slog.Logger) {
// Initialize the exchange cache with the provided registry provider and config
exchangeCache := caching.NewExchangeCache(
registryProvider,
cacheLogger,
cfg,
)
// Set up periodic refresh
ticker := time.NewTicker(5 * time.Minute) // Check every 5 minutes
defer ticker.Stop()
for {
select {
case <-ticker.C:
// Check if cache is stale before refreshing
isStale, timeUntilRefresh, err := exchangeCache.IsCacheStale(ctx)
if err != nil {
logger.Warn("Failed to check cache staleness", "error", err)
continue
}
logger.Debug(
"Cache staleness check",
"is_stale", isStale,
"time_until_refresh", timeUntilRefresh,
)
if isStale {
logger.Info("Cache is stale, refreshing exchange rates")
if err := refreshExchangeRates(
ctx, exchangeRateProvider, exchangeCache, logger); err != nil {
logger.Error("Failed to refresh exchange rates", "error", err)
} else {
logger.Info("Successfully refreshed exchange rates")
}
} else {
logger.Debug(
"Cache is still fresh, next refresh in",
"duration", timeUntilRefresh,
)
}
case <-ctx.Done():
return
}
}
}(ctx, logger)
return nil
}
// refreshExchangeRates handles the actual refreshing of exchange rates
func refreshExchangeRates(
ctx context.Context,
exchangeRateProvider exchange.Exchange,
exchangeCache *caching.ExchangeCache,
logger *slog.Logger,
) error {
// Create a timeout context for the refresh operation
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
// Fetch rates from the provider
rates, err := exchangeRateProvider.FetchRates(
ctx,
"USD", // Base currency
)
if err != nil {
return fmt.Errorf("failed to fetch exchange rates: %w", err)
}
// Cache the rates using ExchangeCache
if err := exchangeCache.CacheRates(
ctx,
rates,
exchangeRateProvider.Metadata().Name,
); err != nil {
return fmt.Errorf("failed to cache exchange rates: %w", err)
}
logger.Info("Successfully cached exchange rates",
"provider", exchangeRateProvider.Metadata().Name,
"rates_count", len(rates),
)
return nil
}
package initializer
import (
"log/slog"
"time"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/registry"
)
// RegistryConfig holds configuration for creating a registry provider
type RegistryConfig struct {
Name string
RedisURL string
KeyPrefix string
CacheSize int
CacheTTL time.Duration
}
// GetRegistryProvider returns a configured registry provider based on the provided config
func GetRegistryProvider(
cfg *RegistryConfig,
logger *slog.Logger,
) (registry.Provider, error) {
if cfg == nil {
cfg = &RegistryConfig{
Name: "default",
CacheSize: 1000,
CacheTTL: -1, // No expiration
}
}
// Ensure cache size is at least 1 if caching is enabled
if cfg.CacheTTL != 0 && cfg.CacheSize <= 0 {
cfg.CacheSize = 1000
}
build := func(withRedis bool) (registry.Provider, error) {
builder := registry.NewBuilder().
WithName(cfg.Name).
WithKeyPrefix(cfg.KeyPrefix).
WithCache(cfg.CacheSize, cfg.CacheTTL)
if withRedis && cfg.RedisURL != "" {
builder = builder.WithRedis(cfg.RedisURL)
}
return builder.BuildRegistry()
}
logger.Info("Creating registry provider",
"name", cfg.Name,
"redis_configured", cfg.RedisURL != "",
"key_prefix", cfg.KeyPrefix,
"cache_size", cfg.CacheSize,
"cache_ttl", cfg.CacheTTL,
)
provider, err := build(true)
if err == nil {
return provider, nil
}
if cfg.RedisURL == "" {
return nil, err
}
logger.Warn(
"Registry init failed, falling back to in-memory",
"name", cfg.Name,
"error", err,
)
return build(false)
}
// GetCheckoutRegistry creates a registry provider for the checkout service
func GetCheckoutRegistry(cfg *config.App, logger *slog.Logger) (registry.Provider, error) {
keyPrefix := ""
if cfg.Redis != nil {
keyPrefix = cfg.Redis.KeyPrefix
}
registryCfg := &RegistryConfig{
Name: "checkout",
RedisURL: cfg.Redis.URL,
KeyPrefix: keyPrefix + "checkout:",
CacheSize: 1000,
CacheTTL: -1, // No expiration for checkout sessions
}
return GetRegistryProvider(registryCfg, logger)
}
// GetExchangeRateRegistry creates a registry provider for the exchange rate service
func GetExchangeRateRegistry(cfg *config.App, logger *slog.Logger) (registry.Provider, error) {
if cfg.ExchangeRateCache == nil {
return nil, nil
}
keyPrefix := cfg.ExchangeRateCache.Prefix
if keyPrefix == "" {
keyPrefix = "exr:rate:"
}
registryCfg := &RegistryConfig{
Name: "exchange_rate",
RedisURL: cfg.ExchangeRateCache.Url,
KeyPrefix: keyPrefix,
CacheSize: 1000,
CacheTTL: cfg.ExchangeRateCache.TTL,
}
return GetRegistryProvider(registryCfg, logger)
}
// GetCurrencyRegistry creates a dedicated registry provider for currency data
func GetCurrencyRegistry(cfg *config.App, logger *slog.Logger) (registry.Provider, error) {
return GetRegistryProvider(
&RegistryConfig{
Name: "currency",
RedisURL: cfg.Redis.URL,
KeyPrefix: "currency:",
CacheSize: 200, // Currency data is relatively small but frequently accessed
CacheTTL: -1, // No expiration for currency data
},
logger,
)
}
// GetDefaultRegistry creates a default registry provider
func GetDefaultRegistry(cfg *config.App, logger *slog.Logger) (registry.Provider, error) {
return GetRegistryProvider(
&RegistryConfig{
Name: "default",
RedisURL: cfg.Redis.URL,
KeyPrefix: "registry:default:",
CacheSize: 1000,
CacheTTL: time.Hour * 24,
},
logger,
)
}
package initializer
import (
"github.com/amirasaad/fintech/pkg/config"
"github.com/charmbracelet/lipgloss"
"github.com/charmbracelet/log"
"log/slog"
"os"
)
func setupLogger(cfg *config.Log) *slog.Logger {
// Create a new logger with a custom style
// Define color styles for different log levels
styles := log.DefaultStyles()
infoTxtColor := lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#04B575"}
warnTxtColor := lipgloss.AdaptiveColor{Light: "#EE6FF8", Dark: "#EE6FF8"}
errorTxtColor := lipgloss.AdaptiveColor{Light: "#FF6B6B", Dark: "#FF6B6B"}
debugTxtColor := lipgloss.AdaptiveColor{Light: "#7E57C2", Dark: "#7E57C2"}
// Customize the style for each log level
// Error level styling
styles.Levels[log.ErrorLevel] = lipgloss.NewStyle().
SetString("❌").
Bold(true).
Padding(0, 1).
Foreground(errorTxtColor)
// Info level styling
styles.Levels[log.InfoLevel] = lipgloss.NewStyle().
SetString("ℹ️").
Bold(true).
Padding(0, 1).
Foreground(infoTxtColor)
// Warn level styling
styles.Levels[log.WarnLevel] = lipgloss.NewStyle().
SetString("⚠️").
Bold(true).
Padding(0, 1).
Foreground(warnTxtColor)
// Debug level styling
styles.Levels[log.DebugLevel] = lipgloss.NewStyle().
SetString("🐛").
Bold(true).
Padding(0, 1).
Foreground(debugTxtColor)
styles.Keys["error"] = lipgloss.NewStyle().Foreground(errorTxtColor)
styles.Values["error"] = lipgloss.NewStyle().Bold(true)
styles.Keys["info"] = lipgloss.NewStyle().Foreground(infoTxtColor)
styles.Values["info"] = lipgloss.NewStyle().Bold(true)
styles.Keys["warn"] = lipgloss.NewStyle().Foreground(warnTxtColor)
styles.Values["warn"] = lipgloss.NewStyle().Bold(true)
styles.Keys["debug"] = lipgloss.NewStyle().Foreground(debugTxtColor)
styles.Values["debug"] = lipgloss.NewStyle().Bold(true)
styles.Keys["prefix"] = lipgloss.NewStyle().Foreground(debugTxtColor)
styles.Values["prefix"] = lipgloss.NewStyle().Bold(true)
styles.Keys["caller"] = lipgloss.NewStyle().Foreground(debugTxtColor)
styles.Values["caller"] = lipgloss.NewStyle().Bold(true)
styles.Keys["time"] = lipgloss.NewStyle().Foreground(debugTxtColor)
styles.Values["time"] = lipgloss.NewStyle().Bold(true)
formattersMap := map[string]log.Formatter{
"json": log.JSONFormatter,
"text": log.TextFormatter,
}
formatter := log.TextFormatter
if f, ok := formattersMap[cfg.Format]; ok {
formatter = f
}
// Create a new logger with the custom styles
logger := log.NewWithOptions(os.Stdout, log.Options{
ReportCaller: true,
ReportTimestamp: true,
TimeFormat: cfg.TimeFormat,
Level: log.Level(cfg.Level),
Prefix: cfg.Prefix,
Formatter: formatter,
})
logger.SetStyles(styles) // Convert to slog.Logger
slogger := slog.New(logger)
slog.SetDefault(slogger)
return slogger
}
package exchangerateapi
import (
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"time"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/provider/exchange"
)
// exchangeRateAPI implements the ExchangeRate interface for exchangerate-api.com v6 API
type exchangeRateAPI struct {
apiKey string
baseURL string
httpClient *http.Client
logger *slog.Logger
timeout time.Duration
}
// API error types
const (
errorTypeUnsupportedCode = "unsupported-code"
errorTypeMalformedRequest = "malformed-request"
errorTypeInvalidKey = "invalid-key"
errorTypeInactiveAccount = "inactive-account"
errorTypeQuotaReached = "quota-reached"
errorTypeUnknown = "unknown-code"
)
// ExchangeRateAPIResponseV6 represents the v6 response from the ExchangeRate API
// See: https://www.exchangerate-api.com/docs/standard-requests
// Example:
// { "result": "success", "documentation": "...", "terms_of_use": "...",
// "time_last_update_unix": 1585267200, ... }
type ExchangeRateAPIResponseV6 struct {
Result string `json:"result"`
Documentation string `json:"documentation"`
TermsOfUse string `json:"terms_of_use"`
TimeLastUpdateUnix int64 `json:"time_last_update_unix"`
TimeNextUpdateUnix int64 `json:"time_next_update_unix"`
TimeNextUpdateUTC string `json:"time_next_update_utc"`
BaseCode string `json:"base_code"`
ConversionRates map[string]float64 `json:"conversion_rates"`
// Error fields (if any)
ErrorType string `json:"error-type,omitempty"`
}
// NewExchangeRateAPIProvider creates a new ExchangeRate API provider using config
func NewExchangeRateAPIProvider(
cfg *config.ExchangeRateApi,
logger *slog.Logger,
) *exchangeRateAPI {
if logger == nil {
logger = slog.Default()
}
return &exchangeRateAPI{
apiKey: cfg.ApiKey,
baseURL: fmt.Sprintf("%s/%s", cfg.ApiUrl, cfg.ApiKey),
httpClient: &http.Client{
Timeout: cfg.HTTPTimeout,
},
logger: logger,
timeout: cfg.HTTPTimeout,
}
}
// GetRate fetches the current exchange rate for a currency pair
func (p *exchangeRateAPI) FetchRate(
ctx context.Context,
from, to string,
) (*exchange.RateInfo, error) {
// Update GetRate to use the v6 endpoint and response if needed, or rely on cache for POC
// For now, we'll assume a simple call to the base URL with the API key
url := fmt.Sprintf("%s/%s/%s", p.baseURL, "latest", from)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
if p.apiKey != "" {
req.Header.Set("Authorization", "Bearer "+p.apiKey)
}
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
p.logger.Warn(
"Failed to close response body",
"error", cerr,
)
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body))
}
var apiResp ExchangeRateAPIResponseV6
if err = json.NewDecoder(resp.Body).Decode(&apiResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
if apiResp.Result != "success" {
return nil, fmt.Errorf("API returned result=%s", apiResp.Result)
}
rate, exists := apiResp.ConversionRates[to]
if !exists {
return nil, fmt.Errorf("rate for %s not found in response", to)
}
return &exchange.RateInfo{
FromCurrency: from,
ToCurrency: to,
Rate: rate,
Timestamp: time.Now(),
Provider: p.Metadata().Name,
}, nil
}
// GetRates fetches multiple exchange rates in a single request
func (p *exchangeRateAPI) FetchRates(
ctx context.Context,
from string,
) (map[string]*exchange.RateInfo, error) {
// Build the URL for the latest rates endpoint
url := fmt.Sprintf("%s/latest/%s", p.baseURL, from)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
// Set content type header
req.Header.Set("Accept", "application/json")
// Execute the request
resp, err := p.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to make request: %w", err)
}
defer func() {
if cerr := resp.Body.Close(); cerr != nil {
p.logger.Warn("failed to close response body", "error", cerr)
}
}()
// Read the response body
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
// Parse the response
var apiResp ExchangeRateAPIResponseV6
if err := json.Unmarshal(body, &apiResp); err != nil {
return nil, fmt.Errorf("failed to decode response: %w", err)
}
// Check for API errors
if apiResp.Result != "success" {
switch apiResp.ErrorType {
case errorTypeUnsupportedCode:
return nil, fmt.Errorf("unsupported currency code")
case errorTypeMalformedRequest:
return nil, fmt.Errorf("malformed request")
case errorTypeInvalidKey:
return nil, fmt.Errorf("invalid API key")
case errorTypeInactiveAccount:
return nil, fmt.Errorf("inactive account")
case errorTypeQuotaReached:
return nil, fmt.Errorf("API quota reached")
case errorTypeUnknown, "":
fallthrough
default:
return nil, fmt.Errorf("API error: %s", apiResp.ErrorType)
}
}
// Process the requested rates
results := make(map[string]*exchange.RateInfo)
now := time.Now()
for targetCurrency, rate := range apiResp.ConversionRates {
results[targetCurrency] = &exchange.RateInfo{
FromCurrency: from,
ToCurrency: targetCurrency,
Rate: rate,
Timestamp: now,
Provider: p.Metadata().Name,
}
}
if len(results) == 0 {
return nil, fmt.Errorf("none of the requested currencies were found in the response")
}
return results, nil
}
// IsSupported checks if the provider supports the given currency pair
func (p *exchangeRateAPI) IsSupported(from string, to string) bool {
// Basic validation to avoid panics; provider supports standard ISO-like codes.
if from == "" || to == "" || from == to {
return true
}
// Conservatively return true; actual unsupported pairs will be handled by GetRate/GetRates.
return true
}
// Metadata returns the provider's metadata
func (p *exchangeRateAPI) Metadata() exchange.ProviderMetadata {
return exchange.ProviderMetadata{
Name: "exchangerate-api",
Version: "v6", // Assuming a version, adjust as needed
// Add other metadata fields if available
}
}
// CheckHealth checks if the provider is currently available
func (p *exchangeRateAPI) CheckHealth(ctx context.Context) error {
// Make a simple health check request
return nil
}
// Ensure ExchangeRateAPIProvider implements provider.ExchangeRate
var _ exchange.Exchange = (*exchangeRateAPI)(nil)
// SupportedPairs returns all supported currency pairs
func (p *exchangeRateAPI) SupportedPairs() []string {
// For simplicity, return a hardcoded list or derive from a configuration
// In a real scenario, this would query the API or a local cache of supported pairs.
return []string{"USD/EUR", "EUR/USD", "USD/GBP", "GBP/USD", "USD/JPY", "JPY/USD"}
}
package exchangerateapi
import (
"context"
"time"
"github.com/amirasaad/fintech/pkg/provider/exchange"
)
type fakeExchangeRate struct {
}
func NewFakeExchangeRate() *fakeExchangeRate {
return &fakeExchangeRate{}
}
// CheckHealth implements exchange.Exchange.
func (f *fakeExchangeRate) CheckHealth(ctx context.Context) error {
return nil
}
// FetchRate implements exchange.Exchange.
func (f *fakeExchangeRate) FetchRate(
ctx context.Context,
from string,
to string,
) (*exchange.RateInfo, error) {
return &exchange.RateInfo{
FromCurrency: from,
ToCurrency: to,
Rate: 1,
Timestamp: time.Now(),
Provider: "fake",
}, nil
}
// FetchRates implements exchange.Exchange.
func (f *fakeExchangeRate) FetchRates(
ctx context.Context,
from string,
) (map[string]*exchange.RateInfo, error) {
return map[string]*exchange.RateInfo{
"EUR": {
FromCurrency: from,
ToCurrency: "EUR",
Rate: 1,
Timestamp: time.Now(),
Provider: "fake",
},
}, nil
}
// IsSupported implements exchange.Exchange.
func (f *fakeExchangeRate) IsSupported(from string, to string) bool {
return true
}
// SupportedPairs implements exchange.Exchange.
func (f *fakeExchangeRate) SupportedPairs() []string {
return []string{"EUR"}
}
func (f *fakeExchangeRate) Metadata() exchange.ProviderMetadata {
return exchange.ProviderMetadata{
Name: "fake",
Version: "v1",
}
}
var _ exchange.Exchange = &fakeExchangeRate{}
package mockpayment
import (
"context"
"sync"
"time"
"github.com/amirasaad/fintech/pkg/provider/payment"
)
type mockPayment struct {
status payment.PaymentStatus
}
// MockPaymentProvider simulates a payment provider for tests and local development.
//
// Usage:
// - InitiateDeposit/InitiateWithdraw simulate async payment completion after a short delay.
// - GetPaymentStatus can be polled until PaymentCompleted is returned.
// - This is NOT for production use. Real payment providers use webhooks or callbacks.
//
// In tests, the service will poll GetPaymentStatus until completion,
// simulating a real-world async flow.
//
// See pkg/service/account/account.go for example usage.
// For production, use a real provider and event-driven confirmation.
type MockPaymentProvider struct {
mu sync.Mutex
payments map[string]*mockPayment
}
// NewMockPaymentProvider creates a new instance of MockPaymentProvider.
func NewMockPaymentProvider() *MockPaymentProvider {
return &MockPaymentProvider{
payments: make(map[string]*mockPayment),
}
}
// InitiatePayment simulates initiating a deposit payment.
func (m *MockPaymentProvider) InitiatePayment(
ctx context.Context,
params *payment.InitiatePaymentParams,
) (*payment.InitiatePaymentResponse, error) {
m.mu.Lock()
m.payments[params.TransactionID.String()] = &mockPayment{
status: payment.PaymentPending,
}
m.mu.Unlock()
// Simulate async completion
go func() {
time.Sleep(2 * time.Second)
m.mu.Lock()
m.payments[params.TransactionID.String()].status = payment.PaymentCompleted
m.mu.Unlock()
}()
return &payment.InitiatePaymentResponse{
Status: payment.PaymentPending,
}, nil
}
// HandleWebhook handles payment webhook events
func (m *MockPaymentProvider) HandleWebhook(
ctx context.Context,
payload []byte,
signature string,
) (*payment.PaymentEvent, error) {
// In a real implementation, this would verify the webhook signature
// and parse the payload to return the appropriate PaymentEvent
return nil, nil
}
// InitiatePayout simulates initiating a payout to a connected account
func (m *MockPaymentProvider) InitiatePayout(
ctx context.Context,
params *payment.InitiatePayoutParams,
) (*payment.InitiatePayoutResponse, error) {
// In a real implementation, this would initiate a payout to the connected account
return &payment.InitiatePayoutResponse{
PayoutID: "mock_payout_id",
PaymentProviderID: "mock_provider_id",
Status: payment.PaymentStatus("completed"),
Amount: params.Amount,
Currency: params.Currency,
FeeAmount: 0,
FeeCurrency: params.Currency,
EstimatedArrivalDate: time.Now().Add(24 * time.Hour).Unix(),
}, nil
}
package stripepayment
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"log/slog"
"maps"
"net/http"
"net/url"
"strings"
"time"
"unicode"
"github.com/amirasaad/fintech/pkg/service/checkout"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/registry"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/stripe/stripe-go/v82/webhook"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/google/uuid"
"github.com/stripe/stripe-go/v82"
)
// CheckoutSession represents a Stripe Checkout session.
type CheckoutSession struct {
ID string
PaymentID string
URL string
AmountTotal int64
Currency string
}
// StripePaymentProvider implements Payment using Stripe API.
type StripePaymentProvider struct {
bus eventbus.Bus
client *stripe.Client
checkoutService *checkout.Service
cfg *config.Stripe
logger *slog.Logger
webhookHandlers map[string]webhookHandler
uow repository.UnitOfWork
}
type webhookHandler func(context.Context, stripe.Event, *slog.Logger) (*payment.PaymentEvent, error)
// New creates a new StripePaymentProvider with the given
// API key, registry, and logger. The registry parameter is used for storing
// checkout session data.
func New(
bus eventbus.Bus,
checkoutProvider registry.Provider,
cfg *config.Stripe,
logger *slog.Logger,
uow repository.UnitOfWork,
) *StripePaymentProvider {
// Configure HTTP client with TLS skip option for development
httpClient := &http.Client{
Timeout: 30 * time.Second,
}
if cfg.SkipTLSVerify {
logger.Warn("⚠️ TLS verification is disabled for Stripe API calls - development mode only")
httpClient.Transport = &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
}
// Create backends with custom HTTP client and use it to configure the Stripe client
backends := stripe.NewBackends(httpClient)
client := stripe.NewClient(cfg.ApiKey, stripe.WithBackends(backends))
provider := &StripePaymentProvider{
bus: bus,
client: client,
cfg: cfg,
checkoutService: checkout.New(checkoutProvider, logger),
logger: logger,
webhookHandlers: make(map[string]webhookHandler),
uow: uow,
}
// Initialize webhook handlers
provider.initializeWebhookHandlers()
return provider
}
// initializeWebhookHandlers sets up all the webhook handlers for Stripe events
func (s *StripePaymentProvider) initializeWebhookHandlers() {
s.webhookHandlers = make(map[string]webhookHandler)
// Payment intent events
s.webhookHandlers["payment_intent.succeeded"] = s.handlePaymentIntentSucceeded
s.webhookHandlers["payment_intent.payment_failed"] = s.handlePaymentIntentFailed
// Checkout session events
s.webhookHandlers["checkout.session.completed"] = s.handleCheckoutSessionCompleted
s.webhookHandlers["checkout.session.expired"] = s.handleCheckoutSessionExpired
// Transfer events
s.webhookHandlers["transfer.created"] = s.handleTransferCreated
s.webhookHandlers["transfer.failed"] = s.handleTransferFailed
s.webhookHandlers["transfer.reversed"] = s.handleTransferReversed
// Charge events
s.webhookHandlers["charge.succeeded"] = s.handleChargeSucceeded
s.webhookHandlers["charge.updated"] = s.handleChargeSucceeded
// Account events
s.webhookHandlers["account.updated"] = s.handleAccountUpdated
s.webhookHandlers["account.application.authorized"] = s.handleAccountApplicationAuthorized
s.webhookHandlers["capability.updated"] = s.handleCapabilityUpdated
// Payout events
// s.webhookHandlers["payout.paid"] = s.handlePayoutPaid
// s.webhookHandlers["payout.failed"] = s.handlePayoutFailed
}
func (s *StripePaymentProvider) handleAccountUpdated(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var account stripe.Account
if err := json.Unmarshal(event.Data.Raw, &account); err != nil {
return nil, fmt.Errorf("error parsing account: %v", err)
}
log.Info("Account updated",
"account_id", account.ID,
"details_submitted", account.DetailsSubmitted,
)
if account.DetailsSubmitted {
userID, err := uuid.Parse(account.Metadata["user_id"])
if err != nil {
return nil, fmt.Errorf("error parsing user_id from account metadata: %v", err)
}
// Emit a custom event to notify the system that the user has completed onboarding.
onboardingCompletedEvent := events.NewUserOnboardingCompleted(userID, account.ID)
if err := s.bus.Emit(ctx, onboardingCompletedEvent); err != nil {
log.Error("failed to emit UserOnboardingCompleted event", "error", err)
return nil, fmt.Errorf("failed to emit UserOnboardingCompleted event: %w", err)
}
}
return nil, nil
}
func (s *StripePaymentProvider) handleAccountApplicationAuthorized(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var app stripe.Application
if err := json.Unmarshal(event.Data.Raw, &app); err != nil {
return nil, fmt.Errorf("error parsing application: %v", err)
}
log.Info("Account application authorized",
"application_id", app.ID,
)
// This event indicates that the user has authorized the application
// to connect to their Stripe account.
// We can treat this as the user having completed the onboarding process.
userID, err := uuid.Parse(event.Account)
if err != nil {
return nil, fmt.Errorf("error parsing user_id from event account: %v", err)
}
onboardingCompletedEvent := events.NewUserOnboardingCompleted(userID, event.Account)
if err := s.bus.Emit(ctx, onboardingCompletedEvent); err != nil {
log.Error("failed to emit UserOnboardingCompleted event", "error", err)
return nil, fmt.Errorf("failed to emit UserOnboardingCompleted event: %w", err)
}
return nil, nil
}
func (s *StripePaymentProvider) handleCapabilityUpdated(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var capability stripe.Capability
if err := json.Unmarshal(event.Data.Raw, &capability); err != nil {
return nil, fmt.Errorf("error parsing capability: %v", err)
}
log.Info("Capability updated",
"capability_id", capability.ID,
"status", capability.Status,
"account", capability.Account.ID,
)
if capability.ID == "transfers" && capability.Status == stripe.CapabilityStatusActive {
userID, err := uuid.Parse(capability.Account.Metadata["user_id"])
if err != nil {
return nil, fmt.Errorf("error parsing user_id from account metadata: %v", err)
}
// Emit a custom event to notify the system that the user has completed onboarding.
onboardingCompletedEvent := events.NewUserOnboardingCompleted(userID, capability.Account.ID)
if err := s.bus.Emit(ctx, onboardingCompletedEvent); err != nil {
log.Error("failed to emit UserOnboardingCompleted event", "error", err)
return nil, fmt.Errorf("failed to emit UserOnboardingCompleted event: %w", err)
}
}
return nil, nil
}
// InitiatePayment creates a PaymentIntent in Stripe and returns its ID.
func (s *StripePaymentProvider) InitiatePayment(
ctx context.Context,
params *payment.InitiatePaymentParams,
) (*payment.InitiatePaymentResponse, error) {
s.logger.Debug("🔵 InitiatePayment called",
"transaction_id", params.TransactionID,
"amount", params.Amount,
"currency", params.Currency,
)
log := s.logger.With(
"handler", "stripe.InitiatePayment",
"user_id", params.UserID,
"account_id", params.AccountID,
"amount", params.Amount,
"currency", params.Currency,
)
log.Info("🛒 [START] InitiatePayment")
// Create checkout session
co, err := s.createCheckoutSession(
ctx,
params.UserID,
params.AccountID,
params.TransactionID,
params.Amount,
params.Currency,
"Payment for deposit",
)
if err != nil {
log.Error(
"failed to create checkout session",
"error", err,
)
return nil, fmt.Errorf(
"failed to create checkout session: %w", err)
}
// Create internal checkout session record
_, err = s.checkoutService.CreateSession(
ctx,
co.ID,
co.PaymentID,
params.TransactionID,
params.UserID,
params.AccountID,
params.Amount,
params.Currency,
co.URL,
time.Hour*24,
)
if err != nil {
log.Error(
"failed to create checkout session record",
"error", err,
)
return nil, fmt.Errorf(
"failed to create checkout session record: %w", err)
}
log.Info(
"🛒 Creating checkout session",
"user_id", params.UserID,
"account_id", params.AccountID,
"transaction_id", params.TransactionID,
"amount", params.Amount,
"currency", params.Currency,
)
return &payment.InitiatePaymentResponse{
Status: payment.PaymentPending,
PaymentID: co.PaymentID,
}, nil
}
// VerifyWebhookSignature verifies the signature of a webhook event
func (s *StripePaymentProvider) VerifyWebhookSignature(payload []byte, header string) error {
if s.cfg.SigningSecret == "" {
return fmt.Errorf("webhook signing secret not configured")
}
_, err := webhook.ConstructEvent(payload, header, s.cfg.SigningSecret)
if err != nil {
return fmt.Errorf("error verifying webhook signature: %v", err)
}
s.logger.Info("Webhook signature verified", "signature", header)
return nil
}
// HandleWebhook handles incoming webhook events from Stripe
func (s *StripePaymentProvider) HandleWebhook(
ctx context.Context,
payload []byte,
signature string,
) (*payment.PaymentEvent, error) {
log := s.logger.With("method", "HandleWebhook")
// Verify the webhook signature
if err := s.VerifyWebhookSignature(payload, signature); err != nil {
log.Error("Failed to verify webhook signature", "error", err)
return nil, fmt.Errorf("webhook signature verification failed: %v", err)
}
// Parse the webhook event
event := stripe.Event{}
if err := json.Unmarshal(payload, &event); err != nil {
log.Error("Failed to parse webhook event", "error", err)
return nil, fmt.Errorf("error parsing webhook event: %v", err)
}
log.Info("Received webhook event",
"type", event.Type,
"id", event.ID,
)
// Find the appropriate handler for the event type
handler, ok := s.webhookHandlers[string(event.Type)]
if !ok {
log.Warn("No handler found for event type", "type", event.Type)
return nil, fmt.Errorf("unhandled event type: %s", event.Type)
}
return handler(ctx, event, log)
}
// handleTransferCreated handles transfer.created webhook events
func (s *StripePaymentProvider) handleTransferCreated(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
log.Debug("🔵 handleTransferCreated called",
"event_id", event.ID,
"event_type", event.Type,
"event_data", string(event.Data.Raw),
)
var transfer stripe.Transfer
if err := json.Unmarshal(event.Data.Raw, &transfer); err != nil {
return nil, fmt.Errorf("error parsing transfer: %v", err)
}
log.Info("Transfer created",
"transfer_id", transfer.ID,
"amount", transfer.Amount,
"currency", transfer.Currency,
)
// Get metadata safely
metadata := make(map[string]string)
if transfer.Metadata != nil {
metadata = transfer.Metadata
}
// Parse user and account IDs from metadata
userID, _ := uuid.Parse(metadata["user_id"])
accountID, _ := uuid.Parse(metadata["account_id"])
transactionID, _ := uuid.Parse(metadata["transaction_id"])
// Create money amount - convert from cents to dollars for money package
amount, err := s.parseAmount(transfer.Amount, string(transfer.Currency))
if err != nil {
return nil, fmt.Errorf("error creating money amount: %v", err)
}
// Determine payment status based on transfer status
status := payment.PaymentCompleted
if transfer.Reversed {
status = payment.PaymentFailed
} else if transfer.AmountReversed > 0 {
status = payment.PaymentStatus("partially_reversed")
}
// Build the payment completed event
metadataInfo := &metadataInfo{
UserID: userID,
AccountID: accountID,
TransactionID: transactionID,
PaymentID: transfer.ID,
}
// Create the payment completed event
pc := s.buildPaymentCompletedEventPayload(amount.Negate(), transfer.ID, metadataInfo, log)
if pc == nil {
return nil, fmt.Errorf("failed to build payment completed event payload")
}
// Log the event details
log.Debug("🔵 Emitting PaymentCompleted event",
"event_id", pc.ID,
"transaction_id", transactionID,
"transfer_id", transfer.ID,
"amount", amount.Amount(),
"currency", amount.Currency(),
)
// Emit the event
if err := s.bus.Emit(ctx, pc); err != nil {
log.Error("🔴 Failed to emit PaymentCompleted event",
"error", err,
"event_id", pc.ID,
)
return nil, fmt.Errorf("failed to emit payment completed event: %v", err)
}
log.Info("✅ PaymentCompleted event emitted successfully",
"event_id", pc.ID,
"transaction_id", transactionID,
"transfer_id", transfer.ID,
)
payoutEvent := &payment.PaymentEvent{
ID: transfer.ID,
Status: status,
Amount: amount.Amount(),
UserID: userID,
AccountID: accountID,
TransactionID: transactionID,
Metadata: metadata,
}
return payoutEvent, nil
}
// handleTransferFailed handles transfer.failed webhook events
func (s *StripePaymentProvider) handleTransferFailed(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var transfer stripe.Transfer
if err := json.Unmarshal(event.Data.Raw, &transfer); err != nil {
return nil, fmt.Errorf("error parsing transfer: %v", err)
}
// Get metadata safely
metadata := make(map[string]string)
if transfer.Metadata != nil {
metadata = transfer.Metadata
}
// Parse user and account IDs from metadata
userID, _ := uuid.Parse(metadata["user_id"])
accountID, _ := uuid.Parse(metadata["account_id"])
// Get failure reason from metadata or use a default message
failureReason := metadata["failure_reason"]
if failureReason == "" {
failureReason = "transfer failed"
}
log.Error("Transfer failed",
"transfer_id", transfer.ID,
"amount", transfer.Amount,
"currency", transfer.Currency,
"failure_reason", failureReason,
)
// Create money amount - convert from cents to dollars for money package
amount, err := s.parseAmount(transfer.Amount, string(transfer.Currency))
if err != nil {
return nil, fmt.Errorf("error creating money amount: %v", err)
}
// Try to get transaction ID from metadata if available
transactionID := uuid.Nil
if txID, ok := metadata["transaction_id"]; ok && txID != "" {
transactionID, _ = uuid.Parse(txID)
}
payoutEvent := &payment.PaymentEvent{
ID: transfer.ID,
Status: payment.PaymentFailed,
Amount: amount.Amount(),
UserID: userID,
AccountID: accountID,
TransactionID: transactionID,
Metadata: metadata,
}
return payoutEvent, nil
}
// handleTransferReversed handles transfer.reversed webhook events
func (s *StripePaymentProvider) handleTransferReversed(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var transfer stripe.Transfer
if err := json.Unmarshal(event.Data.Raw, &transfer); err != nil {
return nil, fmt.Errorf("error parsing transfer: %v", err)
}
log.Warn("Transfer reversed",
"transfer_id", transfer.ID,
"amount", transfer.Amount,
"currency", transfer.Currency,
)
// Get metadata safely
metadata := make(map[string]string)
if transfer.Metadata != nil {
metadata = transfer.Metadata
}
// Parse user and account IDs from metadata
userID, _ := uuid.Parse(metadata["user_id"])
accountID, _ := uuid.Parse(metadata["account_id"])
// Create money amount - convert from cents to dollars for money package
amount, err := s.parseAmount(transfer.Amount, string(transfer.Currency))
if err != nil {
return nil, fmt.Errorf("error creating money amount: %v", err)
}
// Try to get transaction ID from metadata if available
transactionID := uuid.Nil
if txID, ok := metadata["transaction_id"]; ok && txID != "" {
transactionID, _ = uuid.Parse(txID)
}
// Create a failure reason based on whether it's a full or partial reversal
failureReason := "transfer fully reversed"
if transfer.AmountReversed > 0 && transfer.AmountReversed < transfer.Amount {
failureReason = fmt.Sprintf(
"transfer partially reversed: %d/%d",
transfer.AmountReversed, transfer.Amount,
)
}
payoutEvent := &payment.PaymentEvent{
ID: transfer.ID,
Status: payment.PaymentFailed, // Using PaymentFailed since there's no specific err
Amount: amount.Amount(),
UserID: userID,
AccountID: accountID,
TransactionID: transactionID,
Metadata: metadata,
}
// Add failure reason to metadata for reference
if payoutEvent.Metadata == nil {
payoutEvent.Metadata = make(map[string]string)
}
payoutEvent.Metadata["failure_reason"] = failureReason
return payoutEvent, nil
}
// createCheckoutSession creates a new Stripe Checkout Session
func (s *StripePaymentProvider) createCheckoutSession(
ctx context.Context,
userID, accountID, transactionID uuid.UUID,
amount int64,
currency string,
description string,
) (*CheckoutSession, error) {
successURL := s.ensureAbsoluteURL(s.cfg.SuccessPath)
cancelURL := s.ensureAbsoluteURL(s.cfg.CancelPath)
// Create metadata for the checkout session and payment intent
metadata := map[string]string{
"user_id": userID.String(),
"account_id": accountID.String(),
"transaction_id": transactionID.String(),
"amount": fmt.Sprintf("%d", amount),
"currency": currency,
}
params := &stripe.CheckoutSessionCreateParams{
PaymentMethodTypes: stripe.StringSlice([]string{"card"}),
Mode: stripe.String(string(stripe.CheckoutSessionModePayment)),
SuccessURL: stripe.String(successURL),
CancelURL: stripe.String(cancelURL),
Metadata: metadata,
PaymentIntentData: &stripe.CheckoutSessionCreatePaymentIntentDataParams{
Metadata: metadata,
},
LineItems: []*stripe.CheckoutSessionCreateLineItemParams{{
PriceData: &stripe.CheckoutSessionCreateLineItemPriceDataParams{
Currency: stripe.String(currency),
ProductData: &stripe.CheckoutSessionCreateLineItemPriceDataProductDataParams{
Name: stripe.String(description)},
UnitAmount: stripe.Int64(amount),
},
Quantity: stripe.Int64(1),
}},
}
// Create the checkout session parameters
// Add customer email if available
if userEmail, ok := ctx.Value("user_email").(string); ok && userEmail != "" {
params.CustomerEmail = stripe.String(userEmail)
}
// Create the checkout session using the session package
session, err := s.client.V1CheckoutSessions.Create(ctx, params)
if err != nil {
s.logger.Error(
"failed to create checkout session",
"error", err,
)
return nil, fmt.Errorf("failed to create checkout session: %w", err)
}
// Log successful session creation
s.logger.Info(
"✅ Created checkout session",
"session_id", session.ID,
"url", session.URL,
)
// Create the checkout session response
checkoutSession := &CheckoutSession{
ID: session.ID,
URL: session.URL,
AmountTotal: session.AmountTotal,
Currency: string(session.Currency),
}
// Only set PaymentID if PaymentIntent is not nil
if session.PaymentIntent != nil {
checkoutSession.PaymentID = session.PaymentIntent.ID
}
return checkoutSession, nil
}
// handleCheckoutSessionCompleted handles the checkout.session.completed event
func (s *StripePaymentProvider) handleCheckoutSessionCompleted(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var session stripe.CheckoutSession
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
log.Error(
"parsing checkout.session.completed",
"error", err,
)
return nil, fmt.Errorf(
"error parsing checkout.pi.completed: %w", err)
}
log = log.With(
"checkout_session_id", session.ID,
"payment_intent_id", session.PaymentIntent.ID,
)
se, err := s.checkoutService.GetSession(ctx, session.ID)
if err != nil {
return nil, err
}
amount, err := s.parseAmount(session.AmountSubtotal, string(session.Currency))
if err != nil {
log.Error(
"error parsing amount",
"error", err,
)
return nil, fmt.Errorf("error parsing amount: %w", err)
}
if err := s.bus.Emit(
ctx,
events.NewPaymentProcessed(
&events.FlowEvent{
ID: uuid.New(),
UserID: se.UserID,
AccountID: se.AccountID,
FlowType: "payment",
CorrelationID: uuid.New(),
}, func(pp *events.PaymentProcessed) {
pp.TransactionID = se.TransactionID
paymentID := session.PaymentIntent.ID
pp.PaymentID = &paymentID
log.Info("Emitting ", "event_type", pp.Type())
},
).WithAmount(amount).WithPaymentID(session.PaymentIntent.ID),
); err != nil {
log.Error(
"error emitting payment processed event",
"error", err,
)
return nil, fmt.Errorf("error emitting payment processed event: %w", err)
}
log.Info(
"✅ Checkout pi and transaction updated successfully",
"transaction_id", se.TransactionID,
"checkout_session_id", session.ID,
"payment_intent_id", session.PaymentIntent.ID,
)
return &payment.PaymentEvent{
ID: session.PaymentIntent.ID,
Status: payment.PaymentCompleted,
Amount: session.PaymentIntent.AmountReceived,
Currency: string(session.Currency),
UserID: se.UserID,
AccountID: se.AccountID,
}, nil
}
// handleCheckoutSessionExpired handles the checkout.session.expired event
func (s *StripePaymentProvider) handleCheckoutSessionExpired(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (*payment.PaymentEvent, error) {
var session stripe.CheckoutSession
if err := json.Unmarshal(event.Data.Raw, &session); err != nil {
log.Error(
"parsing checkout.session.expired",
"error", err,
)
return nil, fmt.Errorf(
"error parsing checkout.session.expired: %w", err)
}
log = log.With(
"checkout_session_id", session.ID,
"payment_intent_id", session.PaymentIntent.ID,
)
// Get transaction ID from metadata
transactionID, err := uuid.Parse(session.Metadata["transaction_id"])
if err != nil {
log.Error(
"invalid transaction_id in metadata",
"error", err,
"metadata", session.Metadata,
)
return nil, fmt.Errorf("invalid transaction_id in metadata: %w", err)
}
// Update the checkout session status to expired
if err := s.checkoutService.UpdateStatus(
ctx,
session.ID,
"expired",
); err != nil {
log.Error(
"updating checkout session status to expired",
"error", err,
"transaction_id", transactionID,
)
return nil, fmt.Errorf("error updating session status: %w", err)
}
log.Info(
"⏰ Checkout session and transaction updated to expired",
"transaction_id", transactionID,
)
return nil, nil
}
// handlePaymentIntentSucceeded handles the payment_intent.succeeded event
func (s *StripePaymentProvider) handlePaymentIntentSucceeded(
ctx context.Context,
event stripe.Event,
log *slog.Logger,
) (
*payment.PaymentEvent,
error,
) {
const op = "stripe.handlePaymentIntentSucceeded"
if event.Data == nil || event.Data.Raw == nil {
err := fmt.Errorf("%s: event data is nil", op)
log.Error(err.Error())
return nil, err
}
var pi stripe.PaymentIntent
if err := json.Unmarshal(event.Data.Raw, &pi); err != nil {
err = fmt.Errorf("%s: failed to unmarshal payment intent: %w", op, err)
log.Error(err.Error())
return nil, err
}
if pi.ID == "" {
err := fmt.Errorf("%s: payment intent ID is empty", op)
log.Error(err.Error())
return nil, err
}
log = log.With("payment_intent_id", pi.ID)
log.Info("💰 Handling payment_intent.succeeded event", "payment_intent_id", pi.ID)
if pi.Metadata == nil {
err := fmt.Errorf("%s: payment intent metadata is nil", op)
log.Error(err.Error())
return nil, err
}
parsedMeta, err := s.parseAndValidateMetadata(pi.Metadata, log)
if err != nil {
err = fmt.Errorf("%s: invalid metadata: %w", op, err)
log.Error(err.Error())
return nil, err
}
metadata := s.copyMetadata(pi.Metadata)
currencyCode := strings.ToUpper(string(pi.Currency))
if currencyCode == "" {
err = fmt.Errorf("%s: currency code is empty", op)
log.Error(err.Error())
return nil, err
}
amount, err := s.parseAmount(pi.AmountReceived, currencyCode)
if err != nil {
log.Error("failed to create money amount",
"error", err,
"amount", pi.AmountReceived,
"currency", currencyCode)
return nil, fmt.Errorf("failed to create money amount: %w", err)
}
// Emit PaymentCompleted event with zero fee since we're dropping fees
pc := s.buildPaymentCompletedEventPayload(amount, pi.ID, parsedMeta, log)
if pc == nil {
err := fmt.Errorf("failed to build payment completed event payload")
log.Error(err.Error())
return nil, err
}
if err := s.bus.Emit(ctx, pc); err != nil {
log.Error("error emitting payment completed event", "error", err)
return nil, fmt.Errorf("error emitting payment completed event: %w", err)
}
log.Info("✅ Payment intent processed and transaction updated successfully",
"transaction_id", parsedMeta.TransactionID, "payment_id", pi.ID)
return &payment.PaymentEvent{
ID: pi.ID,
Status: payment.PaymentCompleted,
Amount: pi.AmountReceived,
Currency: string(pi.Currency),
UserID: parsedMeta.UserID,
AccountID: parsedMeta.AccountID,
Metadata: metadata,
}, nil
}
// getFeeFromBalanceTransaction retrieves the balance transaction
// and returns the fee amount and currency.
func (s *StripePaymentProvider) getFeeFromBalanceTransaction(
ctx context.Context,
log *slog.Logger,
balanceTxID string,
) (int64, string, error) {
bt, err := s.client.V1BalanceTransactions.Retrieve(ctx, balanceTxID, nil)
if err != nil {
log.Warn(
"Failed to retrieve balance transaction",
"error", err,
"balance_transaction_id", balanceTxID,
)
return 0, "", err
}
log.Debug("Retrieved balance transaction", "balance_transaction", bt)
feeAmount := bt.Fee
feeCurrency := strings.ToUpper(string(bt.Currency))
log.Info("Retrieved fee from balance transaction",
"fee_amount", feeAmount,
"fee_currency", feeCurrency,
"balance_transaction_id", balanceTxID,
)
return feeAmount, feeCurrency, nil
}
// metadataInfo holds parsed metadata fields.
type metadataInfo struct {
UserID uuid.UUID
AccountID uuid.UUID
TransactionID uuid.UUID
PaymentID string
Currency string
}
// parseAndValidateMetadata extracts and validates required fields from metadata.
func (s *StripePaymentProvider) parseAndValidateMetadata(
meta map[string]string,
log *slog.Logger,
) (*metadataInfo, error) {
const op = "stripe.parseAndValidateMetadata"
// Check for required fields
requiredFields := []string{"user_id", "account_id", "transaction_id", "currency"}
var missingFields []string
for _, field := range requiredFields {
if _, exists := meta[field]; !exists || meta[field] == "" {
missingFields = append(missingFields, field)
}
}
if len(missingFields) > 0 {
err := fmt.Errorf("%s: missing required metadata fields: %v", op, missingFields)
log.Error(err.Error(), "metadata", meta)
return nil, err
}
// Parse UUIDs
userID, err := uuid.Parse(meta["user_id"])
if err != nil {
err = fmt.Errorf("%s: invalid user_id in metadata: %w", op, err)
log.Error(err.Error(), "user_id", meta["user_id"])
return nil, err
}
accountID, err := uuid.Parse(meta["account_id"])
if err != nil {
err = fmt.Errorf("%s: invalid account_id in metadata: %w", op, err)
log.Error(err.Error(), "account_id", meta["account_id"])
return nil, err
}
transactionID, err := uuid.Parse(meta["transaction_id"])
if err != nil {
err = fmt.Errorf("%s: invalid transaction_id in metadata: %w", op, err)
log.Error(err.Error(), "transaction_id", meta["transaction_id"])
return nil, err
}
// Validate currency
currencyCode := strings.TrimSpace(meta["currency"])
if currencyCode == "" {
err := fmt.Errorf("%s: currency code is empty", op)
log.Error(err.Error())
return nil, err
}
// Convert to uppercase for consistency
currencyCode = strings.ToUpper(currencyCode)
// Basic currency code validation (ISO 4217 format - 3 uppercase letters)
if len(currencyCode) != 3 || !isAlpha(currencyCode) {
err := fmt.Errorf("%s: invalid currency code format: %s", op, currencyCode)
log.Error(err.Error())
return nil, err
}
return &metadataInfo{
UserID: userID,
AccountID: accountID,
TransactionID: transactionID,
Currency: currencyCode,
}, nil
}
// copyMetadata creates a copy of the metadata map.
func (s *StripePaymentProvider) copyMetadata(
meta map[string]string,
) map[string]string {
if meta == nil {
return make(map[string]string)
}
copied := make(map[string]string, len(meta))
for k, v := range meta {
if k != "" {
copied[k] = v
}
}
return copied
}
// isAlpha checks if a string contains only letters.
func isAlpha(s string) bool {
for _, r := range s {
if !unicode.IsLetter(r) {
return false
}
}
return true
}
// parseAmount converts a Stripe amount and currency to a money.Money object.
// It validates the currency code and ensures the amount is non-negative.
func (s *StripePaymentProvider) parseAmount(
amount int64,
currency string,
) (*money.Money, error) {
const op = "stripe.parseAmount"
if amount < 0 {
err := fmt.Errorf("%s: amount cannot be negative: %d", op, amount)
s.logger.Error(err.Error())
return nil, err
}
if currency == "" {
err := fmt.Errorf("%s: currency cannot be empty", op)
s.logger.Error(err.Error())
return nil, err
}
// Convert to uppercase and validate format (ISO 4217)
currencyCode := strings.ToUpper(strings.TrimSpace(currency))
// Basic currency code validation (3 uppercase letters)
if len(currencyCode) != 3 || !isAlpha(currencyCode) {
err := fmt.Errorf(
"%s: invalid currency code format: %s (must be 3 uppercase letters)",
op,
currencyCode,
)
s.logger.Error(err.Error())
return nil, err
}
// Create money amount from the smallest unit (e.g., cents for USD)
moneyAmount, err := money.NewFromSmallestUnit(amount, money.Code(currencyCode))
if err != nil {
err = fmt.Errorf(
"%s: failed to create money amount from %d %s: %w",
op,
amount,
currencyCode,
err,
)
s.logger.Error(err.Error(),
"amount", amount,
"currency", currencyCode,
)
return nil, err
}
return moneyAmount, nil
}
// parseProviderFeeAmount parses the provider fee amount with validation.
func (s *StripePaymentProvider) parseProviderFeeAmount(
feeAmount int64,
cur string,
log *slog.Logger,
) (*money.Money, error) {
// Validate currency code
if cur == "" {
err := fmt.Errorf("empty currency code provided for fee amount %d", feeAmount)
log.Error("invalid currency code", "error", err)
return nil, fmt.Errorf("invalid currency code: %w", err)
}
// Convert to uppercase to ensure consistency
currency := strings.ToUpper(cur)
// Log the fee being processed for debugging
log = log.With(
"fee_amount", feeAmount,
"fee_currency", currency,
)
// Create money object with validated currency
fee, err := s.parseAmount(feeAmount, currency)
if err != nil {
err = fmt.Errorf("invalid fee amount %d %s: %w", feeAmount, currency, err)
log.Error("error parsing fee amount", "error", err)
return nil, fmt.Errorf("error parsing fee amount: %w", err)
}
log.Debug("successfully parsed provider fee")
return fee, nil
}
// buildPaymentCompletedEventPayload creates a PaymentCompleted event
// with the given amount and metadata.
// It ensures the event is properly constructed without triggering PaymentInitiated handlers.
func (s *StripePaymentProvider) buildPaymentCompletedEventPayload(
amount *money.Money,
paymentID string,
meta *metadataInfo,
log *slog.Logger,
) *events.PaymentCompleted {
// Create a new PaymentCompleted event with minimal required fields
pc := events.NewPaymentCompleted(
&events.FlowEvent{
ID: uuid.New(),
FlowType: "payment",
UserID: meta.UserID,
AccountID: meta.AccountID,
CorrelationID: meta.TransactionID,
Timestamp: time.Now(),
},
func(pc *events.PaymentCompleted) {
pc.TransactionID = meta.TransactionID
pc.PaymentID = &paymentID
pc.Amount = amount
pc.Status = "completed"
},
)
// Set payment ID if available
if meta.PaymentID != "" {
pc.PaymentID = &meta.PaymentID
}
log.Info("built payment completed event",
"event_id", pc.ID,
"transaction_id", meta.TransactionID,
"amount", amount.Amount(),
"currency", amount.Currency(),
)
return pc
}
func (s *StripePaymentProvider) handlePaymentIntentFailed(
ctx context.Context,
event stripe.Event, log *slog.Logger) (*payment.PaymentEvent, error) {
var paymentIntent stripe.PaymentIntent
if err := json.Unmarshal(event.Data.Raw, &paymentIntent); err != nil {
log.Error(
"error parsing payment_intent.payment_failed",
"error", err,
)
return nil, fmt.Errorf(
"error parsing payment_intent.payment_failed: %w", err)
}
log = log.With("payment_intent_id", paymentIntent.ID)
// Get the payment intent details
pi, err := s.client.V1PaymentIntents.Retrieve(
context.Background(),
paymentIntent.ID,
nil,
)
if err != nil {
log.Error(
"error retrieving payment intent",
"error", err,
)
return nil, fmt.Errorf("error retrieving payment intent: %w", err)
}
// Get the user ID, account ID, and transaction ID from metadata
userID, err := uuid.Parse(paymentIntent.Metadata["user_id"])
if err != nil {
log.Error(
"invalid user_id in metadata",
"error", err,
"metadata", paymentIntent.Metadata,
)
return nil, fmt.Errorf("invalid user_id in metadata: %w", err)
}
accountID, err := uuid.Parse(paymentIntent.Metadata["account_id"])
if err != nil {
log.Error(
"invalid account_id in metadata",
"error", err,
"metadata", paymentIntent.Metadata,
)
return nil, fmt.Errorf("invalid account_id in metadata: %w", err)
}
transactionID, err := uuid.Parse(
paymentIntent.Metadata["transaction_id"])
if err != nil {
log.Error(
"invalid transaction_id in metadata",
"error", err,
"metadata", paymentIntent.Metadata,
)
return nil, fmt.Errorf("invalid transaction_id in metadata: %w", err)
}
// Create metadata map from payment intent metadata
metadata := make(map[string]string)
maps.Copy(metadata, paymentIntent.Metadata)
if err := s.bus.Emit(ctx, events.NewPaymentFailed(
&events.FlowEvent{
ID: transactionID,
UserID: userID,
AccountID: accountID,
FlowType: "payment",
CorrelationID: uuid.New(),
},
events.WithFailedPaymentID(&pi.ID),
)); err != nil {
log.Error(
"error emitting payment failed event",
"error", err,
)
return nil, fmt.Errorf("error emitting payment failed event: %w", err)
}
log.Info(
"✅ Payment intent failed and transaction updated",
"transaction_id", transactionID,
"payment_id", paymentIntent.ID,
)
return &payment.PaymentEvent{
ID: pi.ID,
Status: payment.PaymentFailed,
Amount: pi.Amount,
Currency: string(pi.Currency),
UserID: userID,
AccountID: accountID,
Metadata: metadata,
}, nil
}
// ensureAbsoluteURL ensures the URL is absolute by prepending the base URL if needed
func (s *StripePaymentProvider) ensureAbsoluteURL(path string) string {
if path == "" {
return ""
}
u, err := url.Parse(path)
if err != nil {
return path
}
// If it's already an absolute URL, return as is
if u.IsAbs() {
return path
}
return path
}
func (s *StripePaymentProvider) handleChargeSucceeded(
ctx context.Context,
event stripe.Event,
logger *slog.Logger,
) (
*payment.PaymentEvent,
error,
) {
var charge stripe.Charge
if err := json.Unmarshal(event.Data.Raw, &charge); err != nil {
logger.Error(
"error parsing charge.succeeded",
"error", err,
)
return nil, fmt.Errorf(
"error parsing charge.succeeded: %w", err)
}
// Always attempt to retrieve the Stripe fee from the balance transaction.
balanceTxID := ""
if charge.BalanceTransaction != nil {
balanceTxID = charge.BalanceTransaction.ID
}
feeAmount := int64(0)
feeCurrency := string(charge.Currency)
var feeErr error
if balanceTxID != "" {
feeAmount, feeCurrency, feeErr = s.getFeeFromBalanceTransaction(ctx, logger, balanceTxID)
if feeErr != nil {
logger.Warn("Failed to retrieve fee from balance transaction", "error", feeErr)
feeAmount = 0
feeCurrency = string(charge.Currency)
}
} else {
logger.Warn("No balance transaction found on Charge, defaulting fee to 0")
}
if feeCurrency == "" {
feeCurrency = string(charge.Currency)
}
feeCurrency = strings.ToUpper(feeCurrency)
logger = logger.With("charge_id", charge.ID, "balance_transaction_id", balanceTxID)
logger.Info("✅ Charge succeeded", "fee_amount", feeAmount, "fee_currency", feeCurrency)
// Process and emit fee if metadata is valid
if feeEvent, err := s.createFeeEvent(
charge.Metadata,
feeAmount,
feeCurrency,
logger,
); err == nil {
logger.Info("Emitting FeesCalculated event", "event", feeEvent)
_ = s.bus.Emit(ctx, feeEvent)
}
return nil, nil
}
// createFeeEvent creates a FeesCalculated event from the given
// transaction metadata and fee details.
// It returns the created event or an error if any required metadata is missing or invalid.
func (s *StripePaymentProvider) createFeeEvent(
metadata map[string]string,
feeAmount int64,
feeCurrency string,
logger *slog.Logger,
) (*events.FeesCalculated, error) {
// Validate metadata exists
if len(metadata) == 0 {
return nil, fmt.Errorf("missing required metadata")
}
// Parse required metadata fields
userID, err := uuid.Parse(metadata["user_id"])
if err != nil {
logger.Error("Failed to parse user_id from metadata", "error", err)
return nil, fmt.Errorf("invalid user_id: %w", err)
}
accountID, err := uuid.Parse(metadata["account_id"])
if err != nil {
logger.Error("Failed to parse account_id from metadata", "error", err)
return nil, fmt.Errorf("invalid account_id: %w", err)
}
transactionID, err := uuid.Parse(metadata["transaction_id"])
if err != nil {
logger.Error("Failed to parse transaction_id from metadata", "error", err)
return nil, fmt.Errorf("invalid transaction_id: %w", err)
}
// Parse fee amount into money type
feeMoney, err := s.parseProviderFeeAmount(feeAmount, feeCurrency, logger)
if err != nil {
logger.Error("Failed to parse provider fee amount",
"amount", feeAmount,
"currency", feeCurrency,
"error", err)
return nil, fmt.Errorf("invalid fee amount: %w", err)
}
logger.Debug("Creating fee event",
"user_id", userID,
"account_id", accountID,
"transaction_id", transactionID,
"fee_amount", feeMoney.Amount(),
"fee_currency", feeMoney.Currency().String())
// Create and return the fee event
feeEvent := events.NewFeesCalculated(
&events.FlowEvent{
ID: uuid.New(),
UserID: userID,
AccountID: accountID,
FlowType: "payment",
CorrelationID: transactionID,
Timestamp: time.Now(),
},
events.WithFeeAmountValue(feeMoney),
events.WithFeeTransactionID(transactionID),
events.WithFeeType(account.FeeProvider),
)
return feeEvent, nil
}
// InitiatePayout implements payment.Payment interface
func (s *StripePaymentProvider) InitiatePayout(
ctx context.Context,
params *payment.InitiatePayoutParams,
) (*payment.InitiatePayoutResponse, error) {
// First try to get the account with capabilities expanded
s.logger.Info("Initiating payout",
"user_id", params.UserID,
"amount", params.Amount,
"currency", params.Currency,
"destination_type", params.Destination.Type,
"destination_id", params.PaymentProviderID,
)
// Create the transfer to the connected account
transferParams := &stripe.TransferCreateParams{
Amount: stripe.Int64(params.Amount),
Currency: stripe.String(params.Currency),
Destination: stripe.String(params.PaymentProviderID),
Description: stripe.String(params.Description),
}
// Add metadata
transferParams.AddMetadata("user_id", params.UserID.String())
transferParams.AddMetadata("account_id", params.AccountID.String())
transferParams.AddMetadata("transaction_id", params.TransactionID.String())
// Add any additional metadata
for k, v := range params.Metadata {
transferParams.AddMetadata(k, v)
}
// Execute the transfer
transfer, err := s.client.V1Transfers.Create(ctx, transferParams)
if err != nil {
s.logger.Error("failed to create transfer",
"error", err,
"user_id", params.UserID,
"account_id", params.AccountID,
"stripe_account_id", params.PaymentProviderID)
return nil, fmt.Errorf("failed to create transfer: %w", err)
}
// Determine the status based on the transfer status
status := payment.PaymentPending
if transfer.Reversed {
status = payment.PaymentFailed
} else if transfer.DestinationPayment != nil && len(transfer.DestinationPayment.ID) > 0 {
// If we have a destination payment, the transfer was successful
status = payment.PaymentCompleted
}
// Get the fee amount if available
feeAmount := int64(0)
if transfer.DestinationPayment != nil {
feeAmount = max(transfer.DestinationPayment.Amount-transfer.Amount, 0)
}
return &payment.InitiatePayoutResponse{
PayoutID: transfer.ID,
PaymentProviderID: params.PaymentProviderID,
Status: status,
Amount: transfer.Amount,
Currency: string(transfer.Currency),
FeeAmount: feeAmount,
FeeCurrency: string(transfer.Currency),
EstimatedArrivalDate: transfer.Created + 2*24*60*60, // Default to 2 days from creation
}, nil
}
package provider_types
import (
"log/slog"
exchangerateapi "github.com/amirasaad/fintech/infra/provider/exchangerateapi"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/registry"
exchangescv "github.com/amirasaad/fintech/pkg/service/exchange"
)
// Deprecated: Use exchange.Exchange interface directly.
type ExchangeRateCurrencyConverter = exchange.Exchange
// exchangeRateService provides real-time exchange rates with caching and fallback providers.
//
// Deprecated: Use exchange.Service from github.com/amirasaad/fintech/pkg/service/exchange instead.
type exchangeRateService struct {
providers []exchange.Exchange
cache registry.Provider
logger *slog.Logger
cfg *config.ExchangeRateProviders
}
// NewExchangeRateService creates a new exchange rate service.
//
// Deprecated: Use exchange.New from github.com/amirasaad/fintech/pkg/service/exchange instead.
func NewExchangeRateService(
providers []exchange.Exchange,
cache registry.Provider,
logger *slog.Logger,
cfg *config.ExchangeRateProviders,
) *exchangeRateService {
return &exchangeRateService{
providers: providers,
cache: cache,
logger: logger,
cfg: cfg,
}
}
// GetRate retrieves an exchange rate, trying cache first, then providers in order.
//
// Deprecated: Use exchange.Service.GetRate instead.
func (s *exchangeRateService) GetRate(from, to string) (*domain.ConversionInfo, error) {
// Implementation moved to exchange/service.go
return nil, domain.ErrExchangeRateUnavailable
}
// GetRates retrieves multiple exchange rates efficiently.
//
// Deprecated: Use exchange.Service.GetRates instead.
func (s *exchangeRateService) GetRates(
from string,
to []string,
) (map[string]*domain.ConversionInfo, error) {
// Implementation moved to exchange/service.go
return nil, domain.ErrExchangeRateUnavailable
}
// Deprecated: Use NewExchangeRateAPIProvider instead.
func NewExchangeRateCurrencyConverter(
exchangeRateService *exchangescv.Service,
fallback ExchangeRateCurrencyConverter,
logger *slog.Logger,
) ExchangeRateCurrencyConverter {
return exchangerateapi.NewExchangeRateAPIProvider(&config.ExchangeRateApi{}, logger)
}
package account
import (
"github.com/amirasaad/fintech/infra/repository/transaction"
"github.com/google/uuid"
"gorm.io/gorm"
)
// Account represents an account record in the database.
type Account struct {
gorm.Model
ID uuid.UUID `gorm:"type:uuid;primary_key"`
UserID uuid.UUID `gorm:"type:uuid"`
Balance int64
Currency string `gorm:"type:varchar(3);not null;default:'USD'"`
Transactions []transaction.Transaction
}
// TableName specifies the table name for the Account model.
func (Account) TableName() string {
return "accounts"
}
package account
import (
"context"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
"gorm.io/gorm"
)
type repository struct {
db *gorm.DB
}
// New creates a new CQRS-style account repository
// using the provided *gorm.DB.
func New(db *gorm.DB) *repository {
return &repository{db: db}
}
// Create implements account.Repository.
func (r *repository) Create(
ctx context.Context,
create dto.AccountCreate,
) error {
acct := mapCreateDTOToModel(create)
return r.db.WithContext(ctx).Create(&acct).Error
}
// Update implements account.Repository.
func (r *repository) Update(
ctx context.Context,
id uuid.UUID,
update dto.AccountUpdate,
) error {
updates := mapUpdateDTOToModel(update)
return r.db.WithContext(ctx).Model(&Account{}).Where("id = ?", id).Updates(updates).Error
}
// Get implements account.Repository.
func (r *repository) Get(
ctx context.Context,
id uuid.UUID,
) (*dto.AccountRead, error) {
var acct Account
if err := r.db.WithContext(ctx).First(&acct, "id = ?", id).Error; err != nil {
return nil, err
}
return mapModelToDTO(&acct), nil
}
// ListByUser implements account.Repository.
func (r *repository) ListByUser(
ctx context.Context,
userID uuid.UUID,
) ([]*dto.AccountRead, error) {
var accts []Account
if err := r.db.WithContext(ctx).Where("user_id = ?", userID).Find(&accts).Error; err != nil {
return nil, err
}
result := make([]*dto.AccountRead, 0, len(accts))
for i := range accts {
result = append(result, mapModelToDTO(&accts[i]))
}
return result, nil
}
// mapCreateDTOToModel maps AccountCreate DTO to GORM model.
func mapCreateDTOToModel(create dto.AccountCreate) Account {
return Account{
ID: create.ID,
UserID: create.UserID,
Balance: 0,
Currency: create.Currency,
// Add more fields as needed
}
}
// mapUpdateDTOToModel maps AccountUpdate DTO to a map for GORM Updates.
func mapUpdateDTOToModel(update dto.AccountUpdate) map[string]any {
updates := make(map[string]any)
if update.Balance != nil {
updates["balance"] = *update.Balance
}
// if update.Status != nil {
// updates["status"] = *update.Status
// }
// Add more fields as needed
return updates
}
// mapModelToDTO maps a GORM model to a read-optimized DTO.
func mapModelToDTO(acct *Account) *dto.AccountRead {
bal := money.NewFromData(acct.Balance, acct.Currency)
return &dto.AccountRead{
ID: acct.ID,
UserID: acct.UserID,
Balance: bal.AmountFloat(),
Currency: bal.Currency().String(),
CreatedAt: acct.CreatedAt,
}
}
package repository
import (
"errors"
"github.com/amirasaad/fintech/pkg/domain"
"gorm.io/gorm"
)
// MapGormErrorToDomain converts GORM errors to domain errors.
// This keeps infrastructure concerns (database errors) within the infrastructure layer.
// Traverses the error chain to find GORM errors and maps them to appropriate domain errors.
func MapGormErrorToDomain(err error) error {
if err == nil {
return nil
}
// Traverse the error chain to find GORM errors
// GORM wraps database errors, so we check each level
currentErr := err
for currentErr != nil {
switch {
case errors.Is(currentErr, gorm.ErrDuplicatedKey):
return domain.ErrAlreadyExists
case errors.Is(currentErr, gorm.ErrRecordNotFound):
return domain.ErrNotFound
// Add more GORM error mappings as needed
// case errors.Is(currentErr, gorm.ErrForeignKeyViolated):
// return domain.ErrInvalidReference
}
// Move to the next error in the chain
currentErr = errors.Unwrap(currentErr)
}
// Return original error if no mapping found
return err
}
// WrapError wraps a GORM operation and automatically maps errors.
// This helper reduces boilerplate in repository methods while keeping code explicit.
//
// Usage:
//
// err := WrapError(func() error {
// return r.db.WithContext(ctx).Create(user).Error
// })
func WrapError(op func() error) error {
return MapGormErrorToDomain(op())
}
package transaction
import (
"github.com/google/uuid"
"gorm.io/gorm"
)
// Transaction represents a persisted financial transaction.
type Transaction struct {
gorm.Model
ID uuid.UUID `gorm:"type:uuid;primary_key"`
AccountID uuid.UUID `gorm:"type:uuid"`
UserID uuid.UUID `gorm:"type:uuid"`
Amount int64
Currency string `gorm:"type:varchar(3);not null;default:'USD'"`
Balance int64
Status string `gorm:"type:varchar(32);not null;default:'pending'"`
PaymentID *string `gorm:"type:varchar(64);column:payment_id;index"`
// Conversion fields (nullable when no conversion occurs)
OriginalAmount *float64 `gorm:"type:decimal(20,8)"`
OriginalCurrency *string `gorm:"type:varchar(3)"`
ConversionRate *float64 `gorm:"type:decimal(20,8)"`
// MoneySource indicates the origin of funds (e.g., Cash, BankAccount, Stripe, etc.)
MoneySource string `gorm:"type:varchar(64);not null;default:'Internal'"`
ExternalTargetMasked string `gorm:"type:varchar(128);column:external_target_masked"`
// TargetCurrency is the currency the account is credited in (for multi-currency deposits)
TargetCurrency string `gorm:"type:varchar(8);column:target_currency"`
// Fee is the transaction fee in the smallest currency unit (e.g., cents)
Fee *int64 `gorm:"type:bigint;default:0"`
}
// TableName specifies the table name for the Transaction model.
func (Transaction) TableName() string {
return "transactions"
}
package transaction // import alias for infra/repository/transaction
import (
"context"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/money"
repo "github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/google/uuid"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
type repository struct {
db *gorm.DB
}
// New creates a new CQRS-style transaction repository using the provided *gorm.DB.
func New(db *gorm.DB) repo.Repository {
return &repository{db: db}
}
// Create implements transaction.Repository.
func (r *repository) Create(
ctx context.Context,
create dto.TransactionCreate,
) error {
tx := mapCreateDTOToModel(create)
return r.db.WithContext(ctx).Create(&tx).Error
}
// Update implements transaction.Repository.
func (r *repository) Update(
ctx context.Context,
id uuid.UUID,
update dto.TransactionUpdate,
) error {
updates := mapUpdateDTOToModel(update)
return r.db.WithContext(
ctx,
).Model(
&Transaction{},
).Where(
"id = ?",
id,
).Updates(
updates,
).Error
}
// PartialUpdate implements transaction.Repository.
func (r *repository) PartialUpdate(
ctx context.Context,
id uuid.UUID,
update dto.TransactionUpdate,
) error {
updates := mapUpdateDTOToModel(update)
return r.db.WithContext(
ctx,
).Model(
&Transaction{},
).Where(
"id = ?",
id,
).Updates(
updates,
).Error
}
// UpsertByPaymentID implements transaction.Repository.
func (r *repository) UpsertByPaymentID(
ctx context.Context,
paymentID string,
create dto.TransactionCreate,
) error {
tx := mapCreateDTOToModel(create)
if paymentID != "" {
tx.PaymentID = &paymentID
}
return r.db.WithContext(
ctx,
).Clauses(
clause.OnConflict{
Columns: []clause.Column{{Name: "payment_id"}},
DoUpdates: clause.AssignmentColumns([]string{"status", "amount"}),
},
).Create(&tx).Error
}
// Get implements transaction.Repository.
func (r *repository) Get(
ctx context.Context,
id uuid.UUID,
) (*dto.TransactionRead, error) {
var tx Transaction
if err := r.db.WithContext(
ctx,
).First(
&tx,
"id = ?",
id,
).Error; err != nil {
return nil, err
}
return mapModelToReadDTO(&tx), nil
}
// GetByPaymentID implements transaction.Repository.
func (r *repository) GetByPaymentID(
ctx context.Context,
paymentID string,
) (*dto.TransactionRead, error) {
var tx Transaction
if err := r.db.WithContext(
ctx,
).Where(
"payment_id = ?",
paymentID,
).First(
&tx,
).Error; err != nil {
return nil, err
}
return mapModelToReadDTO(&tx), nil
}
// ListByUser implements transaction.Repository.
func (r *repository) ListByUser(
ctx context.Context,
userID uuid.UUID,
) ([]*dto.TransactionRead, error) {
var txs []Transaction
if err := r.db.WithContext(
ctx,
).Where(
"user_id = ?",
userID,
).Find(
&txs,
).Error; err != nil {
return nil, err
}
result := make([]*dto.TransactionRead, 0, len(txs))
for i := range txs {
result = append(result, mapModelToReadDTO(&txs[i]))
}
return result, nil
}
// ListByAccount implements transaction.Repository.
func (r *repository) ListByAccount(
ctx context.Context,
accountID uuid.UUID,
) ([]*dto.TransactionRead, error) {
var txs []Transaction
if err := r.db.WithContext(
ctx,
).Where(
"account_id = ?",
accountID,
).Find(
&txs,
).Error; err != nil {
return nil, err
}
result := make([]*dto.TransactionRead, 0, len(txs))
for i := range txs {
result = append(result, mapModelToReadDTO(&txs[i]))
}
return result, nil
}
// --- Mappers ---
func mapCreateDTOToModel(create dto.TransactionCreate) Transaction {
tx := Transaction{
ID: create.ID,
UserID: create.UserID,
AccountID: create.AccountID,
Amount: create.Amount,
Status: create.Status,
MoneySource: create.MoneySource,
}
// Set PaymentID if it's not nil
if create.PaymentID != nil && *create.PaymentID != "" {
tx.PaymentID = create.PaymentID
}
return tx
}
func mapUpdateDTOToModel(update dto.TransactionUpdate) map[string]any {
updates := make(map[string]any)
if update.Status != nil {
updates["status"] = *update.Status
}
if update.Amount != nil {
updates["amount"] = *update.Amount
}
if update.Currency != nil {
updates["currency"] = *update.Currency
}
if update.Balance != nil {
updates["balance"] = *update.Balance
}
if update.PaymentID != nil {
updates["payment_id"] = *update.PaymentID
}
if update.Fee != nil {
updates["fee"] = *update.Fee
}
if update.ConversionRate != nil {
updates["conversion_rate"] = update.ConversionRate
}
if update.OriginalAmount != nil {
updates["original_amount"] = update.OriginalAmount
}
if update.OriginalCurrency != nil {
updates["original_currency"] = *update.OriginalCurrency
}
// Add more fields as needed
return updates
}
func mapModelToReadDTO(tx *Transaction) *dto.TransactionRead {
amount, err := money.NewFromSmallestUnit(tx.Amount, money.Code(tx.Currency))
if err != nil {
panic(err)
}
dto := &dto.TransactionRead{
ID: tx.ID,
UserID: tx.UserID,
AccountID: tx.AccountID,
Amount: amount.AmountFloat(),
Currency: tx.Currency, // Include the currency
Status: tx.Status,
CreatedAt: tx.CreatedAt,
}
if tx.PaymentID != nil {
dto.PaymentID = tx.PaymentID
}
return dto
}
package repository
import (
"context"
"fmt"
repoaccount "github.com/amirasaad/fintech/infra/repository/account"
repotransaction "github.com/amirasaad/fintech/infra/repository/transaction"
repouser "github.com/amirasaad/fintech/infra/repository/user"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/amirasaad/fintech/pkg/repository/account"
"github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/amirasaad/fintech/pkg/repository/user"
"gorm.io/gorm"
)
// UoW provides transaction boundary and repository access in one abstraction.
//
// Why is GetRepository part of UoW?
// - Ensures all repositories use the same DB session/transaction for true atomicity.
// - Keeps service code clean and focused on business logic.
// - Centralizes repository wiring and registry for maintainability.
// - Prevents accidental use of the wrong DB session (which would break transactionality).
// - Is idiomatic for Go UoW patterns and easy to mock in tests.
type UoW struct {
db *gorm.DB
tx *gorm.DB
repoMap map[any]func(*gorm.DB) any
}
// NewUoW creates a new UoW for the given *gorm.DB.
func NewUoW(db *gorm.DB) *UoW {
return &UoW{
db: db,
repoMap: map[any]func(db *gorm.DB) any{
(*account.Repository)(nil): func(db *gorm.DB) any {
return repoaccount.New(db)
},
(*transaction.Repository)(nil): func(db *gorm.DB) any {
return repotransaction.New(db)
},
(*user.Repository)(nil): func(db *gorm.DB) any {
return repouser.New(db)
},
},
}
}
// Do runs the given function in a transaction boundary, providing a UoW with repository access.
// Automatically maps GORM errors to domain errors.
func (u *UoW) Do(ctx context.Context, fn func(uow repository.UnitOfWork) error) error {
return WrapError(func() error {
return u.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
txnUow := &UoW{
db: u.db,
tx: tx,
}
return fn(txnUow)
})
})
}
// GetRepository provides generic, type-safe access to repositories using the transaction session.
// This method is maintained for backward compatibility
// but is deprecated in favor of type-safe methods.
//
// This method is part of UoW to guarantee that all repository operations within a transaction
// use the same DB session, ensuring atomicity and consistency. It also centralizes repository
// construction and makes testing and extension easier.
func (u *UoW) GetRepository(repoType any) (any, error) {
// Use transaction DB if available, otherwise use main DB
dbToUse := u.tx
if dbToUse == nil {
dbToUse = u.db
}
switch repoType {
case (*account.Repository)(nil):
return repoaccount.New(dbToUse), nil
case (*transaction.Repository)(nil):
return repotransaction.New(dbToUse), nil
case (*user.Repository)(nil):
return repouser.New(dbToUse), nil
default:
if repo, ok := u.repoMap[repoType]; ok {
return repo(dbToUse), nil
}
return nil, fmt.Errorf(
"unsupported repository type: %T, ", repoType)
}
}
package user
import (
"time"
"github.com/google/uuid"
"gorm.io/gorm"
)
// User represents a user record in the database.
//
//revive:disable
type User struct {
gorm.Model
ID uuid.UUID `gorm:"type:uuid;primary_key;default:uuid_generate_v4()"`
Username string `gorm:"uniqueIndex;not null;size:50" validate:"required,min=3,max=50"`
Email string `gorm:"uniqueIndex;not null;size:255" validate:"required,email"`
Password string `gorm:"not null" validate:"required,min=6"`
Names string `gorm:"size:255"`
StripeConnectAccountID string `gorm:"size:255;index"`
StripeConnectOnboardingCompleted bool `gorm:"default:false"`
StripeConnectAccountStatus string `gorm:"size:50"`
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt gorm.DeletedAt `gorm:"index"`
}
//revive:enable
// TableName specifies the table name for the User model.
func (User) TableName() string {
return "users"
}
package user
import (
"context"
"errors"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/repository/user"
"github.com/google/uuid"
"gorm.io/gorm"
)
// repository defines the interface for user repository operations
type repository struct {
// Add repository fields here if needed
db *gorm.DB
}
func New(db *gorm.DB) user.Repository {
return &repository{db: db}
}
func (r *repository) GetByEmail(
ctx context.Context,
email string,
) (*dto.UserRead, error) {
var user User
if err := r.db.WithContext(
ctx,
).Where("email = ?", email).First(&user).Error; err != nil {
return nil, err
}
return mapModelToDTO(&user), nil
}
func (r *repository) GetByUsername(
ctx context.Context,
username string,
) (*dto.UserRead, error) {
var user User
if err := r.db.WithContext(
ctx,
).Where("username = ?", username).First(&user).Error; err != nil {
return nil, err
}
return mapModelToDTO(&user), nil
}
func (r *repository) List(
ctx context.Context,
page, pageSize int,
) ([]*dto.UserRead, error) {
var users []User
if err := r.db.WithContext(
ctx,
).Offset((page - 1) * pageSize).Limit(pageSize).Find(&users).Error; err != nil {
return nil, err
}
result := make([]*dto.UserRead, 0, len(users))
for _, user := range users {
result = append(result, mapModelToDTO(&user))
}
return result, nil
}
func (r *repository) Update(
ctx context.Context,
id uuid.UUID,
uu *dto.UserUpdate,
) error {
updates := make(map[string]interface{})
// Only include non-nil fields in the update
if uu.Username != nil {
updates["username"] = *uu.Username
}
if uu.Email != nil {
updates["email"] = *uu.Email
}
if uu.Names != nil {
updates["names"] = *uu.Names
}
if uu.Password != nil {
updates["password"] = *uu.Password
}
if uu.StripeConnectAccountID != nil {
updates["stripe_connect_account_id"] = *uu.StripeConnectAccountID
}
// If no fields to update, return early
if len(updates) == 0 {
return nil
}
// Update only the specified fields
return r.db.WithContext(ctx).Model(&User{}).
Where("id = ?", id).
Updates(updates).Error
}
func (r *repository) Create(
ctx context.Context,
create *dto.UserCreate,
) error {
user := &User{
ID: create.ID,
Username: create.Username,
Email: create.Email,
Password: create.Password,
Names: create.Names,
}
return r.db.WithContext(
ctx,
).Create(user).Error
}
func (r *repository) Get(
ctx context.Context,
id uuid.UUID,
) (*dto.UserRead, error) {
var user User
if err := r.db.WithContext(
ctx,
).First(&user, "id = ?", id).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, nil
}
return nil, err
}
return mapModelToDTO(&user), nil
}
func (r *repository) Delete(
ctx context.Context,
id uuid.UUID,
) error {
return r.db.WithContext(
ctx,
).Delete(&User{}, "id = ?", id).Error
}
func (r *repository) Exists(
ctx context.Context,
id uuid.UUID,
) (bool, error) {
var count int64
err := r.db.WithContext(
ctx,
).Model(&User{}).Where("id = ?", id).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *repository) ExistsByEmail(
ctx context.Context,
email string,
) (bool, error) {
var count int64
err := r.db.WithContext(
ctx,
).Model(&User{}).Where("email = ?", email).Count(&count).Error
if err != nil {
return false, err
}
return count > 0, nil
}
func (r *repository) ExistsByUsername(
ctx context.Context,
username string,
) (bool, error) {
var count int64
if err := r.db.WithContext(ctx).
Model(&User{}).
Where("username = ?", username).
Count(&count).
Error; err != nil {
return false, err
}
return count > 0, nil
}
// Stripe Connect related methods
func (r *repository) GetStripeAccountID(
ctx context.Context,
userID uuid.UUID,
) (string, error) {
var user User
if err := r.db.WithContext(ctx).
Select("stripe_connect_account_id").
Where("id = ?", userID).
First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return "", nil
}
return "", err
}
return user.StripeConnectAccountID, nil
}
func (r *repository) UpdateStripeAccount(
ctx context.Context,
userID uuid.UUID,
accountID string,
onboardingComplete bool,
) error {
updates := map[string]interface{}{
"stripe_connect_account_id": accountID,
}
if onboardingComplete {
updates["stripe_connect_onboarding_completed"] = true
updates["stripe_connect_account_status"] = "active"
}
return r.db.WithContext(ctx).
Model(&User{}).
Where("id = ?", userID).
Updates(updates).Error
}
func (r *repository) GetStripeOnboardingStatus(
ctx context.Context,
userID uuid.UUID,
) (bool, error) {
var user User
if err := r.db.WithContext(ctx).
Select("stripe_connect_onboarding_completed").
Where("id = ?", userID).
First(&user).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
return user.StripeConnectOnboardingCompleted, nil
}
func (r *repository) UpdateStripeOnboardingStatus(
ctx context.Context,
userID uuid.UUID,
completed bool,
) error {
updates := map[string]interface{}{
"stripe_connect_onboarding_completed": completed,
}
if completed {
updates["stripe_connect_account_status"] = "active"
}
return r.db.WithContext(ctx).
Model(&User{}).
Where("id = ?", userID).
Updates(updates).Error
}
func mapModelToDTO(user *User) *dto.UserRead {
return &dto.UserRead{
ID: user.ID,
Username: user.Username,
Email: user.Email,
HashedPassword: user.Password,
Names: user.Names,
StripeConnectAccountID: user.StripeConnectAccountID,
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
}
}
var _ user.Repository = (*repository)(nil)
package app
import (
"log/slog"
"github.com/amirasaad/fintech/pkg/service/checkout"
exchangeSvc "github.com/amirasaad/fintech/pkg/service/exchange"
"github.com/amirasaad/fintech/pkg/service/stripeconnect"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/amirasaad/fintech/pkg/registry"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/amirasaad/fintech/pkg/service/account"
"github.com/amirasaad/fintech/pkg/service/auth"
currencyScv "github.com/amirasaad/fintech/pkg/service/currency"
userSvc "github.com/amirasaad/fintech/pkg/service/user"
)
// Deps contains all the dependencies needed by the SetupBus function
type Deps struct {
// Registry providers
RegistryProvider registry.Provider // Main registry provider
CurrencyRegistry registry.Provider // For currency service
CheckoutRegistry registry.Provider // For checkout service
ExchangeRateRegistry registry.Provider // For exchange rate service
// Other dependencies
ExchangeRateProvider exchange.Exchange
PaymentProvider payment.Payment
Uow repository.UnitOfWork
EventBus eventbus.Bus
Logger *slog.Logger
}
type App struct {
Deps *Deps
Config *config.App
AuthService *auth.Service
UserService *userSvc.Service
AccountService *account.Service
CurrencyService *currencyScv.Service
CheckoutService *checkout.Service
ExchangeRateService *exchangeSvc.Service
StripeConnectService stripeconnect.Service
}
func New(deps *Deps, cfg *config.App) *App {
app := &App{
Deps: deps,
Config: cfg,
}
app.setupEventBus()
authMap := map[string]func() *auth.Service{
"jwt": func() *auth.Service {
return auth.NewWithJWT(deps.Uow, cfg.Auth.Jwt, deps.Logger)
},
}
if authFactory, ok := authMap[cfg.Auth.Strategy]; ok {
app.AuthService = authFactory()
} else {
app.AuthService = auth.NewWithBasic(deps.Uow, deps.Logger)
}
// Initialize Stripe Connect service
if cfg.PaymentProviders != nil && cfg.PaymentProviders.Stripe != nil && deps.Uow != nil {
app.StripeConnectService = stripeconnect.New(
deps.Uow,
deps.Logger,
cfg.PaymentProviders.Stripe,
)
deps.Logger.Info("Stripe Connect service initialized")
}
// Initialize user service with Unit of Work
app.UserService = userSvc.New(deps.Uow, deps.Logger)
app.AccountService = account.New(
deps.EventBus,
deps.Uow,
deps.Logger,
app.StripeConnectService,
)
// Initialize services with their respective registry providers
app.CurrencyService = currencyScv.New(
deps.CurrencyRegistry,
deps.Logger,
)
app.CheckoutService = checkout.New(
deps.CheckoutRegistry,
deps.Logger,
)
app.ExchangeRateService = exchangeSvc.New(
deps.ExchangeRateRegistry,
deps.ExchangeRateProvider,
deps.Logger,
)
return app
}
// Package app provides functionality for setting up and configuring the event Bus
// with all necessary event handlers for the application.
package app
import (
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/account/deposit"
"github.com/amirasaad/fintech/pkg/handler/account/transfer"
"github.com/amirasaad/fintech/pkg/handler/account/withdraw"
handlercommon "github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/handler/conversion"
"github.com/amirasaad/fintech/pkg/handler/fees"
"github.com/amirasaad/fintech/pkg/handler/payment"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/repository"
)
// setupEventBus registers all event handlers with the provided event Bus.
func (a *App) setupEventBus() {
bus := a.Deps.EventBus
uow := a.Deps.Uow
logger := a.Deps.Logger
a.setupConversionHandlers(
bus,
uow,
a.Deps.ExchangeRateProvider,
logger,
)
a.setupDepositHandlers(bus, uow, logger)
a.setupWithdrawHandlers(bus, uow, logger)
a.setupPaymentHandlers(bus, uow, logger)
a.setupTransferHandlers(bus, uow, logger)
a.setupFeesHandlers(bus, uow, logger)
a.setupUserHandlers(bus, uow, logger)
}
func (a *App) setupUserHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
// bus.Register(
// events.EventTypeUserOnboardingCompleted,
// user.HandleUserOnboardingCompleted(
// bus,
// uow,
// logger,
// ),
//)
}
func (a *App) setupWithdrawHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
bus.Register(
events.EventTypeWithdrawRequested,
withdraw.HandleRequested(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeWithdrawCurrencyConverted,
withdraw.HandleCurrencyConverted(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeWithdrawValidated,
withdraw.HandleValidated(
bus,
uow,
a.Deps.PaymentProvider,
a.Deps.Logger,
),
)
}
func (a *App) setupPaymentHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
// Create idempotency trackers for each handler
initiatedTracker := handlercommon.NewIdempotencyTracker()
processedTracker := handlercommon.NewIdempotencyTracker()
completedTracker := handlercommon.NewIdempotencyTracker()
// Register handlers with idempotency middleware
bus.Register(
events.EventTypePaymentInitiated,
handlercommon.WithIdempotency(
payment.HandleInitiated(
bus,
a.Deps.PaymentProvider,
logger,
),
initiatedTracker,
payment.ExtractPaymentInitiatedKey,
"HandleInitiated",
logger,
),
)
bus.Register(
events.EventTypePaymentProcessed,
handlercommon.WithIdempotency(
payment.HandleProcessed(
uow,
logger,
),
processedTracker,
payment.ExtractPaymentProcessedKey,
"HandleProcessed",
logger,
),
)
bus.Register(
events.EventTypePaymentCompleted,
handlercommon.WithIdempotency(
payment.HandleCompleted(
bus,
uow,
logger,
),
completedTracker,
payment.ExtractPaymentCompletedKey,
"HandleCompleted",
logger,
),
)
}
func (a *App) setupFeesHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
bus.Register(
events.EventTypeFeesCalculated,
fees.HandleCalculated(
uow,
logger,
),
)
}
func (a *App) setupTransferHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
bus.Register(
events.EventTypeTransferRequested,
transfer.HandleRequested(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeTransferCurrencyConverted,
transfer.HandleCurrencyConverted(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeTransferCompleted,
transfer.HandleCompleted(
bus,
uow,
logger,
),
)
}
func (a *App) setupDepositHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) {
bus.Register(
events.EventTypeDepositRequested,
deposit.HandleRequested(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeDepositCurrencyConverted,
deposit.HandleCurrencyConverted(
bus,
uow,
logger,
),
)
bus.Register(
events.EventTypeDepositValidated,
deposit.HandleValidated(
bus,
uow,
a.Deps.PaymentProvider,
a.Deps.Logger,
),
)
}
func (a *App) setupConversionHandlers(
bus eventbus.Bus,
uow repository.UnitOfWork,
exchangeRateProvider exchange.Exchange,
logger *slog.Logger,
) {
// 1️⃣ GENERIC CONVERSION HANDLER
// This handler processes all conversion requests and delegates to the appropriate flow
conversionFactories := map[string]conversion.EventFactory{
"deposit": &conversion.DepositEventFactory{},
"withdraw": &conversion.WithdrawEventFactory{},
"transfer": &conversion.TransferEventFactory{},
}
bus.Register(
events.EventTypeCurrencyConversionRequested,
conversion.HandleRequested(
bus,
a.Deps.ExchangeRateRegistry, // Use the exchange rate registry provider
exchangeRateProvider,
logger,
conversionFactories,
),
)
}
package config
import (
"log/slog"
"github.com/joho/godotenv"
"github.com/kelseyhightower/envconfig"
)
func Load(envFilePath ...string) (*App, error) {
logger := slog.Default()
logger.Info("Loading environment variables")
// If no specific paths provided, try default .env
if len(envFilePath) == 0 {
logger.Debug("No environment file specified, trying default .env")
if err := godotenv.Load(); err != nil {
logger.Warn("No .env file found in current directory")
}
return loadFromEnv()
}
// Try each provided path until we find a valid one
for _, path := range envFilePath {
logger.Debug("Looking for environment file", "path", path)
foundPath, err := FindEnvTest(path)
if err != nil {
logger.Debug("Environment file not found", "path", path, "error", err)
continue
}
logger.Info("Loading environment from file", "path", foundPath)
if err := godotenv.Overload(foundPath); err != nil {
logger.Error("Failed to load environment file", "path", foundPath, "error", err)
continue
}
// Successfully loaded a file, proceed with config loading
return loadFromEnv()
}
// No valid environment files found, try default .env as fallback
logger.Info("No valid environment files found, using default .env")
if err := godotenv.Load(); err != nil {
logger.Warn("No .env file found in current directory")
}
return loadFromEnv()
}
func loadFromEnv() (*App, error) {
var cfg App
err := envconfig.Process("", &cfg)
if err != nil {
return nil, err
}
// Set default values if not set
if cfg.Env == "" {
cfg.Env = "development"
}
if cfg.EventBus == nil {
cfg.EventBus = &EventBus{}
}
logger := slog.Default()
logger.Info("Environment variables loaded from .env file")
logger.Info("App config loaded",
"env", cfg.Env,
"event_bus_driver", cfg.EventBus.Driver,
"rate_limit_max_requests", cfg.RateLimit.MaxRequests,
"rate_limit_window", cfg.RateLimit.Window,
"db", maskValue(cfg.DB.Url),
"auth_strategy", cfg.Auth.Strategy,
"auth_jwt_expiry", cfg.Auth.Jwt.Expiry,
"exchange_cache_ttl", cfg.ExchangeRateCache.TTL,
"exchange_api_url", cfg.ExchangeRateAPIProviders.ExchangeRateApi.ApiUrl,
"exchange_api_key", maskValue(cfg.ExchangeRateAPIProviders.ExchangeRateApi.ApiKey),
)
return &cfg, nil
}
func maskValue(key string) string {
if len(key) <= 6 {
return "****"
}
return key[:2] + "****" + key[len(key)-4:]
}
package config
import (
"os"
"path/filepath"
)
// FindEnvTest searches for the nearest file
// If filename is empty, it searches for .env
func FindEnvTest(filename string) (string, error) {
if filename == "" {
filename = ".env"
}
startDir, err := os.Getwd()
if err != nil {
return "", err
}
curr := startDir
for {
candidate := filepath.Join(curr, filename)
if _, err = os.Stat(candidate); err == nil {
return candidate, nil
}
parent := filepath.Dir(curr)
if parent == curr {
break
}
curr = parent
}
return "", os.ErrNotExist
}
// Package currency provides functionality for working with currency codes and metadata.
// It includes validation, formatting, and conversion utilities for ISO 4217 currency codes.
//
// Deprecated: This package is deprecated and will be removed in a future release.
// Please use the money package instead: github.com/amirasaad/fintech/pkg/money
package currency
import (
"context"
"errors"
"fmt"
"log/slog"
"maps"
"os"
"regexp"
"strconv"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/registry"
)
// Common errors
var (
// Deprecated: Use money.ErrInvalidCurrency instead
ErrInvalidCode = money.ErrInvalidCurrency
ErrUnsupported = errors.New("unsupported currency")
ErrInvalidDecimals = errors.New("invalid decimals: must be between 0 and 8")
ErrInvalidSymbol = errors.New(
"invalid symbol: must not be empty and max 10 characters")
ErrCurrencyNotFound = errors.New("currency not found")
ErrCurrencyExists = errors.New("currency already exists")
)
const (
// DefaultCode is the fallback currency code (USD)
DefaultCode = "USD"
// DefaultDecimals is the default number of decimal places for currencies
DefaultDecimals = 2
// MaxDecimals is the maximum number of decimal places allowed
MaxDecimals = 18
// MaxSymbolLength is the maximum length for currency symbols
MaxSymbolLength = 10
// Default is the default currency code (USD)
// Deprecated: Use money.USD from the money package instead
Default = USD
)
// Meta holds currency-specific metadata
// Deprecated:
type Meta struct {
Code string `json:"code"`
Name string `json:"name"`
Symbol string `json:"symbol"`
Decimals int `json:"decimals"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
Active bool `json:"active"`
Metadata map[string]string `json:"metadata,omitempty"`
Created time.Time `json:"created"`
Updated time.Time `json:"updated"`
}
// Entity implements the registry.Entity interface
// Deprecated
type Entity struct {
*registry.BaseEntity
meta Meta
}
// NewEntity creates a new currency entity
func NewEntity(meta Meta) *Entity {
now := time.Now()
meta.Created = now
meta.Updated = now
return &Entity{
BaseEntity: registry.NewBaseEntity(meta.Code, meta.Name),
meta: meta,
}
}
// Code returns the currency code
func (c *Entity) Code() string {
return c.meta.Code
}
// Name returns the currency name
func (c *Entity) Name() string {
return c.meta.Name
}
// Active returns whether the currency is active
func (c *Entity) Active() bool {
return c.meta.Active
}
// Metadata returns currency metadata
func (c *Entity) Metadata() map[string]string {
metadata := make(map[string]string)
// Only include core fields in the metadata
metadata["code"] = c.meta.Code
metadata["symbol"] = c.meta.Symbol
metadata["decimals"] = strconv.Itoa(c.meta.Decimals)
metadata["country"] = c.meta.Country
metadata["region"] = c.meta.Region
metadata["active"] = strconv.FormatBool(c.meta.Active)
metadata["created"] = c.meta.Created.Format(time.RFC3339)
metadata["updated"] = c.meta.Updated.Format(time.RFC3339)
// Add custom metadata
maps.Copy(metadata, c.meta.Metadata)
return metadata
}
// CreatedAt returns the creation timestamp
func (c *Entity) CreatedAt() time.Time {
return c.meta.Created
}
// UpdatedAt returns the last update timestamp
func (c *Entity) UpdatedAt() time.Time {
return c.meta.Updated
}
// Meta returns the currency metadata
func (c *Entity) Meta() Meta {
return c.meta
}
// Validator implements registry.Validator for currency entities
type Validator struct{}
// NewCurrencyValidator creates a new currency validator
func NewCurrencyValidator() *Validator {
return &Validator{}
}
// Validate validates a currency entity
func (cv *Validator) Validate(ctx context.Context, entity registry.Entity) error {
// Try to convert to Entity first
if currencyEntity, ok := entity.(*Entity); ok {
return validateMeta(currencyEntity.Meta())
}
// If it's not a Entity, try to validate using metadata
// This handles cases where the entity might be a BaseEntity or other type
metadata := entity.Metadata()
if len(metadata) == 0 {
return fmt.Errorf("invalid entity type: expected *Entity or entity with metadata")
}
// Validate required metadata fields
requiredFields := []string{"code", "symbol", "decimals"}
for _, field := range requiredFields {
if value, exists := metadata[field]; !exists || value == "" {
return fmt.Errorf("required metadata field missing: %s", field)
}
}
// Validate currency code format
if code, exists := metadata["code"]; exists {
if !isValidCurrencyCode(code) {
return ErrInvalidCode
}
}
// Validate decimals
if decimalsStr, exists := metadata["decimals"]; exists {
if decimals, err := strconv.Atoi(decimalsStr); err != nil {
return ErrInvalidDecimals
} else if decimals < 0 || decimals > MaxDecimals {
return ErrInvalidDecimals
}
}
// Validate symbol
if symbol, exists := metadata["symbol"]; exists {
if symbol == "" || len(symbol) > MaxSymbolLength {
return ErrInvalidSymbol
}
}
return nil
}
// ValidateMetadata validates currency metadata
func (cv *Validator) ValidateMetadata(ctx context.Context, metadata map[string]string) error {
// Validate required metadata fields
requiredFields := []string{"code", "symbol", "decimals"}
for _, field := range requiredFields {
if value, exists := metadata[field]; !exists || value == "" {
return fmt.Errorf("required metadata field missing: %s", field)
}
}
// Validate currency code format
if code, exists := metadata["code"]; exists {
if !isValidCurrencyCode(code) {
return ErrInvalidCode
}
}
// Validate decimals
if decimalsStr, exists := metadata["decimals"]; exists {
if decimals, err := strconv.Atoi(decimalsStr); err != nil {
return ErrInvalidDecimals
} else if decimals < 0 || decimals > MaxDecimals {
return ErrInvalidDecimals
}
}
// Validate symbol
if symbol, exists := metadata["symbol"]; exists {
if symbol == "" || len(symbol) > MaxSymbolLength {
return ErrInvalidSymbol
}
}
return nil
}
// validateMeta validates currency metadata
func validateMeta(meta Meta) error {
// Validate currency code format
if !isValidCurrencyCode(meta.Code) {
return ErrInvalidCode
}
// Validate decimals
if meta.Decimals < 0 || meta.Decimals > MaxDecimals {
return ErrInvalidDecimals
}
// Validate symbol
if meta.Symbol == "" || len(meta.Symbol) > MaxSymbolLength {
return ErrInvalidSymbol
}
// Validate name
if meta.Name == "" {
return errors.New("currency name cannot be empty")
}
return nil
}
// IsValidFormat returns true if the code
// is a well-formed ISO 4217 currency code (3 uppercase letters).
func IsValidFormat(code string) bool {
re := regexp.MustCompile(`^[A-Z]{3}$`)
return re.MatchString(code)
}
// isValidCurrencyCode checks if a currency code is valid (3 uppercase letters)
func isValidCurrencyCode(code string) bool {
return IsValidFormat(code)
}
// Registry provides currency-specific operations using the registry system
type Registry struct {
registry registry.Provider
ctx context.Context
}
// New creates a new currency registry with default currencies
// If redisURL is provided, it will use Redis for caching
// The function accepts optional parameters in this order: redisURL, keyPrefix
func New(ctx context.Context, params ...string) (*Registry, error) {
var redisURL, keyPrefix string
// Parse parameters
switch len(params) {
case 0:
// No parameters
case 1:
// Only redisURL provided
redisURL = params[0]
default:
// Both redisURL and keyPrefix provided
redisURL = params[0]
keyPrefix = params[1]
}
// Create registry with currency-specific configuration
config := registry.Config{
Name: "currency-registry",
MaxEntities: 1000,
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
}
var reg registry.Provider
var err error
// Use Redis if URL is provided
if redisURL != "" {
// Create registry builder with Redis settings
builder := registry.NewBuilder().
WithName(config.Name).
WithMaxEntities(config.MaxEntities).
WithRedis(redisURL).
WithCache(100, 10*time.Minute) // Cache size and TTL
// Set custom key prefix if provided
if keyPrefix != "" {
builder = builder.WithKeyPrefix(keyPrefix)
}
// Get the config with Redis settings
config = builder.Build()
// Create registry with Redis cache
factory := registry.NewFactory()
reg, err = factory.Create(ctx, config)
if err != nil {
return nil, fmt.Errorf("failed to create Redis-backed registry: %w", err)
}
} else {
// Fall back to in-memory cache
enhanced := registry.NewEnhanced(config).
WithValidator(NewCurrencyValidator()).
WithCache(registry.NewMemoryCache(10 * time.Minute))
reg = enhanced
}
cr := &Registry{
registry: reg,
ctx: ctx,
}
// Register default currencies
if err := cr.registerDefaults(); err != nil {
return nil, fmt.Errorf("failed to register default currencies: %w", err)
}
return cr, nil
}
// NewRegistryWithPersistence creates a currency registry with persistence
// If redisURL is provided, it will use Redis for caching
func NewRegistryWithPersistence(
ctx context.Context, persistencePath string, redisURL ...string,
) (*Registry, error) {
config := registry.Config{
Name: "currency-registry",
MaxEntities: 1000,
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
CacheTTL: 10 * time.Minute,
EnablePersistence: true,
PersistencePath: persistencePath,
AutoSaveInterval: time.Minute,
RedisKeyPrefix: "currency",
}
var reg registry.Provider
var err error
// Use Redis if URL is provided
if len(redisURL) > 0 && redisURL[0] != "" {
// Configure Redis settings
builder := registry.NewBuilder().
WithName(config.Name).
WithMaxEntities(config.MaxEntities).
WithRedis(redisURL[0]).
WithCache(100, 10*time.Minute). // Cache size and TTL
WithPersistence(persistencePath, time.Minute) // Auto-save interval
// Get the config with Redis settings
config = builder.Build()
// Create registry with Redis cache and persistence
factory := registry.NewFactory()
reg, err = factory.CreateWithPersistence(
ctx,
config,
registry.NewFilePersistence(persistencePath),
)
if err != nil {
return nil, fmt.Errorf(
"failed to create Redis-backed registry with persistence: %w",
err,
)
}
} else {
// Fall back to in-memory cache with file persistence
enhanced := registry.NewEnhanced(config)
enhanced.WithValidator(NewCurrencyValidator())
enhanced.WithCache(registry.NewMemoryCache(10 * time.Minute))
// Add persistence
persistence := registry.NewFilePersistence(persistencePath)
enhanced.WithPersistence(persistence)
reg = enhanced
// Load existing entities
if entities, err := persistence.Load(ctx); err == nil {
for _, entity := range entities {
if err := reg.Register(ctx, entity); err != nil {
return nil, fmt.Errorf("failed to load entity %s: %w", entity.ID(), err)
}
}
}
}
cr := &Registry{
registry: reg,
ctx: ctx,
}
// Only register defaults if no entities were loaded
if count, _ := reg.Count(ctx); count == 0 {
if err := cr.registerDefaults(); err != nil {
return nil, fmt.Errorf("failed to register default currencies: %w", err)
}
}
return cr, nil
}
// registerDefaults registers the default set of currencies
func (cr *Registry) registerDefaults() error {
defaultCurrencies := []Meta{
{Code: "USD", Name: "US Dollar", Symbol: "$", Decimals: 2,
Country: "United States", Region: "North America", Active: true},
{Code: "EUR", Name: "Euro", Symbol: "€", Decimals: 2,
Country: "European Union", Region: "Europe", Active: true},
{Code: "GBP", Name: "British Pound", Symbol: "£", Decimals: 2,
Country: "United Kingdom", Region: "Europe", Active: true},
{Code: "JPY", Name: "Japanese Yen", Symbol: "¥", Decimals: 0,
Country: "Japan", Region: "Asia", Active: true},
{Code: "CAD", Name: "Canadian Dollar", Symbol: "C$", Decimals: 2,
Country: "Canada", Region: "North America", Active: true},
{Code: "AUD", Name: "Australian Dollar", Symbol: "A$", Decimals: 2,
Country: "Australia", Region: "Oceania", Active: true},
{Code: "CHF", Name: "Swiss Franc", Symbol: "CHF", Decimals: 2,
Country: "Switzerland", Region: "Europe", Active: true},
{Code: "CNY", Name: "Chinese Yuan", Symbol: "¥", Decimals: 2,
Country: "China", Region: "Asia", Active: true},
{Code: "INR", Name: "Indian Rupee", Symbol: "₹", Decimals: 2,
Country: "India", Region: "Asia", Active: true},
{Code: "BRL", Name: "Brazilian Real", Symbol: "R$", Decimals: 2,
Country: "Brazil", Region: "South America", Active: true},
{Code: "KWD", Name: "Kuwaiti Dinar", Symbol: "د.ك", Decimals: 3,
Country: "Kuwait", Region: "Middle East", Active: true},
{Code: "EGP", Name: "Egyptian Pound", Symbol: "£", Decimals: 2,
Country: "Egypt", Region: "Africa", Active: true},
}
for _, meta := range defaultCurrencies {
if err := cr.Register(meta); err != nil {
return fmt.Errorf("failed to register %s: %w", meta.Code, err)
}
}
return nil
}
// Register adds or updates a currency in the registry
func (cr *Registry) Register(meta Meta) error {
// Validate currency metadata
if err := validateMeta(meta); err != nil {
return fmt.Errorf("validation failed: %w", err)
}
// Create currency entity
entity := NewEntity(meta)
// Register with the registry
if err := cr.registry.Register(cr.ctx, entity); err != nil {
return fmt.Errorf("failed to register currency: %w", err)
}
return nil
}
// Get returns currency metadata for the given code
func (cr *Registry) Get(code string) (Meta, error) {
entity, err := cr.registry.Get(cr.ctx, code)
if err != nil {
return Meta{}, fmt.Errorf("currency not found: %w", err)
}
// Convert entity back to currency metadata
currencyEntity, ok := entity.(*Entity)
if !ok {
// Fallback: try to convert from BaseEntity
metadata := entity.Metadata()
decimals, _ := strconv.Atoi(metadata["decimals"])
active, _ := strconv.ParseBool(metadata["active"])
return Meta{
Code: metadata["code"],
Name: entity.Name(),
Symbol: metadata["symbol"],
Decimals: decimals,
Country: metadata["country"],
Region: metadata["region"],
Active: active,
}, nil
}
return currencyEntity.Meta(), nil
}
// IsSupported checks if a currency code is registered and active
func (cr *Registry) IsSupported(code string) bool {
if !cr.registry.IsRegistered(cr.ctx, code) {
return false
}
entity, err := cr.registry.Get(cr.ctx, code)
if err != nil {
return false
}
if ce, ok := entity.(*Entity); ok {
return ce.Active()
} else {
return entity.Active()
}
}
// ListSupported returns a list of all supported currency codes
func (cr *Registry) ListSupported() ([]string, error) {
entities, err := cr.registry.ListActive(cr.ctx)
if err != nil {
return nil, fmt.Errorf("failed to list currencies: %w", err)
}
codes := make([]string, len(entities))
for i, entity := range entities {
if ce, ok := entity.(*Entity); ok {
codes[i] = ce.Code()
} else {
codes[i] = entity.ID()
}
}
return codes, nil
}
// ListAll returns all registered currencies (active and inactive)
func (cr *Registry) ListAll() ([]Meta, error) {
entities, err := cr.registry.List(cr.ctx)
if err != nil {
return nil, fmt.Errorf("failed to list currencies: %w", err)
}
currencies := make([]Meta, len(entities))
for i, entity := range entities {
if currencyEntity, ok := entity.(*Entity); ok {
currencies[i] = currencyEntity.Meta()
} else {
// Fallback conversion
metadata := entity.Metadata()
decimals, _ := strconv.Atoi(metadata["decimals"])
active, _ := strconv.ParseBool(metadata["active"])
currencies[i] = Meta{
Code: metadata["code"],
Name: entity.Name(),
Symbol: metadata["symbol"],
Decimals: decimals,
Country: metadata["country"],
Region: metadata["region"],
Active: active,
}
}
}
return currencies, nil
}
// Unregister removes a currency from the registry
func (cr *Registry) Unregister(code string) error {
if err := cr.registry.Unregister(cr.ctx, code); err != nil {
return fmt.Errorf("failed to unregister currency: %w", err)
}
return nil
}
// Activate activates a currency
func (cr *Registry) Activate(code string) error {
if err := cr.registry.Activate(cr.ctx, code); err != nil {
return fmt.Errorf("failed to activate currency: %w", err)
}
return nil
}
// Deactivate deactivates a currency
func (cr *Registry) Deactivate(code string) error {
if err := cr.registry.Deactivate(cr.ctx, code); err != nil {
return fmt.Errorf("failed to deactivate currency: %w", err)
}
return nil
}
// Count returns the total number of registered currencies
func (cr *Registry) Count() (int, error) {
return cr.registry.Count(cr.ctx)
}
// CountActive returns the number of active currencies
func (cr *Registry) CountActive() (int, error) {
return cr.registry.CountActive(cr.ctx)
}
// Search searches for currencies by name
func (cr *Registry) Search(query string) ([]Meta, error) {
entities, err := cr.registry.Search(cr.ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to search currencies: %w", err)
}
var currencies []Meta
for _, entity := range entities {
if currencyEntity, ok := entity.(*Entity); ok {
currencies = append(currencies, currencyEntity.Meta())
} else {
// Fallback for non-Entity types
metadata := entity.Metadata()
decimals, _ := strconv.Atoi(metadata["decimals"])
active, _ := strconv.ParseBool(metadata["active"])
currencies = append(currencies, Meta{
Code: metadata["code"],
Name: entity.Name(),
Symbol: metadata["symbol"],
Decimals: decimals,
Country: metadata["country"],
Region: metadata["region"],
Active: active,
})
}
}
return currencies, nil
}
// SearchByRegion searches for currencies by region
func (cr *Registry) SearchByRegion(region string) ([]Meta, error) {
entities, err := cr.registry.SearchByMetadata(cr.ctx, map[string]string{"region": region})
if err != nil {
return nil, fmt.Errorf("failed to search currencies by region: %w", err)
}
var currencies []Meta
for _, entity := range entities {
if currencyEntity, ok := entity.(*Entity); ok {
currencies = append(currencies, currencyEntity.Meta())
} else {
// Fallback for non-Entity types
metadata := entity.Metadata()
decimals, _ := strconv.Atoi(metadata["decimals"])
active, _ := strconv.ParseBool(metadata["active"])
currencies = append(currencies, Meta{
Code: metadata["code"],
Name: entity.Name(),
Symbol: metadata["symbol"],
Decimals: decimals,
Country: metadata["country"],
Region: metadata["region"],
Active: active,
})
}
}
return currencies, nil
}
// GetRegistry returns the underlying registry provider
func (cr *Registry) GetRegistry() registry.Provider {
return cr.registry
}
func (cr *Registry) SetRegistry(reg registry.Provider) {
cr.registry = reg
}
// Global currency registry instance
var globalCurrencyRegistry *Registry
// GetGlobalRegistry returns the global currency registry instance.
// Make sure to call InitializeGlobalRegistry first to initialize the registry.
// If the registry is nil, it will attempt lazy initialization with in-memory cache.
// Panics if initialization fails, as the application cannot function without a currency registry.
func GetGlobalRegistry() *Registry {
if globalCurrencyRegistry == nil {
// Fallback to in-memory cache if not initialized
var err error
globalCurrencyRegistry, err = New(context.Background())
if err != nil {
// Log error before panicking for better diagnostics
slog.Default().Error("Failed to initialize fallback currency registry",
"error", err,
"package", "currency",
"function", "GetGlobalRegistry",
)
// Panic is acceptable here as the application cannot function without currency registry
panic(fmt.Sprintf("currency: failed to initialize fallback registry: %v", err))
}
}
return globalCurrencyRegistry
}
// InitializeGlobalRegistry initializes the global currency registry
// with optional Redis configuration.
// If redisURL is provided, it will be used to configure Redis caching.
// If keyPrefix is provided, it will be used as the Redis key prefix.
// If redisURL is empty, an in-memory cache will be used.
//
// This function should be called during application startup.
func InitializeGlobalRegistry(ctx context.Context, redisURL ...string) error {
var err error
if len(redisURL) > 0 && redisURL[0] != "" {
// Initialize with Redis cache and optional key prefix
if len(redisURL) > 1 {
globalCurrencyRegistry, err = New(ctx, redisURL[0], redisURL[1])
} else {
globalCurrencyRegistry, err = New(ctx, redisURL[0])
}
} else {
// Initialize with in-memory cache
globalCurrencyRegistry, err = New(ctx)
}
return err
}
// Initialize global registry with in-memory cache as fallback
// NOTE: This package is deprecated. Panic in init() is intentional - if currency registry
// initialization fails, the application cannot function. The panic ensures fast failure
// during startup rather than silent failures later.
func init() {
// Initialize with background context and in-memory cache by default
// This ensures the global registry is always available, even if not explicitly initialized
var err error
globalCurrencyRegistry, err = New(context.Background())
if err != nil {
// Log to stderr before panicking (slog may not be initialized yet in init())
fmt.Fprintf(os.Stderr, "ERROR: currency: failed to initialize global registry: %v\n", err)
// Panic is intentional - application cannot function without currency registry
panic(fmt.Sprintf("currency: failed to initialize global registry: %v", err))
}
}
// Register Global convenience functions with error handling
func Register(meta Meta) error {
return globalCurrencyRegistry.Register(meta)
}
// Get returns currency metadata for the given code
var Get = getCurrencyInternal
func getCurrencyInternal(code string) (Meta, error) {
return globalCurrencyRegistry.Get(code)
}
func IsSupported(code string) bool {
return globalCurrencyRegistry.IsSupported(code)
}
func ListSupported() ([]string, error) {
return globalCurrencyRegistry.ListSupported()
}
func ListAll() ([]Meta, error) {
return globalCurrencyRegistry.ListAll()
}
func Unregister(code string) error {
return globalCurrencyRegistry.Unregister(code)
}
func Count() (int, error) {
return globalCurrencyRegistry.Count()
}
func CountActive() (int, error) {
return globalCurrencyRegistry.CountActive()
}
func Search(query string) ([]Meta, error) {
return globalCurrencyRegistry.Search(query)
}
func SearchByRegion(region string) ([]Meta, error) {
return globalCurrencyRegistry.SearchByRegion(region)
}
// Legacy Backward compatibility functions (deprecated)
func Legacy(code string, meta Meta) {
// Convert legacy format to new format
newMeta := Meta{
Code: code,
Name: code,
Symbol: meta.Symbol,
Decimals: meta.Decimals,
Active: true,
}
if err := Register(newMeta); err != nil {
// Log error but don't panic for backward compatibility
slog.Default().Warn("failed to register currency", "code", code, "error", err)
}
}
func GetLegacy(code string) Meta {
meta, err := Get(code)
if err != nil {
// Return default for backward compatibility
return Meta{
Code: code,
Name: code,
Symbol: code,
Decimals: DefaultDecimals,
Active: false,
}
}
return meta
}
func IsSupportedLegacy(code string) bool {
return IsSupported(code)
}
func ListSupportedLegacy() []string {
codes, err := ListSupported()
if err != nil {
return []string{}
}
return codes
}
func UnregisterLegacy(code string) bool {
err := Unregister(code)
return err == nil
}
func CountLegacy() int {
count, err := Count()
if err != nil {
return 0
}
return count
}
package account
import (
"errors"
"fmt"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
var (
// ErrDepositAmountExceedsMaxSafeInt is returned when a deposit would cause the
// account balance to overflow.
ErrDepositAmountExceedsMaxSafeInt = errors.New(
"deposit amount exceeds maximum safe integer value")
// ErrTransactionAmountMustBePositive is returned when a transaction amount is not positive.
ErrTransactionAmountMustBePositive = errors.New(
"transaction amount must be positive")
// ErrInsufficientFunds is returned when an account has insufficient funds for a
// withdrawal or transfer.
ErrInsufficientFunds = errors.New("insufficient funds")
// ErrAccountNotFound is returned when an account cannot be found.
ErrAccountNotFound = errors.New("account not found")
// ErrTransactionNotFound is returned when a transaction cannot be found.
ErrTransactionNotFound = errors.New("transaction not found")
// ErrCannotTransferToSameAccount is returned when a transfer
// is attempted from an account to itself.
ErrCannotTransferToSameAccount = errors.New("cannot transfer to same account")
// ErrNilAccount is returned when a nil account
// is provided to a transfer or other operation.
ErrNilAccount = errors.New("nil account")
// ErrNotOwner is returned when a user attempts to
// perform an action on an account they do not own.
ErrNotOwner = errors.New("not owner")
// ErrCurrencyMismatch is returned when there is
// a currency mismatch between accounts or transactions.
ErrCurrencyMismatch = errors.New("currency mismatch")
)
// Account represents a user's financial account, encapsulating its balance and ownership.
// It acts as an aggregate root, ensuring all state changes are consistent and valid.
//
// Invariants:
// - An account must always have a valid owner (UserID).
// - The account's balance is represented by a Money value object, ensuring currency consistency.
// - The balance can never be negative.
// - All operations are thread-safe, enforced by a mutex.
type Account struct {
ID uuid.UUID
UserID uuid.UUID
Balance *money.Money // Account balance as a Money value object.
UpdatedAt time.Time
CreatedAt time.Time
}
// Builder provides a fluent API for constructing Account instances.
// This pattern is particularly useful for setting optional parameters and ensuring
// that only valid accounts are constructed.
type Builder struct {
id uuid.UUID
userID uuid.UUID
balance int64
currency money.Code
updatedAt time.Time
createdAt time.Time
}
// New creates a new Builder with sensible defaults, such as a new UUID and the default currency.
func New() *Builder {
return &Builder{
id: uuid.New(),
currency: money.DefaultCode,
createdAt: time.Now(),
}
}
// WithID sets the ID for the account being built.
func (b *Builder) WithID(id uuid.UUID) *Builder {
b.id = id
return b
}
// WithUserID sets the user ID for the account being built. This is a mandatory field.
func (b *Builder) WithUserID(userID uuid.UUID) *Builder {
b.userID = userID
return b
}
// WithCurrency sets the currency for the account being built.
// If not set, it defaults to the system's default currency.
// This method accepts string, money.Code, or money.Currency types for backward compatibility.
func (b *Builder) WithCurrency(currencyCode interface{}) *Builder {
switch v := currencyCode.(type) {
case string:
b.currency = money.Code(v)
case money.Code:
b.currency = v
case money.Currency:
b.currency = v.Code
default:
b.currency = money.DefaultCode
}
return b
}
// WithBalance sets the initial balance for the account. This should only be used
// for hydrating an existing account from a data store or for test setup.
func (b *Builder) WithBalance(balance int64) *Builder {
b.balance = balance
return b
}
// WithCreatedAt sets the creation timestamp. This is primarily for hydrating
// an existing account from a data store.
func (b *Builder) WithCreatedAt(t time.Time) *Builder {
b.createdAt = t
return b
}
// WithUpdatedAt sets the last-updated timestamp. This is primarily for hydrating
// an existing account from a data store.
func (b *Builder) WithUpdatedAt(t time.Time) *Builder {
b.updatedAt = t
return b
}
// Build finalizes the construction of the Account. It validates all invariants,
// such as ensuring a valid currency and a non-nil UserID, before returning the
// new Account instance.
func (b *Builder) Build() (*Account, error) {
if b.userID == uuid.Nil {
return nil, errors.New("user ID is required")
}
if b.currency == "" {
b.currency = money.DefaultCode
}
// Create a zero-amount money object in the specified currency
balance, err := money.NewFromSmallestUnit(b.balance, b.currency)
if err != nil {
return nil, fmt.Errorf("invalid balance: %w", err)
}
return &Account{
ID: b.id,
UserID: b.userID,
Balance: balance,
UpdatedAt: b.updatedAt,
CreatedAt: b.createdAt,
}, nil
}
// SetCurrency sets the account's currency.
// This is typically only used during account creation or migration.
func (a *Account) SetCurrency(c money.Code) error {
if a.Balance.Amount() != 0 {
return errors.New("cannot change currency of account with non-zero balance")
}
var err error
a.Balance, err = money.NewFromSmallestUnit(0, c)
return err
}
// validate checks all business invariants for an operation (common validation logic).
func (a *Account) validate(userID uuid.UUID) error {
if a.UserID != userID {
return ErrNotOwner
}
return nil
}
func (a *Account) validateAmount(amount *money.Money) error {
if !amount.IsPositive() {
return ErrTransactionAmountMustBePositive
}
return nil
}
// ValidateDeposit checks all business invariants for a deposit operation.
func (a *Account) ValidateDeposit(userID uuid.UUID, amount *money.Money) (err error) {
if err = a.validate(userID); err != nil {
return
}
if err = a.validateAmount(amount); err != nil {
return
}
if !a.Balance.IsSameCurrency(amount) {
return ErrCurrencyMismatch
}
return
}
// ValidateWithdraw removes funds from the account if all business invariants are satisfied.
// Invariants enforced:
// - Only the account owner can withdraw.
// - Withdrawal amount must be positive.
// - Withdrawal currency must match account currency.
// - Cannot withdraw more than the current balance.
//
// Returns a Transaction or an error if any invariant is violated.
func (a *Account) ValidateWithdraw(userID uuid.UUID, amount *money.Money) error {
if a.UserID != userID {
return ErrNotOwner
}
if err := a.validateAmount(amount); err != nil {
return err
}
// Sufficient funds check: do not allow negative balance
hasEnough, err := a.Balance.GreaterThan(amount)
if err != nil {
return err
}
if !hasEnough && !a.Balance.Equals(amount) {
return ErrInsufficientFunds
}
return nil
}
// ValidateTransfer ensures that a funds transfer from this account to another is valid.
func (a *Account) ValidateTransfer(
senderUserID, receiverUserID uuid.UUID,
dest *Account,
amount *money.Money,
) error {
if a == nil || dest == nil {
return ErrNilAccount
}
if a.ID == dest.ID {
return ErrCannotTransferToSameAccount
}
if a.UserID != senderUserID {
return ErrNotOwner
}
if !amount.IsPositive() {
return ErrTransactionAmountMustBePositive
}
if !a.Balance.IsSameCurrency(amount) ||
!dest.Balance.IsSameCurrency(amount) {
return ErrCurrencyMismatch
}
hasEnough, err := a.Balance.GreaterThan(amount)
if err != nil {
return err
}
if !hasEnough && !a.Balance.Equals(amount) {
return ErrInsufficientFunds
}
return nil
}
package account
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// MoneySource represents the origin of funds for a transaction.
type MoneySource string
// Money source constants define the origin of funds for transactions.
const (
MoneySourceInternal MoneySource = "Internal"
MoneySourceBankAccount MoneySource = "BankAccount"
MoneySourceCard MoneySource = "Card"
MoneySourceCash MoneySource = "Cash"
MoneySourceExternalWallet MoneySource = "ExternalWallet"
)
// TransactionStatus represents the status of a transaction in the payment lifecycle.
type TransactionStatus string
// Transaction status constants define the lifecycle of a transaction.
const (
// TransactionStatusPending indicates that a transaction
// has been initiated and is awaiting completion.
TransactionStatusPending TransactionStatus = "pending"
// TransactionStatusCompleted indicates that a transaction
// has been completed successfully.
TransactionStatusCompleted TransactionStatus = "completed"
// TransactionStatusFailed indicates that a transaction
// has been failed.
TransactionStatusFailed TransactionStatus = "failed"
)
// ExternalTarget represents the destination for an external withdrawal,
// such as a bank account or wallet.
type ExternalTarget struct {
BankAccountNumber string
RoutingNumber string
ExternalWalletAddress string
}
// Transaction represents a financial transaction, capturing all details of a
// single ledger entry.
// It acts as a value object within the domain.
type Transaction struct {
ID uuid.UUID
UserID uuid.UUID
AccountID uuid.UUID
Amount money.Money
Balance money.Money // A snapshot of the account balance at the time of the transaction.
MoneySource MoneySource // The origin of the funds (e.g., Cash, BankAccount, Stripe).
Status TransactionStatus
CreatedAt time.Time
}
// NewTransactionFromData creates a Transaction instance from raw data.
// This function is intended for use by repositories to hydrate a domain object from a data store
// or for setting up test fixtures.
// It bypasses domain invariants and should not be used in business logic.
func NewTransactionFromData(
id, userID, accountID uuid.UUID,
amount money.Money,
balance money.Money,
moneySource MoneySource,
created time.Time,
) *Transaction {
return &Transaction{
ID: id,
UserID: userID,
AccountID: accountID,
Amount: amount,
Balance: balance,
MoneySource: moneySource,
CreatedAt: created,
}
}
package events
import (
"github.com/google/uuid"
"time"
)
type FlowEventOpt func(*FlowEvent)
func NewFlowEvent(opts ...FlowEventOpt) *FlowEvent {
e := &FlowEvent{
ID: uuid.New(),
FlowType: "",
UserID: uuid.New(),
AccountID: uuid.New(),
CorrelationID: uuid.New(),
Timestamp: time.Now(),
}
for _, opt := range opts {
opt(e)
}
return e
}
func (e *FlowEvent) WithUserID(userID uuid.UUID) *FlowEvent {
e.UserID = userID
return e
}
func (e *FlowEvent) WithID(id uuid.UUID) *FlowEvent {
e.ID = id
return e
}
func (e *FlowEvent) WithAccountID(accountID uuid.UUID) *FlowEvent {
e.AccountID = accountID
return e
}
func (e *FlowEvent) WithCorrelationID(correlationID uuid.UUID) *FlowEvent {
e.CorrelationID = correlationID
return e
}
func (e *FlowEvent) WithFlowType(flowType string) *FlowEvent {
e.FlowType = flowType
return e
}
package events
import (
"encoding/json"
"fmt"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/google/uuid"
)
// CurrencyConversionRequested is an agnostic event
// for requesting currency conversion in any business flow.
type CurrencyConversionRequested struct {
FlowEvent
OriginalRequest Event `json:"-"` // Handle this field manually in MarshalJSON/UnmarshalJSON
Amount *money.Money
To money.Code
TransactionID uuid.UUID
// Used for JSON serialization
RequestType string `json:"requestType,omitempty"`
RequestPayload json.RawMessage `json:"requestPayload,omitempty"`
}
func (e CurrencyConversionRequested) Type() string {
return EventTypeCurrencyConversionRequested.String()
}
// MarshalJSON implements custom JSON marshaling for CurrencyConversionRequested
func (e CurrencyConversionRequested) MarshalJSON() ([]byte, error) {
// Create an auxiliary type to avoid recursion
type Alias CurrencyConversionRequested
// Create a copy to avoid modifying the original
eCopy := e
// Marshal the original request if it exists
if eCopy.OriginalRequest != nil {
// Store the type name for proper unmarshaling
eCopy.RequestType = fmt.Sprintf("%T", eCopy.OriginalRequest)
// Marshal the original request
var err error
eCopy.RequestPayload, err = json.Marshal(eCopy.OriginalRequest)
if err != nil {
return nil, fmt.Errorf("failed to marshal original request: %w", err)
}
}
aux := &struct {
*Alias
}{
Alias: (*Alias)(&eCopy),
}
return json.Marshal(aux)
}
// UnmarshalJSON implements custom JSON unmarshaling for CurrencyConversionRequested
func (e *CurrencyConversionRequested) UnmarshalJSON(data []byte) error {
// Create an auxiliary type to avoid recursion
type Alias CurrencyConversionRequested
aux := &struct {
*Alias
}{
Alias: (*Alias)(e),
}
// Unmarshal the main fields
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("failed to unmarshal CurrencyConversionRequested: %w", err)
}
// If we have a request type and payload, try to unmarshal it
if e.RequestType != "" && len(e.RequestPayload) > 0 {
// Create a new instance of the appropriate type based on RequestType
var request Event
switch e.RequestType {
case "*events.DepositRequested", "events.DepositRequested":
req := &DepositRequested{}
if err := json.Unmarshal(e.RequestPayload, req); err != nil {
return fmt.Errorf("failed to unmarshal DepositRequested: %w", err)
}
request = req
case "*events.WithdrawRequested", "events.WithdrawRequested":
req := &WithdrawRequested{}
if err := json.Unmarshal(e.RequestPayload, req); err != nil {
return fmt.Errorf("failed to unmarshal WithdrawRequested: %w", err)
}
request = req
case "*events.TransferRequested", "events.TransferRequested":
req := &TransferRequested{}
if err := json.Unmarshal(e.RequestPayload, req); err != nil {
return fmt.Errorf("failed to unmarshal TransferRequested: %w", err)
}
request = req
default:
return fmt.Errorf("unsupported request type: %s", e.RequestType)
}
e.OriginalRequest = request
}
return nil
}
// CurrencyConverted is an agnostic event for reporting
// the successful result of a currency conversion.
type CurrencyConverted struct {
CurrencyConversionRequested
TransactionID uuid.UUID
ConvertedAmount *money.Money
ConversionInfo *exchange.RateInfo `json:"conversionInfo"`
}
func (e CurrencyConverted) Type() string { return EventTypeCurrencyConverted.String() }
// MarshalJSON implements custom JSON marshaling for CurrencyConverted
func (e CurrencyConverted) MarshalJSON() ([]byte, error) {
// Create an auxiliary structure to explicitly handle all fields
aux := struct {
// Embedded CurrencyConversionRequested fields
ID uuid.UUID `json:"id"`
FlowType string `json:"flowType"`
UserID uuid.UUID `json:"userId"`
AccountID uuid.UUID `json:"accountId"`
CorrelationID uuid.UUID `json:"correlationId"`
Timestamp time.Time `json:"timestamp"`
Amount *money.Money `json:"amount"`
To money.Code `json:"to"`
RequestTransactionID uuid.UUID `json:"requestTransactionId"` // From embedded CCR
RequestType string `json:"requestType,omitempty"`
RequestPayload json.RawMessage `json:"requestPayload,omitempty"`
// CurrencyConverted specific fields
TransactionID uuid.UUID `json:"transactionId"`
ConvertedAmount *money.Money `json:"convertedAmount"`
ConversionInfo *exchange.RateInfo `json:"conversionInfo"`
}{
// Copy from embedded CurrencyConversionRequested
ID: e.ID,
FlowType: e.FlowType,
UserID: e.UserID,
AccountID: e.AccountID,
CorrelationID: e.CorrelationID,
Timestamp: e.Timestamp,
Amount: e.Amount,
To: e.To,
RequestTransactionID: e.TransactionID,
RequestType: e.RequestType,
RequestPayload: e.RequestPayload,
// Copy CurrencyConverted specific fields
TransactionID: e.TransactionID,
ConvertedAmount: e.ConvertedAmount, // Keep as pointer
ConversionInfo: e.ConversionInfo,
}
// Handle OriginalRequest marshaling
if e.OriginalRequest != nil {
aux.RequestType = fmt.Sprintf("%T", e.OriginalRequest)
var err error
aux.RequestPayload, err = json.Marshal(e.OriginalRequest)
if err != nil {
return nil, fmt.Errorf(
"failed to marshal original request in CurrencyConverted: %w",
err,
)
}
}
return json.Marshal(aux)
}
// UnmarshalJSON implements custom JSON unmarshaling for CurrencyConverted
func (e *CurrencyConverted) UnmarshalJSON(data []byte) error {
// Create an auxiliary structure to match the marshaling format
aux := struct {
// Embedded CurrencyConversionRequested fields
ID uuid.UUID `json:"id"`
FlowType string `json:"flowType"`
UserID uuid.UUID `json:"userId"`
AccountID uuid.UUID `json:"accountId"`
CorrelationID uuid.UUID `json:"correlationId"`
Timestamp time.Time `json:"timestamp"`
Amount money.Money `json:"amount"`
To money.Code `json:"to"`
RequestTransactionID uuid.UUID `json:"requestTransactionId"` // From embedded CCR
RequestType string `json:"requestType,omitempty"`
RequestPayload json.RawMessage `json:"requestPayload,omitempty"`
// CurrencyConverted specific fields
TransactionID uuid.UUID `json:"transactionId"`
ConvertedAmount *money.Money `json:"convertedAmount"`
ConversionInfo json.RawMessage `json:"conversionInfo"`
}{}
// Unmarshal the main fields
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("failed to unmarshal CurrencyConverted: %w", err)
}
// Copy fields to the embedded CurrencyConversionRequested
e.ID = aux.ID
e.FlowType = aux.FlowType
e.UserID = aux.UserID
e.AccountID = aux.AccountID
e.CorrelationID = aux.CorrelationID
e.Timestamp = aux.Timestamp
// Create a new money.Money pointer and copy the value
if aux.Amount != (money.Money{}) {
amount := aux.Amount // Create a copy
e.Amount = &amount
} else {
e.Amount = nil
}
e.To = aux.To
e.TransactionID = aux.RequestTransactionID
e.RequestType = aux.RequestType
e.RequestPayload = aux.RequestPayload
// Copy CurrencyConverted specific fields
e.TransactionID = aux.TransactionID
e.ConvertedAmount = aux.ConvertedAmount
// Parse ConversionInfo if present
if len(aux.ConversionInfo) > 0 {
var info exchange.RateInfo
if err := json.Unmarshal(aux.ConversionInfo, &info); err != nil {
return fmt.Errorf("failed to unmarshal ConversionInfo: %w", err)
}
e.ConversionInfo = &info
}
// Handle the OriginalRequest reconstruction
if aux.RequestType != "" && len(aux.RequestPayload) > 0 {
// Create a new instance of the appropriate type based on RequestType
var request Event
switch aux.RequestType {
case "*events.DepositRequested", "events.DepositRequested":
req := &DepositRequested{}
if err := json.Unmarshal(aux.RequestPayload, req); err != nil {
return fmt.Errorf(
"failed to unmarshal DepositRequested in CurrencyConverted: %w",
err,
)
}
request = req
case "*events.WithdrawRequested", "events.WithdrawRequested":
req := &WithdrawRequested{}
if err := json.Unmarshal(aux.RequestPayload, req); err != nil {
return fmt.Errorf(
"failed to unmarshal WithdrawRequested in CurrencyConverted: %w",
err,
)
}
request = req
case "*events.TransferRequested", "events.TransferRequested":
req := &TransferRequested{}
if err := json.Unmarshal(aux.RequestPayload, req); err != nil {
return fmt.Errorf(
"failed to unmarshal TransferRequested in CurrencyConverted: %w",
err,
)
}
request = req
default:
return fmt.Errorf(
"unsupported request type in CurrencyConverted: %s",
aux.RequestType,
)
}
e.OriginalRequest = request
}
return nil
}
// CurrencyConversionFailed is an event for reporting a failed currency conversion.
type CurrencyConversionFailed struct {
FlowEvent
TransactionID uuid.UUID
Amount money.Money
To money.Code
Reason string
}
func (e CurrencyConversionFailed) Type() string {
return EventTypeCurrencyConversionFailed.String()
}
package events
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// CurrencyConversionRequestedOpt --- CurrencyConversionRequested ---
type CurrencyConversionRequestedOpt func(*CurrencyConversionRequested)
// WithConversionAmount sets the amount for the CurrencyConversionRequested.
func WithConversionAmount(amount *money.Money) CurrencyConversionRequestedOpt {
return func(e *CurrencyConversionRequested) { e.Amount = amount }
}
// WithConversionTo sets the target currency for the
// CurrencyConversionRequested.
func WithConversionTo(currency money.Code) CurrencyConversionRequestedOpt {
return func(e *CurrencyConversionRequested) { e.To = currency }
}
// WithConversionTransactionID sets the transaction ID for the
// CurrencyConversionRequested.
func WithConversionTransactionID(id uuid.UUID) CurrencyConversionRequestedOpt {
return func(e *CurrencyConversionRequested) { e.TransactionID = id }
}
// NewCurrencyConversionRequested creates a new CurrencyConversionRequested
// with the given options.
func NewCurrencyConversionRequested(
fe FlowEvent,
or Event,
opts ...CurrencyConversionRequestedOpt,
) *CurrencyConversionRequested {
ccr := &CurrencyConversionRequested{
FlowEvent: fe,
OriginalRequest: or,
}
ccr.ID = uuid.New()
ccr.Timestamp = time.Now()
for _, opt := range opts {
opt(ccr)
}
return ccr
}
// CurrencyConvertedOpt --- CurrencyConverted ---
type CurrencyConvertedOpt func(*CurrencyConverted)
// NewCurrencyConverted creates a new CurrencyConverted with the given options.
func NewCurrencyConverted(
ccr *CurrencyConversionRequested,
opts ...CurrencyConvertedOpt,
) *CurrencyConverted {
cc := &CurrencyConverted{
CurrencyConversionRequested: *ccr,
}
cc.ID = uuid.New()
cc.Timestamp = time.Now()
for _, opt := range opts {
opt(cc)
}
return cc
}
package events
import (
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// DepositRequested is emitted after deposit validation and persistence.
type DepositRequested struct {
FlowEvent
Amount *money.Money
Source string
TransactionID uuid.UUID
}
func (e DepositRequested) Type() string { return EventTypeDepositRequested.String() }
func (e DepositRequested) Validate() error {
return nil
}
// DepositCurrencyConverted is emitted after currency conversion for deposit.
type DepositCurrencyConverted struct {
CurrencyConverted
}
func (e DepositCurrencyConverted) Type() string {
return EventTypeDepositCurrencyConverted.String()
}
// DepositValidated is emitted after business validation for deposit.
type DepositValidated struct {
DepositCurrencyConverted
}
func (e DepositValidated) Type() string { return EventTypeDepositValidated.String() }
// DepositFailed is emitted when a deposit fails.
type DepositFailed struct {
DepositRequested
Reason string
}
func (e DepositFailed) Type() string { return EventTypeDepositFailed.String() }
package events
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// DepositRequestedOpt is a function that configures a DepositRequested
type DepositRequestedOpt func(*DepositRequested)
// WithDepositAmount sets the deposit amount
func WithDepositAmount(m *money.Money) DepositRequestedOpt {
return func(e *DepositRequested) { e.Amount = m }
}
// WithDepositTimestamp sets the deposit timestamp
func WithDepositTimestamp(ts time.Time) DepositRequestedOpt {
return func(e *DepositRequested) { e.Timestamp = ts }
}
// WithDepositID sets the deposit ID
func WithDepositID(id uuid.UUID) DepositRequestedOpt {
return func(e *DepositRequested) { e.ID = id }
}
// WithDepositFlowEvent sets the flow event for the deposit
func WithDepositFlowEvent(fe FlowEvent) DepositRequestedOpt {
return func(e *DepositRequested) { e.FlowEvent = fe }
}
// WithDepositTransactionID sets the transaction ID for the deposit
func WithDepositTransactionID(id uuid.UUID) DepositRequestedOpt {
return func(e *DepositRequested) { e.TransactionID = id }
}
// WithDepositSource is a test helper to set the source on a DepositRequested event
func WithDepositSource(source string) DepositRequestedOpt {
return func(e *DepositRequested) {
e.Source = source
}
}
// NewDepositRequested creates a new DepositRequested event with the given
// parameters
func NewDepositRequested(
userID, accountID, correlationID uuid.UUID,
opts ...DepositRequestedOpt,
) *DepositRequested {
dr := &DepositRequested{
FlowEvent: FlowEvent{
ID: uuid.New(),
FlowType: "deposit",
UserID: userID,
AccountID: accountID,
CorrelationID: correlationID,
Timestamp: time.Now(),
},
TransactionID: uuid.New(),
Amount: money.Zero(money.USD),
}
for _, opt := range opts {
opt(dr)
}
return dr
}
type DepositCurrencyConvertedOpt func(*DepositCurrencyConverted)
// NewDepositCurrencyConverted creates a new DepositCurrencyConverted event with
// the given parameters
func NewDepositCurrencyConverted(
cc *CurrencyConverted,
opts ...DepositCurrencyConvertedOpt,
) *DepositCurrencyConverted {
de := &DepositCurrencyConverted{
CurrencyConverted: *cc,
}
de.ID = uuid.New()
de.Timestamp = time.Now()
for _, opt := range opts {
opt(de)
}
return de
}
type DepositValidatedOpt func(*DepositValidated)
// NewDepositValidated creates a new DepositValidated event with the given parameters
func NewDepositValidated(dcv *DepositCurrencyConverted) *DepositValidated {
dv := &DepositValidated{
DepositCurrencyConverted: *dcv,
}
dv.ID = uuid.New()
dv.Timestamp = time.Now()
return dv
}
// DepositFailedOpt is a function that configures a DepositFailed
type DepositFailedOpt func(*DepositFailed)
// WithFailureReason sets the failure reason
func WithFailureReason(reason string) DepositFailedOpt {
return func(df *DepositFailed) { df.Reason = reason }
}
// WithDepositFailedTransactionID sets the transaction ID for a failed deposit event
func WithDepositFailedTransactionID(id uuid.UUID) DepositFailedOpt {
return func(df *DepositFailed) { df.TransactionID = id }
}
// NewDepositFailed creates a new DepositFailed event with the given parameters
func NewDepositFailed(
dr *DepositRequested,
reason string,
opts ...DepositFailedOpt,
) *DepositFailed {
df := &DepositFailed{
DepositRequested: *dr,
Reason: reason,
}
df.ID = uuid.New()
df.Timestamp = time.Now()
for _, opt := range opts {
opt(df)
}
return df
}
package events
// EventType represents the type of an event in the system.
type EventType string
// Event type constants
const (
// Payment events
EventTypePaymentInitiated EventType = "Payment.Initiated"
EventTypePaymentProcessed EventType = "Payment.Processed"
EventTypePaymentCompleted EventType = "Payment.Completed"
EventTypePaymentFailed EventType = "Payment.Failed"
// Deposit events
EventTypeDepositRequested EventType = "Deposit.Requested"
EventTypeDepositCurrencyConverted EventType = "Deposit.CurrencyConverted"
EventTypeDepositValidated EventType = "Deposit.Validated"
EventTypeDepositFailed EventType = "Deposit.Failed"
// Withdraw events
EventTypeWithdrawRequested EventType = "Withdraw.Requested"
EventTypeWithdrawCurrencyConverted EventType = "Withdraw.CurrencyConverted"
EventTypeWithdrawValidated EventType = "Withdraw.Validated"
EventTypeWithdrawFailed EventType = "Withdraw.Failed"
// UserOnboardingCompleted event
EventTypeUserOnboardingCompleted EventType = "User.OnboardingCompleted"
// Transfer events
EventTypeTransferRequested EventType = "Transfer.Requested"
EventTypeTransferCurrencyConverted EventType = "Transfer.CurrencyConverted"
EventTypeTransferValidated EventType = "Transfer.Validated"
EventTypeTransferPaid EventType = "Transfer.Paid"
EventTypeTransferCompleted EventType = "Transfer.Completed"
EventTypeTransferFailed EventType = "Transfer.Failed"
// Fee events
EventTypeFeesCalculated EventType = "Fees.Calculated"
// Currency conversion events
EventTypeCurrencyConversionRequested EventType = "CurrencyConversion.Requested"
EventTypeCurrencyConverted EventType = "CurrencyConversion.Converted"
EventTypeCurrencyConversionFailed EventType = "CurrencyConversion.Failed"
)
// String returns the string representation of the event type.
func (et EventType) String() string {
return string(et)
}
package events
import (
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/google/uuid"
)
// FeesCalculated is emitted after all fees for a transaction have been calculated.
type FeesCalculated struct {
FlowEvent
TransactionID uuid.UUID
Fee account.Fee
}
func (e FeesCalculated) Type() string { return EventTypeFeesCalculated.String() }
package events
import (
"time"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// FeesCalculatedOpt is a function that configures a FeesCalculated event
type FeesCalculatedOpt func(*FeesCalculated)
// WithFeeTransactionID sets the transaction ID for the FeesCalculated event
func WithFeeTransactionID(id uuid.UUID) FeesCalculatedOpt {
return func(e *FeesCalculated) { e.TransactionID = id }
}
// WithFee sets the fee amount for the FeesCalculated event
func WithFee(fee account.Fee) FeesCalculatedOpt {
return func(e *FeesCalculated) { e.Fee = fee }
}
// NewFeesCalculated creates a new FeesCalculated event with the given options
func NewFeesCalculated(ef *FlowEvent, opts ...FeesCalculatedOpt) *FeesCalculated {
e := &FeesCalculated{
FlowEvent: *ef,
}
// Set default values
e.ID = uuid.New()
e.Timestamp = time.Now()
// Apply options
for _, opt := range opts {
opt(e)
}
return e
}
// WithFeeType sets the fee type for the FeesCalculated event
func WithFeeType(feeType account.FeeType) FeesCalculatedOpt {
return func(e *FeesCalculated) {
e.Fee.Type = feeType
}
}
// WithFeeAmountValue sets the fee amount for the FeesCalculated event
func WithFeeAmountValue(amount *money.Money) FeesCalculatedOpt {
return func(e *FeesCalculated) {
e.Fee.Amount = amount
}
}
package events
import (
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// PaymentInitiated is emitted after payment initiation with a provider (event-driven workflow).
type PaymentInitiated struct {
FlowEvent
Amount *money.Money
TransactionID uuid.UUID
PaymentID *string // Pointer to allow NULL in database
Status string
}
func (e PaymentInitiated) Type() string { return EventTypePaymentInitiated.String() }
func (e *PaymentInitiated) WithAmount(m *money.Money) *PaymentInitiated {
e.Amount = m
return e
}
func (e *PaymentInitiated) WithTransactionID(id uuid.UUID) *PaymentInitiated {
e.TransactionID = id
return e
}
func (e *PaymentInitiated) WithPaymentID(id string) *PaymentInitiated {
if id != "" {
e.PaymentID = &id
} else {
e.PaymentID = nil
}
return e
}
func (e *PaymentInitiated) WithStatus(status string) *PaymentInitiated {
e.Status = status
return e
}
// PaymentFailed is emitted when payment fails.
type PaymentFailed struct {
PaymentInitiated
Reason string
}
func (e *PaymentFailed) Type() string { return EventTypePaymentFailed.String() }
func (e *PaymentFailed) WithReason(reason string) *PaymentFailed {
e.Reason = reason
return e
}
type PaymentProcessed struct {
PaymentInitiated
}
func (e *PaymentProcessed) Type() string { return EventTypePaymentProcessed.String() }
func (e *PaymentProcessed) WithAmount(m *money.Money) *PaymentProcessed {
e.Amount = m
return e
}
// PaymentCompleted is an event for when a payment is completed.
type PaymentCompleted struct {
PaymentInitiated
}
func (e PaymentCompleted) Type() string { return EventTypePaymentCompleted.String() }
package events
import (
"time"
"github.com/google/uuid"
)
// PaymentInitiatedOpt is a function that configures a PaymentInitiated
type PaymentInitiatedOpt func(*PaymentInitiated)
// WithPaymentTransactionID sets the transaction ID for the PaymentInitiated
func WithPaymentTransactionID(id uuid.UUID) PaymentInitiatedOpt {
return func(
pi *PaymentInitiated,
) {
pi.TransactionID = id
}
}
// WithPaymentID sets the payment ID for the PaymentInitiated
func WithInitiatedPaymentID(paymentID string) PaymentInitiatedOpt {
return func(
pi *PaymentInitiated,
) {
if paymentID != "" {
pi.PaymentID = &paymentID
} else {
pi.PaymentID = nil
}
}
}
// WithPaymentStatus sets the status for the PaymentInitiated
func WithInitiatedPaymentStatus(status string) PaymentInitiatedOpt {
return func(e *PaymentInitiated) { e.Status = status }
}
// WithFlowEvent sets the FlowEvent from an existing FlowEvent
func WithFlowEvent(fe FlowEvent) PaymentInitiatedOpt {
return func(e *PaymentInitiated) {
e.FlowEvent = fe
}
}
// NewPaymentInitiated creates a new PaymentInitiated with the given options
func NewPaymentInitiated(fe *FlowEvent, opts ...PaymentInitiatedOpt) *PaymentInitiated {
pi := &PaymentInitiated{
FlowEvent: *fe,
Status: "initiated",
}
pi.ID = uuid.New()
pi.Timestamp = time.Now()
for _, opt := range opts {
opt(pi)
}
return pi
}
type PaymentProcessedOpt func(*PaymentProcessed)
// NewPaymentProcessed creates a new PaymentProcessed with the given parameters
func NewPaymentProcessed(
ef *FlowEvent,
opts ...PaymentProcessedOpt,
) *PaymentProcessed {
// Create base PaymentInitiated with required fields
pp := &PaymentProcessed{
PaymentInitiated: PaymentInitiated{
FlowEvent: *ef,
},
}
pp.ID = uuid.New()
pp.Timestamp = time.Now()
// Apply any additional options
for _, opt := range opts {
opt(pp)
}
return pp
}
// PaymentCompletedOpt is a function that configures a PaymentCompletedEvent
type PaymentCompletedOpt func(*PaymentCompleted)
// WithPaymentID sets the payment ID for the PaymentCompletedEvent
func WithPaymentID(paymentID *string) PaymentCompletedOpt {
return func(e *PaymentCompleted) {
e.PaymentID = paymentID
}
}
// WithCorrelationID sets the correlation ID for the PaymentCompletedEvent
func WithCorrelationID(correlationID uuid.UUID) PaymentCompletedOpt {
return func(e *PaymentCompleted) { e.CorrelationID = correlationID }
}
// NewPaymentCompleted creates a new PaymentCompleted with the given options
func NewPaymentCompleted(
ef *FlowEvent,
opts ...PaymentCompletedOpt,
) *PaymentCompleted {
pc := &PaymentCompleted{
PaymentInitiated: PaymentInitiated{
FlowEvent: *ef,
},
}
pc.ID = uuid.New()
pc.Timestamp = time.Now()
for _, opt := range opts {
opt(pc)
}
return pc
}
// PaymentFailedOpt is a function that configures a PaymentFailedEvent
type PaymentFailedOpt func(*PaymentFailed)
// WithPaymentID sets the payment ID for the PaymentFailedEvent
func WithFailedPaymentID(paymentID *string) PaymentFailedOpt {
return func(e *PaymentFailed) {
e.PaymentID = paymentID
}
}
// NewPaymentFailed creates a new PaymentFailed with the given options
func NewPaymentFailed(
ef *FlowEvent,
opts ...PaymentFailedOpt,
) *PaymentFailed {
pf := &PaymentFailed{
PaymentInitiated: PaymentInitiated{
FlowEvent: *ef,
},
}
pf.ID = uuid.New()
pf.Timestamp = time.Now()
for _, opt := range opts {
opt(pf)
}
return pf
}
package events
import (
"fmt"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// TransferRequested is emitted after transfer validation and persistence.
type TransferRequested struct {
FlowEvent
Amount *money.Money
Source string
DestAccountID uuid.UUID
Timestamp time.Time
TransactionID uuid.UUID
Fee int64
}
func (e *TransferRequested) Type() string {
return EventTypeTransferRequested.String()
}
func (e *TransferRequested) WithDestAccountID(id uuid.UUID) *TransferRequested {
return NewTransferRequested(
e.UserID,
e.AccountID,
e.CorrelationID,
WithTransferDestAccountID(id),
)
}
func (e *TransferRequested) WithAmount(m *money.Money) *TransferRequested {
return NewTransferRequested(
e.UserID,
e.AccountID,
e.CorrelationID,
WithTransferRequestedAmount(m),
)
}
// Validate checks if the event is valid.
func (e *TransferRequested) Validate() error {
if e.AccountID == uuid.Nil || e.UserID == uuid.Nil ||
e.DestAccountID == uuid.Nil || e.Amount.IsZero() || e.Amount.IsNegative() {
return fmt.Errorf("malformed validated event: %+v", e)
}
return nil
}
// TransferCurrencyConverted is emitted after currency conversion for transfer.
type TransferCurrencyConverted struct {
CurrencyConverted
}
func (e TransferCurrencyConverted) Type() string {
return EventTypeTransferCurrencyConverted.String()
}
// TransferValidated is emitted after business validation for transfer.
type TransferValidated struct {
TransferCurrencyConverted
}
func (e TransferValidated) Type() string { return EventTypeTransferValidated.String() }
// TransferCompleted is emitted when transfer is fully completed.
type TransferCompleted struct {
TransferValidated
}
func (e TransferCompleted) Type() string { return EventTypeTransferCompleted.String() }
// TransferFailed is emitted when transfer fails.
type TransferFailed struct {
TransferRequested
Reason string
}
func (e TransferFailed) Type() string {
return EventTypeTransferFailed.String()
}
package events
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// --- TransferRequested ---
type TransferRequestedOpt func(*TransferRequested)
func WithTransferRequestedAmount(m *money.Money) TransferRequestedOpt {
return func(e *TransferRequested) { e.Amount = m }
}
// WithTransferFee sets the transfer fee
func WithTransferFee(fee int64) TransferRequestedOpt {
return func(e *TransferRequested) { e.Fee = fee }
}
func WithTransferDestAccountID(id uuid.UUID) TransferRequestedOpt {
return func(e *TransferRequested) { e.DestAccountID = id }
}
func NewTransferRequested(
userID, accountID, correlationID uuid.UUID,
opts ...TransferRequestedOpt,
) *TransferRequested {
event := TransferRequested{
FlowEvent: FlowEvent{
ID: uuid.New(),
FlowType: "transfer",
UserID: userID,
AccountID: accountID,
CorrelationID: correlationID,
},
Amount: money.Zero(money.USD),
Timestamp: time.Now(),
}
for _, opt := range opts {
opt(&event)
}
return &event
}
type TransferCurrencyConvertedOpt func(*TransferCurrencyConverted)
// NewTransferCurrencyConverted creates a new TransferCurrencyConverted event
func NewTransferCurrencyConverted(
cc *CurrencyConverted,
opts ...TransferCurrencyConvertedOpt,
) *TransferCurrencyConverted {
tr := &TransferCurrencyConverted{
CurrencyConverted: *cc,
}
for _, opt := range opts {
opt(tr)
}
return tr
}
// TransferValidatedOpt --- TransferValidated ---
type TransferValidatedOpt func(*TransferValidated)
// NewTransferValidated creates a new TransferValidated event
func NewTransferValidated(
tr *TransferCurrencyConverted,
opts ...TransferValidatedOpt,
) *TransferValidated {
tf := &TransferValidated{
TransferCurrencyConverted: *tr,
}
tf.Timestamp = time.Now()
for _, opt := range opts {
opt(tf)
}
return tf
}
// TransferFailedOpt --- TransferFailed ---
type TransferFailedOpt func(*TransferFailed)
// WithReason sets the failure reason
func WithReason(reason string) TransferFailedOpt {
return func(f *TransferFailed) { f.Reason = reason }
}
// NewTransferFailed creates a new TransferFailed event
func NewTransferFailed(
tr *TransferRequested,
reason string,
opts ...TransferFailedOpt,
) *TransferFailed {
tf := &TransferFailed{
TransferRequested: *tr,
Reason: reason,
}
tf.ID = uuid.New()
tf.Timestamp = time.Now()
for _, opt := range opts {
opt(tf)
}
return tf
}
// TransferCompletedOpt --- TransferCompleted factory ---
type TransferCompletedOpt func(*TransferCompleted)
func WithTransferAmount(m *money.Money) TransferCompletedOpt {
return func(e *TransferCompleted) { e.Amount = m }
}
// NewTransferCompleted creates a new TransferCompleted with the given options
func NewTransferCompleted(
tr *TransferRequested,
opts ...TransferCompletedOpt,
) *TransferCompleted {
tc := &TransferCompleted{
TransferValidated: TransferValidated{
TransferCurrencyConverted: TransferCurrencyConverted{
CurrencyConverted: *NewCurrencyConverted(
NewCurrencyConversionRequested(tr.FlowEvent, tr),
),
},
},
}
tc.ID = uuid.New()
tc.Timestamp = time.Now()
for _, opt := range opts {
opt(tc)
}
return tc
}
package events
// EventTypes maps event type constants to their respective constructor functions.
var EventTypes = map[EventType]func() Event{
EventTypePaymentInitiated: func() Event { return &PaymentInitiated{} },
EventTypePaymentCompleted: func() Event { return &PaymentCompleted{} },
EventTypePaymentProcessed: func() Event { return &PaymentProcessed{} },
EventTypeDepositRequested: func() Event { return &DepositRequested{} },
EventTypeDepositCurrencyConverted: func() Event {
return &DepositCurrencyConverted{}
},
EventTypeDepositValidated: func() Event { return &DepositValidated{} },
EventTypeDepositFailed: func() Event { return &DepositFailed{} },
EventTypeWithdrawRequested: func() Event { return &WithdrawRequested{} },
EventTypeWithdrawCurrencyConverted: func() Event {
return &WithdrawCurrencyConverted{}
},
EventTypeWithdrawValidated: func() Event { return &WithdrawValidated{} },
EventTypeWithdrawFailed: func() Event { return &WithdrawFailed{} },
EventTypeTransferRequested: func() Event { return &TransferRequested{} },
EventTypeTransferCurrencyConverted: func() Event {
return &TransferCurrencyConverted{}
},
EventTypeTransferValidated: func() Event { return &TransferValidated{} },
EventTypeTransferCompleted: func() Event { return &TransferCompleted{} },
EventTypeTransferFailed: func() Event { return &TransferFailed{} },
EventTypeCurrencyConversionRequested: func() Event {
return &CurrencyConversionRequested{}
},
EventTypeCurrencyConverted: func() Event { return &CurrencyConverted{} },
EventTypeCurrencyConversionFailed: func() Event {
return &CurrencyConversionFailed{}
},
EventTypeFeesCalculated: func() Event { return &FeesCalculated{} },
}
package events
import (
"fmt"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// WithdrawRequested is emitted when a withdrawal is requested (pure event-driven
// domain).
type WithdrawRequested struct {
FlowEvent
ID uuid.UUID
TransactionID uuid.UUID
Amount *money.Money
BankAccountNumber string
RoutingNumber string
ExternalWalletAddress string
Timestamp time.Time
PaymentID string // Added for payment provider integration
Fee int64
}
func (e *WithdrawRequested) Type() string {
return EventTypeWithdrawRequested.String()
}
// Validate performs business validation on the withdrawal request
func (e *WithdrawRequested) Validate() error {
if e.AccountID == uuid.Nil {
return fmt.Errorf("account ID cannot be nil")
}
if e.UserID == uuid.Nil {
return fmt.Errorf("user ID cannot be nil")
}
if e.Amount.IsZero() {
return fmt.Errorf("amount cannot be zero")
}
if e.Amount.IsNegative() {
return fmt.Errorf("amount must be positive")
}
return nil
}
// WithdrawCurrencyConverted is emitted after currency conversion for withdraw.
type WithdrawCurrencyConverted struct {
CurrencyConverted
}
func (e WithdrawCurrencyConverted) Type() string {
return EventTypeWithdrawCurrencyConverted.String()
}
// WithdrawValidated is emitted after business validation for withdraw.
type WithdrawValidated struct {
WithdrawCurrencyConverted
}
func (e WithdrawValidated) Type() string {
return EventTypeWithdrawValidated.String()
}
// WithdrawFailed is emitted when any part of the withdrawal flow fails.
type WithdrawFailed struct {
WithdrawRequested
Reason string
}
func (e WithdrawFailed) Type() string { return EventTypeWithdrawFailed.String() }
// UserOnboardingCompleted is emitted when a user completes the Stripe onboarding process.
type UserOnboardingCompleted struct {
FlowEvent
StripeAccountID string
}
func (e UserOnboardingCompleted) Type() string {
return EventTypeUserOnboardingCompleted.String()
}
package events
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
)
// --- WithdrawRequested ---
type WithdrawRequestedOpt func(*WithdrawRequested)
func WithWithdrawAmount(m *money.Money) WithdrawRequestedOpt {
return func(e *WithdrawRequested) { e.Amount = m }
}
func WithWithdrawTimestamp(ts time.Time) WithdrawRequestedOpt {
return func(e *WithdrawRequested) { e.Timestamp = ts }
}
func WithWithdrawID(id uuid.UUID) WithdrawRequestedOpt {
return func(e *WithdrawRequested) { e.ID = id }
}
func WithWithdrawFlowEvent(fe FlowEvent) WithdrawRequestedOpt {
return func(e *WithdrawRequested) { e.FlowEvent = fe }
}
// WithWithdrawBankAccountNumber sets the bank account number for the withdraw
// request
func WithWithdrawBankAccountNumber(accountNumber string) WithdrawRequestedOpt {
return func(e *WithdrawRequested) { e.BankAccountNumber = accountNumber }
}
func NewWithdrawRequested(
userID, accountID, correlationID uuid.UUID,
opts ...WithdrawRequestedOpt,
) *WithdrawRequested {
wr := &WithdrawRequested{
FlowEvent: FlowEvent{
ID: uuid.New(),
FlowType: "withdraw",
UserID: userID,
AccountID: accountID,
CorrelationID: correlationID,
Timestamp: time.Now(),
},
Amount: money.Zero(money.USD),
}
for _, opt := range opts {
opt(wr)
}
return wr
}
type WithdrawCurrencyConvertedOpt func(*WithdrawCurrencyConverted)
// NewWithdrawCurrencyConverted creates a new WithdrawCurrencyConverted event.
// It takes a CurrencyConverted and combines it into a
// WithdrawCurrencyConverted event, ensuring all necessary fields are properly propagated.
func NewWithdrawCurrencyConverted(
cc *CurrencyConverted,
opts ...WithdrawCurrencyConvertedOpt,
) *WithdrawCurrencyConverted {
wcc := &WithdrawCurrencyConverted{
CurrencyConverted: *cc,
}
// Ensure the TransactionID is properly set from the CurrencyConverted event
if wcc.TransactionID == uuid.Nil {
wcc.TransactionID = cc.TransactionID
}
wcc.ID = uuid.New()
wcc.Timestamp = time.Now()
for _, opt := range opts {
opt(wcc)
}
return wcc
}
type WithdrawValidatedOpt func(*WithdrawValidated)
// NewWithdrawValidated creates a new WithdrawValidated event.
// It takes a WithdrawCurrencyConverted and returns a new WithdrawValidated
// event.
func NewWithdrawValidated(
cc *WithdrawCurrencyConverted,
opts ...WithdrawValidatedOpt,
) *WithdrawValidated {
wv := &WithdrawValidated{
WithdrawCurrencyConverted: *cc,
}
wv.ID = uuid.New()
wv.Timestamp = time.Now()
for _, opt := range opts {
opt(wv)
}
return wv
}
// --- WithdrawFailed ---
type WithdrawFailedOpt func(*WithdrawFailed)
func WithWithdrawFailureReason(reason string) WithdrawFailedOpt {
return func(wf *WithdrawFailed) { wf.Reason = reason }
}
func NewWithdrawFailed(
wr *WithdrawRequested,
reason string,
opts ...WithdrawFailedOpt,
) *WithdrawFailed {
wf := &WithdrawFailed{
WithdrawRequested: *wr,
Reason: reason,
}
wf.ID = uuid.New()
wf.Timestamp = time.Now()
for _, opt := range opts {
opt(wf)
}
return wf
}
func NewUserOnboardingCompleted(userID uuid.UUID, stripeAccountID string) *UserOnboardingCompleted {
return &UserOnboardingCompleted{
FlowEvent: FlowEvent{
ID: uuid.New(),
UserID: userID,
Timestamp: time.Now(),
},
StripeAccountID: stripeAccountID,
}
}
package domain
import (
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/common"
"github.com/amirasaad/fintech/pkg/domain/user"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/exchange"
)
// Type aliases for backward compatibility
// Account and related
// Account is an alias for account.Account
// Deprecated: Use account.Account directly.
type Account = account.Account
// Transaction represents a financial transaction in the system.
// Deprecated: Use account.Transaction directly.
type Transaction = account.Transaction
// User and related
type User = user.User
// Error aliases for backward compatibility
var (
// ErrDepositAmountExceedsMaxSafeInt Account errors
// Deprecated: Use account.ErrDepositAmountExceedsMaxSafeInt directly.
ErrDepositAmountExceedsMaxSafeInt = account.ErrDepositAmountExceedsMaxSafeInt
// Deprecated: Use account.ErrTransactionAmountMustBePositive directly.
ErrTransactionAmountMustBePositive = account.ErrTransactionAmountMustBePositive
// Deprecated: Use account.ErrInsufficientFunds directly.
ErrInsufficientFunds = account.ErrInsufficientFunds
// Deprecated: Use account.ErrAccountNotFound directly.
ErrAccountNotFound = account.ErrAccountNotFound
// Deprecated: Use account.ErrInvalidCurrencyCode directly.
ErrInvalidCurrencyCode = common.ErrInvalidCurrencyCode
// Deprecated: Use account.ErrUserUnauthorized directly.
ErrUserUnauthorized = user.ErrUserUnauthorized
// Currency-related errors
// Deprecated: Use exchange.ErrProviderUnavailable directly.
ErrExchangeRateUnavailable = exchange.ErrProviderUnavailable
// Deprecated: Use exchange.ErrUnsupportedPair directly.
ErrUnsupportedCurrencyPair = exchange.ErrUnsupportedPair
)
// ConversionInfo is an alias for exchange.RateInfo
// Deprecated: Use exchange.RateInfo directly.
type ConversionInfo = exchange.RateInfo
// ExchangeRate is an alias for exchange.RateInfo
// Deprecated: Use exchange.RateInfo directly.
type ExchangeRate = exchange.RateInfo
// Deprecated: use money.New
func NewMoney(amount float64, currencyCode money.Code) (m *money.Money, err error) {
return money.New(amount, currencyCode)
}
package user
import (
"errors"
"time"
"github.com/amirasaad/fintech/pkg/utils"
"github.com/google/uuid"
)
var (
// ErrUserNotFound is returned when a user cannot be found in the
// repository.
ErrUserNotFound = errors.New("user not found")
// ErrUserUnauthorized is return when user
ErrUserUnauthorized = errors.New("user unauthorized")
)
// User represents a user in the system.
type User struct {
ID uuid.UUID `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
Password string `json:"password"`
Names string `json:"names"`
CreatedAt time.Time `json:"created"`
UpdatedAt time.Time `json:"updated"`
}
// New creates a new User with a hashed password and current timestamps.
func New(username, email, password string) (*User, error) {
if username == "" {
return nil, errors.New("username cannot be empty")
}
if email == "" {
return nil, errors.New("email cannot be empty")
}
hashedPassword, err := utils.HashPassword(password)
if err != nil {
return nil, err
}
return &User{
ID: uuid.New(),
Username: username,
Email: email,
Password: hashedPassword,
CreatedAt: time.Now().UTC(),
UpdatedAt: time.Now().UTC(),
}, nil
}
// NewUserFromData creates a User from raw data (used for DB hydration).
func NewUserFromData(
id uuid.UUID,
username, email, password string,
created, updated time.Time,
) *User {
return &User{
ID: id,
Username: username,
Email: email,
Password: password,
CreatedAt: created,
UpdatedAt: updated,
}
}
package core
import "errors"
// Common errors for exchange operations
var (
// ErrInvalidRate indicates that an invalid exchange rate was provided
ErrInvalidRate = errors.New("invalid exchange rate")
// ErrUnsupportedCurrencyPair indicates that the currency pair is not supported
ErrUnsupportedCurrencyPair = errors.New("unsupported currency pair")
// ErrProviderUnavailable indicates that the rate provider is not available
ErrProviderUnavailable = errors.New("rate provider unavailable")
// ErrRateNotFound indicates that the requested rate was not found
ErrRateNotFound = errors.New("exchange rate not found")
// ErrInvalidAmount indicates that an invalid amount was provided
ErrInvalidAmount = errors.New("invalid amount")
)
// ProviderError represents an error from a rate provider
type ProviderError struct {
Provider string
Err error
}
func (e *ProviderError) Error() string {
return "provider " + e.Provider + ": " + e.Err.Error()
}
func (e *ProviderError) Unwrap() error {
return e.Err
}
// IsProviderError checks if an error is a ProviderError
func IsProviderError(err error) bool {
_, ok := err.(*ProviderError)
return ok
}
package service
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/exchange/core"
"github.com/amirasaad/fintech/pkg/provider/exchange"
)
// Service handles currency exchange operations using a provider and cache
type Service struct {
provider exchange.Exchange // Single provider that may be a composite
cache *exchange.Cache // Optional cache
logger *slog.Logger // Logger for the service
}
// New creates a new exchange service with the given provider and cache
func New(
provider exchange.Exchange,
cache *exchange.Cache,
logger *slog.Logger,
) *Service {
if logger == nil {
logger = slog.Default()
}
return &Service{
provider: provider,
cache: cache,
logger: logger,
}
}
// Convert converts an amount from one currency to another
func (s *Service) Convert(
ctx context.Context,
from, to string,
amount float64,
) (*core.ConversionResult, error) {
if amount <= 0 {
return nil, core.ErrInvalidAmount
}
rate, err := s.GetRate(ctx, from, to)
if err != nil {
return nil, fmt.Errorf("failed to get exchange rate: %w", err)
}
return &core.ConversionResult{
FromAmount: amount,
ToAmount: amount * rate.Value,
Rate: rate.Value,
Source: rate.Source,
}, nil
}
// GetRate gets the exchange rate between two currencies
func (s *Service) GetRate(
ctx context.Context,
from, to string,
) (*core.Rate, error) {
// Try cache first
if s.cache != nil {
if rate, err := s.cache.GetRate(ctx, from, to); err == nil {
return &core.Rate{
From: from,
To: to,
Value: rate.Rate,
Timestamp: rate.Timestamp,
Source: rate.Provider,
}, nil
}
}
// Get rate from provider
rateInfo, err := s.provider.FetchRate(ctx, from, to)
if err != nil {
return nil, fmt.Errorf("provider error: %w", err)
}
rate := &core.Rate{
From: from,
To: to,
Value: rateInfo.Rate,
Timestamp: rateInfo.Timestamp,
Source: rateInfo.Provider,
}
// Update cache
if s.cache != nil {
if err := s.cache.StoreRate(ctx, rateInfo); err != nil {
s.logger.Error("failed to cache rate", "error", err)
}
}
return rate, nil
}
// GetRates gets multiple exchange rates in a single request
func (s *Service) GetRates(
ctx context.Context,
from string,
to []string,
) (map[string]*core.Rate, error) {
if len(to) == 0 {
return nil, errors.New("no target currencies provided")
}
// Check cache first
cachedRates := make(map[string]*core.Rate)
var toFetch []string
if s.cache != nil {
rates, err := s.cache.BatchGetRates(ctx, from, to)
if err == nil {
for currency, rate := range rates {
if rate != nil {
cachedRates[currency] = &core.Rate{
From: from,
To: currency,
Value: rate.Rate,
Timestamp: rate.Timestamp,
Source: rate.Provider,
}
} else {
toFetch = append(toFetch, currency)
}
}
} else {
toFetch = to
}
} else {
toFetch = to
}
// If we have everything cached, return early
if len(cachedRates) == len(to) {
return cachedRates, nil
}
// Fetch missing rates from provider
fetchedRates := make(map[string]*core.Rate)
if len(toFetch) > 0 {
rates, err := s.provider.FetchRates(ctx, from)
if err != nil {
return nil, fmt.Errorf("failed to fetch rates: %w", err)
}
// Convert RateInfo to core.Rate and update cache
for currency, rateInfo := range rates {
fetchedRates[currency] = &core.Rate{
From: from,
To: currency,
Value: rateInfo.Rate,
Timestamp: rateInfo.Timestamp,
Source: rateInfo.Provider,
}
// Update cache
if s.cache != nil {
if err := s.cache.StoreRate(ctx, rateInfo); err != nil {
s.logger.Error("failed to cache rate", "error", err)
}
}
}
}
// Merge cached and fetched rates
result := make(map[string]*core.Rate, len(to))
for _, currency := range to {
if rate, exists := cachedRates[currency]; exists {
result[currency] = rate
} else if rate, exists := fetchedRates[currency]; exists {
result[currency] = rate
} else {
result[currency] = nil // Indicate missing rate
}
}
return result, nil
}
package deposit
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/mapper"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleCurrencyConverted performs domain validation after currency conversion for deposits.
// Emits DepositBusinessValidated event to trigger payment initiation.
func HandleCurrencyConverted(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "deposit.HandleCurrencyConverted",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing DepositCurrencyConverted event")
dcc, ok := e.(*events.DepositCurrencyConverted)
if !ok {
log.Warn(
"🚫 [SKIP] Skipping: unexpected event type",
"event", e,
)
return nil
}
log = log.With(
"user_id", dcc.UserID,
"account_id", dcc.AccountID,
"transaction_id", dcc.TransactionID,
"correlation_id", dcc.CorrelationID,
)
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
log.Error(
"Failed to get account repository",
"error", err,
)
return err
}
// Get account for validation
accountID := dcc.AccountID
userID := dcc.UserID
// Log detailed information about the event for debugging
log.Debug(
"Processing DepositCurrencyConverted event",
"event_type", fmt.Sprintf("%T", dcc),
"transaction_id", dcc.TransactionID,
"user_id", userID,
"account_id", accountID,
"converted_amount", dcc.ConvertedAmount,
"has_original_request", dcc.OriginalRequest != nil,
"original_request_type", fmt.Sprintf("%T", dcc.OriginalRequest),
)
// Check if OriginalRequest is nil
if dcc.OriginalRequest == nil {
log.Warn(
"[SKIP] Original request is missing",
"event_id", dcc.ID,
)
return nil
}
// Type assert the OriginalRequest to DepositRequested
dr, ok := dcc.OriginalRequest.(*events.DepositRequested)
if !ok {
log.Warn(
"[SKIP] Unexpected original request type",
"original_request_type", fmt.Sprintf("%T", dcc.OriginalRequest),
)
return nil
}
// Get the account
accRead, err := accRepo.Get(ctx, accountID)
if err != nil {
log.Warn(
"Failed to get account",
"error", err,
"account_id", accountID,
)
return fmt.Errorf("failed to get account: %w", err)
}
acc, err := mapper.MapAccountReadToDomain(accRead)
if err != nil {
log.Error(
"Failed to map account read to domain",
"error", err,
)
return fmt.Errorf("failed to map account read to domain: %w", err)
}
// Perform domain validation
if err := acc.ValidateDeposit(userID, dcc.ConvertedAmount); err != nil {
log.Warn(
"Domain validation failed",
"error", err,
)
// Create the failed event
df := events.NewDepositFailed(dr, err.Error())
_ = bus.Emit(ctx, df)
return nil
}
dv := events.NewDepositValidated(dcc)
log.Info(
"✅ [SUCCESS] Domain validation passed, emitting",
"event_type", dv.Type(),
)
if err := bus.Emit(ctx, dv); err != nil {
log.Warn(
"Failed to emit",
"event_type", dv.Type(),
"error", err,
)
return fmt.Errorf("failed to emit %s: %w", dv.Type(), err)
}
log.Info(
"📤 [EMITTED] event",
"event_id", dv.ID,
"event_type", dv.Type(),
"transaction_id", dv.TransactionID,
"correlation_id", dv.CorrelationID,
)
return nil
}
}
package deposit
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
)
// HandleRequested handles DepositRequested events by validating and persisting the deposit.
// This follows the new event flow pattern:
// HandleRequested -> HandleRequested (validate and persist).
func HandleRequested(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "deposit.HandleRequested",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing DepositRequested event")
// Type assert to get the deposit request
dr, ok := e.(*events.DepositRequested)
if !ok {
log.Error(
"❌ [ERROR] Unexpected event type",
"expected", "DepositRequested",
"got", e.Type(),
)
return fmt.Errorf("unexpected event type: %s", e.Type())
}
log = log.With(
"user_id", dr.UserID,
"account_id", dr.AccountID,
"amount", dr.Amount.String(),
"correlation_id", dr.CorrelationID,
)
// Validate the deposit request
if err := dr.Validate(); err != nil {
log.Error(
"❌ [ERROR] Deposit validation failed",
"error", err,
)
// Emit failed event
df := events.NewDepositFailed(dr, err.Error())
if err := bus.Emit(ctx, df); err != nil {
log.Error("❌ [ERROR] Failed to emit DepositFailed event", "error", err)
}
return nil
}
accountRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
log.Error("❌ [ERROR] Failed to get account repository", "error", err)
return fmt.Errorf("failed to get account repository: %w", err)
}
account, err := accountRepo.Get(ctx, dr.AccountID)
if err != nil {
log.Error("❌ [ERROR] Failed to get account", "error", err)
return fmt.Errorf("failed to get account: %w", err)
}
// Create transaction ID if not provided
if dr.TransactionID == uuid.Nil {
dr.TransactionID = uuid.New()
}
// Persist the deposit transaction
txID := dr.TransactionID
if err := persistDepositTransaction(ctx, uow, dr, log); err != nil {
log.Error(
"❌ [ERROR] Failed to persist deposit transaction",
"error", err,
"transaction_id", txID,
)
// Emit failed event
df := events.NewDepositFailed(dr, fmt.Sprintf("failed to persist transaction: %v", err))
if err := bus.Emit(ctx, df); err != nil {
log.Error(
"❌ [ERROR] Failed to emit DepositFailed event",
"error", err,
)
}
return nil
}
log.Info(
"✅ [SUCCESS] Deposit validated and persisted",
"transaction_id", txID,
)
// Emit CurrencyConversionRequested event
log.Debug(
"🔧 CurrencyConversionRequested event created",
"deposit_request", fmt.Sprintf("%+v", *dr),
"original_request_type", fmt.Sprintf("%T", *dr))
ccr := events.NewCurrencyConversionRequested(
dr.FlowEvent,
*dr,
events.WithConversionAmount(dr.Amount),
events.WithConversionTo(money.Code(account.Currency)),
events.WithConversionTransactionID(txID),
)
log.Debug(
"🔧 CurrencyConversionRequested event created",
"ccr_original_request_nil", ccr.OriginalRequest == nil,
"ccr_original_request_type", fmt.Sprintf("%T", ccr.OriginalRequest),
"ccr_transaction_id", ccr.TransactionID,
)
if err := bus.Emit(ctx, ccr); err != nil {
log.Error("Failed to emit CurrencyConversionRequested event", "error", err)
return nil
}
log.Info(
"📤 [EMITTED] event",
"event_id", ccr.ID,
"event_type", ccr.Type(),
)
return nil
}
}
// persistDepositTransaction persists the deposit transaction to the database
func persistDepositTransaction(
ctx context.Context,
uow repository.UnitOfWork,
dr *events.DepositRequested,
logger *slog.Logger,
) error {
return uow.Do(ctx, func(uow repository.UnitOfWork) error {
// Get the transaction repository
txRepo, err := common.GetTransactionRepository(uow, logger)
if err != nil {
return fmt.Errorf("failed to get transaction repository: %w", err)
}
// Create the transaction record using domain object
tx := dto.TransactionCreate{
ID: dr.TransactionID,
UserID: dr.UserID,
AccountID: dr.AccountID,
Amount: dr.Amount.Amount(),
Status: "created",
MoneySource: "deposit",
Currency: dr.Amount.Currency().String(),
// PaymentID is intentionally omitted to prevent unique constraint violations
}
if err := txRepo.Create(ctx, tx); err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}
return nil
})
}
package deposit
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/amirasaad/fintech/pkg/repository"
)
func HandleValidated(
bus eventbus.Bus,
uow repository.UnitOfWork,
paymentProvider payment.Payment,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "deposit.HandleValidated",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing DepositValidated event")
dv, ok := e.(*events.DepositValidated)
if !ok {
log.Error(
"unexpected event type",
"event_type", fmt.Sprintf("%T", e),
)
return errors.New("unexpected event type")
}
log = log.With(
"user_id", dv.UserID,
"account_id", dv.AccountID,
"transaction_id", dv.TransactionID,
"correlation_id", dv.CorrelationID,
)
pi := events.NewPaymentInitiated(&dv.FlowEvent, func(pi *events.PaymentInitiated) {
pi.TransactionID = dv.TransactionID
pi.Amount = dv.ConvertedAmount
pi.UserID = dv.UserID
pi.AccountID = dv.AccountID
pi.CorrelationID = dv.CorrelationID
})
log.Info(
"📤 [EMIT] Emitting event",
"event_type", pi.Type(),
)
if err := bus.Emit(ctx, pi); err != nil {
log.Warn(
"Failed to emit",
"event_type", pi.Type(),
"error", err,
)
return fmt.Errorf("failed to emit %s: %w", pi.Type(), err)
}
log.Info(
"📤 [EMITTED] event",
"event_id", pi.ID,
"event_type", pi.Type(),
"payment_id", pi.PaymentID,
"status", pi.Status,
)
return nil
}
}
package transfer
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
)
// HandleCompleted handles the final, atomic persistence of a transfer.
func HandleCompleted(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "transfer.HandleCompleted",
"event_type", e.Type(),
)
// 1. Defensive: Check event type and structure
te, ok := e.(*events.TransferCompleted)
if !ok {
log.Error(
"❌ [DISCARD] Unexpected event type",
"event", e,
)
return fmt.Errorf("unexpected event type: %T", e)
}
tr, ok := te.OriginalRequest.(*events.TransferRequested)
if !ok {
log.Error(
"❌ [DISCARD] Unexpected original request type",
"event", te,
)
return fmt.Errorf("unexpected original request type: %T", te)
}
log = log.With("correlation_id", tr.CorrelationID)
log.Info(
"🟢 [START] Received event",
"event", te,
)
if tr.AccountID == uuid.Nil || tr.DestAccountID == uuid.Nil || tr.Amount.IsZero() {
log.Error(
"❌ [DISCARD] Malformed final persistence event",
"event", te,
)
return fmt.Errorf("malformed final persistence event: %v", te)
}
// 2. Atomic Final HandleCompleted
txInID := uuid.New()
txOutID := tr.TransactionID
if err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
return fmt.Errorf("failed to get transaction repo: %w", err)
}
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
return fmt.Errorf("failed to get account repo: %w", err)
}
sourceAcc, err := accRepo.Get(ctx, tr.AccountID)
if err != nil {
return fmt.Errorf("could not find source account: %w", err)
}
destAcc, err := accRepo.Get(ctx, tr.DestAccountID)
if err != nil {
return fmt.Errorf("could not find destination account: %w", err)
}
sourceBalance, err := money.New(sourceAcc.Balance, tr.Amount.Currency())
if err != nil {
return fmt.Errorf("could not create money for source balance: %w", err)
}
destBalance, err := money.New(destAcc.Balance, tr.Amount.Currency())
if err != nil {
return fmt.Errorf("could not create money for dest balance: %w", err)
}
newSourceMoney, err := sourceBalance.Subtract(tr.Amount)
if err != nil {
return fmt.Errorf("could not subtract from source balance: %w", err)
}
newDestMoney, err := destBalance.Add(tr.Amount)
if err != nil {
return fmt.Errorf("could not add to dest balance: %w", err)
}
newSourceBalance := newSourceMoney.Amount()
newDestBalance := newDestMoney.Amount()
if err := accRepo.Update(
ctx,
tr.AccountID,
dto.AccountUpdate{Balance: &newSourceBalance},
); err != nil {
return fmt.Errorf("failed to debit source account: %w", err)
}
if err := accRepo.Update(
ctx,
tr.DestAccountID,
dto.AccountUpdate{Balance: &newDestBalance},
); err != nil {
return fmt.Errorf("failed to credit destination account: %w", err)
}
completedStatus := "completed"
if err := txRepo.Update(
ctx,
txOutID,
dto.TransactionUpdate{Status: &completedStatus},
); err != nil {
return fmt.Errorf(
"failed to update transaction status to completed: %w", err,
)
}
return nil
}); err != nil {
log.Error(
"❌ [ERROR] Final persistence transaction failed",
"error", err,
)
tf := events.NewTransferFailed(tr, "PersistenceFailed: "+err.Error())
return bus.Emit(ctx, tf)
}
log.Info(
"✅ [SUCCESS] Final transfer persistence complete",
"tx_out_id", txOutID,
"tx_in_id", txInID,
)
tc := events.NewTransferCompleted(tr)
log.Info(
"📤 [EMIT] Emitting event",
"event_type", tc.Type(),
)
return bus.Emit(ctx, tc)
}
}
package transfer
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/mapper"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleCurrencyConverted performs domain validation after currency conversion for transfers.
// Emits TransferBusinessValidated event to trigger final persistence.
func HandleCurrencyConverted(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "transfer.HandleCurrencyConverted",
"event_type", e.Type(),
)
log.Info(
"🟢 [HANDLER] HandleCurrencyConverted received event",
"event_type", e.Type(),
)
// 1. Defensive: Check event type and structure
tcc, ok := e.(*events.TransferCurrencyConverted)
if !ok {
log.Error(
"unexpected event type",
"event", e,
)
return fmt.Errorf("unexpected event type: %T", e)
}
log = log.With(
"user_id", tcc.UserID,
"account_id", tcc.AccountID,
"transaction_id", tcc.TransactionID,
"correlation_id", tcc.CorrelationID,
)
// 2. Get account repository
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
log.Error(
"failed to get account repository",
"error", err,
)
return err
}
// Get source account DTO
sourceAccDto, err := accRepo.Get(ctx, tcc.AccountID)
if err != nil {
log.Warn(
"source account not found",
"account_id", tcc.AccountID,
"error", err,
)
return bus.Emit(ctx, events.NewTransferFailed(
tcc.OriginalRequest.(*events.TransferRequested),
"source account not found: "+err.Error(),
))
}
// Map DTO to domain model
sourceAcc, err := mapper.MapAccountReadToDomain(sourceAccDto)
if err != nil {
log.Error(
"failed to map account read to domain",
"error", err,
)
return fmt.Errorf("failed to map account read to domain: %w", err)
}
// Get TransferRequested fields once
tr, ok := tcc.OriginalRequest.(*events.TransferRequested)
if !ok {
log.Error(
"unexpected event type",
"event", tcc.OriginalRequest,
)
return fmt.Errorf("unexpected event type: %T", tcc.OriginalRequest)
}
// Get destination account DTO
destAccDto, err := accRepo.Get(ctx, tr.DestAccountID)
if err != nil {
log.Warn(
"destination account not found",
"account_id", tr.DestAccountID,
"error", err,
)
return bus.Emit(ctx, events.NewTransferFailed(
tr,
"destination account not found: "+err.Error(),
))
}
// Map DTO to domain model
destAcc, err := mapper.MapAccountReadToDomain(destAccDto)
if err != nil {
log.Error(
"failed to map destination account read to domain",
"error", err,
)
return fmt.Errorf("failed to map destination account read to domain: %w", err)
}
// Perform domain validation
if err := sourceAcc.ValidateTransfer(
tcc.UserID,
// Pass the user ID of the destination account owner for validation
destAcc.UserID,
destAcc,
tcc.ConvertedAmount,
); err != nil {
log.Warn(
"domain validation failed",
"reason", err,
)
return bus.Emit(ctx, events.NewTransferFailed(
tr,
err.Error(),
))
}
// 3. Emit success event
tv := events.NewTransferValidated(
tcc,
)
log.Info(
"✅ [SUCCESS] Domain validation passed, emitting",
"event_type", tv.Type(),
)
return bus.Emit(ctx, tv)
}
}
package transfer
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleRequested handles TransferValidatedEvent,
// creates an initial 'pending' transaction, and triggers conversion.
func HandleRequested(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "HandleRequested",
"event_type", e.Type(),
)
tr, ok := e.(*events.TransferRequested)
if !ok {
log.Error(
"❌ [DISCARD] Unexpected event type",
"event", e,
)
return fmt.Errorf("unexpected event type: %T", e)
}
log = log.With("correlation_id", tr.CorrelationID)
log.Info("🟢 [START] Received event", "event", tr)
if err := tr.Validate(); err != nil {
log.Error(
"❌ [DISCARD] Malformed validated event",
"error", err,
)
return err
}
// 2. Persist initial transaction (tx_out) atomically
txID := tr.ID
var destAccountRead *dto.AccountRead
err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
return fmt.Errorf("failed to get repo: %w", err)
}
accountRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
return fmt.Errorf("failed to get account repo: %w", err)
}
destAccountRead, err = accountRepo.Get(ctx, tr.DestAccountID)
if err != nil {
return fmt.Errorf("failed to get destination account: %w", err)
}
return txRepo.Create(ctx, dto.TransactionCreate{
ID: txID,
UserID: tr.UserID,
AccountID: tr.AccountID,
Amount: tr.Amount.Negate().Amount(),
Currency: tr.Amount.Currency().String(),
Status: "pending",
MoneySource: "transfer",
})
})
if err != nil {
log.Error("❌ [ERROR] Failed to create initial transaction", "error", err)
return err
}
log.Info("✅ [SUCCESS] Initial 'pending' transaction created", "transaction_id", txID)
// 3. Emit event to trigger currency conversion
ccr := events.NewCurrencyConversionRequested(
tr.FlowEvent,
tr,
events.WithConversionAmount(tr.Amount),
events.WithConversionTo(money.Code(destAccountRead.Currency)),
events.WithConversionTransactionID(txID),
)
log.Info(
"📤 [EMIT] Emitting event",
"event_type", ccr.Type(),
)
if err := bus.Emit(ctx, ccr); err != nil {
log.Error(
"❌ [ERROR] Failed to emit CurrencyConversionRequested",
"error", err,
)
}
return nil
}
}
package withdraw
import (
"context"
"errors"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/mapper"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleCurrencyConverted performs domain validation after currency conversion for withdrawals.
// Emits WithdrawBusinessValidated event to trigger payment initiation.
func HandleCurrencyConverted(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "withdraw.CurrencyConverted",
"event_type", e.Type(),
)
log.Info("🟢 [START] Received event", "event", e)
wcc, ok := e.(*events.WithdrawCurrencyConverted)
if !ok {
log.Debug(
"🚫 skipping: unexpected event type in WithdrawCurrencyConverted",
"event", e,
)
return nil
}
wr, ok := wcc.OriginalRequest.(*events.WithdrawRequested)
if !ok {
log.Debug(
"🚫 skipping: unexpected event type in WithdrawCurrencyConverted",
"event", e,
)
return nil
}
log = log.With(
"user_id", wcc.UserID,
"account_id", wcc.AccountID,
"transaction_id", wcc.TransactionID,
"correlation_id", wcc.CorrelationID,
)
if wcc.FlowType != "withdraw" {
log.Debug(
"🚫 skipping: not a withdraw flow",
"flow_type", wcc.FlowType,
)
return nil
}
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
return errors.New("invalid account repository type")
}
accRead, err := accRepo.Get(ctx, wcc.AccountID)
if err != nil && !errors.Is(err, account.ErrAccountNotFound) {
log.Error(
"failed to get account",
"error", err,
"account_id", wcc.AccountID,
)
return err
}
if accRead == nil {
log.Error(
"account not found",
"account_id", wcc.AccountID,
)
return account.ErrAccountNotFound
}
acc, err := mapper.MapAccountReadToDomain(accRead)
if err != nil {
log.Error(
"failed to map account read to domain",
"error", err,
)
return err
}
// Perform domain validation
if err := acc.ValidateWithdraw(wcc.UserID, wcc.ConvertedAmount); err != nil {
log.Error(
"domain validation failed",
"transaction_id", wcc.TransactionID,
"error", err,
"user_id", wcc.UserID,
"account_id", wcc.AccountID,
"amount", wcc.ConvertedAmount.String(),
)
wf := events.NewWithdrawFailed(
wr,
err.Error(),
)
return bus.Emit(ctx, wf)
}
log.Info(
"✅ [SUCCESS] Domain validation passed, emitting WithdrawBusinessValidated",
"user_id", wcc.UserID,
"account_id", wcc.AccountID,
"amount", wcc.ConvertedAmount.Amount(),
"currency", wcc.ConvertedAmount.Currency().String(),
"correlation_id", wcc.CorrelationID,
)
// Emit WithdrawBusinessValidated event
wv := events.NewWithdrawValidated(wcc)
log.Info(
"📤 [EMIT] Emitting event",
"event_type", wv.Type(),
"correlation_id", wcc.CorrelationID.String(),
)
return bus.Emit(ctx, wv)
}
}
package withdraw
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
)
// HandleRequested handles WithdrawRequested events by validating and persisting the withdraw.
// This follows the new event flow pattern: Requested -> HandleRequested (validate and persist).
func HandleRequested(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "withdraw.HandleRequested",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing WithdrawRequested event")
// Type assert to get the withdraw request
wr, ok := e.(*events.WithdrawRequested)
if !ok {
log.Error(
"❌ [ERROR] Unexpected event type",
"expected", "WithdrawRequested",
"got", e.Type(),
)
return fmt.Errorf("unexpected event type: %s", e.Type())
}
log = log.With(
"user_id", wr.UserID,
"account_id", wr.AccountID,
"transaction_id", wr.TransactionID,
"amount", wr.Amount,
"correlation_id", wr.CorrelationID,
)
// Validate the withdraw request
if err := wr.Validate(); err != nil {
log.Error(
"❌ [ERROR] Withdraw validation failed",
"error", err,
)
// Emit failed event
wf := events.NewWithdrawFailed(
wr,
err.Error(),
)
if err := bus.Emit(ctx, wf); err != nil {
log.Error(
"❌ [ERROR] Failed to emit WithdrawFailed event",
"error", err,
)
}
return nil
}
// Create transaction ID
txID := uuid.New()
// Persist the withdraw transaction
if err := persistWithdrawTransaction(ctx, uow, wr, txID, log); err != nil {
log.Error(
"❌ [ERROR] Failed to persist withdraw transaction",
"error", err,
"transaction_id", txID,
)
// Emit failed event
wf := events.NewWithdrawFailed(
wr,
fmt.Sprintf("failed to persist transaction: %v", err),
)
if err := bus.Emit(ctx, wf); err != nil {
log.Error(
"❌ [ERROR] Failed to emit WithdrawFailed event",
"error", err,
)
}
return nil
}
log.Info("✅ [SUCCESS] Withdraw validated and persisted", "transaction_id", txID)
accountRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
log.Error(
"❌ [ERROR] Failed to get account repository",
"error", err,
)
return fmt.Errorf("failed to get account repository: %w", err)
}
account, err := accountRepo.Get(ctx, wr.AccountID)
if err != nil {
log.Error(
"❌ [ERROR] Failed to get account",
"error", err,
)
return fmt.Errorf("failed to get account: %w", err)
}
if account == nil {
log.Error(
"❌ [ERROR] Account not found",
"account_id", wr.AccountID,
)
return fmt.Errorf("account not found: %s", wr.AccountID)
}
// Emit CurrencyConversionRequested event
ccr := events.NewCurrencyConversionRequested(
wr.FlowEvent,
wr,
events.WithConversionTransactionID(txID),
events.WithConversionAmount(wr.Amount),
events.WithConversionTo(money.Code(account.Currency)),
)
if err := bus.Emit(ctx, ccr); err != nil {
log.Error(
"❌ [ERROR] Failed to emit CurrencyConversionRequested event",
"error", err,
)
return fmt.Errorf(
"failed to emit CurrencyConversionRequested event: %w", err,
)
}
log.Info(
"📤 [EMITTED] event",
"event_id", ccr.ID,
"event_type", ccr.Type(),
)
return nil
}
}
// persistWithdrawTransaction persists the withdraw transaction to the database
func persistWithdrawTransaction(
ctx context.Context,
uow repository.UnitOfWork,
wr *events.WithdrawRequested,
txID uuid.UUID,
log *slog.Logger,
) error {
return uow.Do(ctx, func(uow repository.UnitOfWork) error {
// Get the transaction repository
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
return fmt.Errorf("failed to get transaction repository: %w", err)
}
// Create the transaction record using DTO
txCreate := dto.TransactionCreate{
ID: txID,
UserID: wr.UserID,
AccountID: wr.AccountID,
Amount: wr.Amount.Negate().Amount(),
Currency: wr.Amount.Currency().String(),
Status: "created",
MoneySource: "withdraw",
}
if err := txRepo.Create(ctx, txCreate); err != nil {
return fmt.Errorf("failed to create transaction: %w", err)
}
return nil
})
}
package withdraw
import (
"context"
"errors"
"fmt"
"log/slog"
"strings"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleValidated handles WithdrawValidated events by initiating a payout.
// It's responsible for starting the external payout process to the user's connected account.
// The function follows these steps:
// 1. Validates the withdrawal request
// 2. Retrieves user's Stripe Connect account
// 3. Prepares and initiates the payout
// 4. Emits appropriate events for the transaction lifecycle
func HandleValidated(
bus eventbus.Bus,
uow repository.UnitOfWork,
paymentProvider payment.Payment,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "withdraw.HandleValidated",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing WithdrawValidated event")
// Type assert to get the withdraw validated event
wv, ok := e.(*events.WithdrawValidated)
if !ok {
err := fmt.Errorf("expected WithdrawValidated event, got %T", e)
log.Error("unexpected event type", "error", err)
return err
}
// Get the original withdraw request to access bank details
req, ok := wv.OriginalRequest.(*events.WithdrawRequested)
if !ok || req == nil {
err := errors.New("missing or invalid original withdraw request")
log.Error("invalid original request", "error", err)
return fmt.Errorf("invalid withdrawal request: %w", err)
}
// Validate the withdrawal amount is positive
if wv.ConvertedAmount.Amount() <= 0 {
err := fmt.Errorf("invalid withdrawal amount: %d", wv.ConvertedAmount.Amount())
log.Error("validation failed", "error", err)
return err
}
log = log.With(
"user_id", wv.UserID,
"account_id", wv.AccountID,
"transaction_id", wv.TransactionID,
"correlation_id", wv.CorrelationID,
)
userRepo, err := common.GetUserRepository(uow, log)
if err != nil {
log.Error("Failed to get user repo", "error", err)
return err
}
// Get user details to check for Stripe Connect account
user, err := userRepo.Get(ctx, req.UserID)
if err != nil {
err = fmt.Errorf("failed to get user details: %w", err)
log.Error("validation failed", "error", err)
return err
}
// Get the user's full name for the payout
var firstName, lastName string
if user.Names != "" {
names := strings.Split(user.Names, " ")
if len(names) < 2 {
return fmt.Errorf("user names are required to create a Stripe Connect account")
}
firstName = names[0]
lastName = names[1]
}
// Prepare the payout parameters
description := fmt.Sprintf("Withdrawal from account %s", wv.AccountID)
metadata := map[string]string{
"correlation_id": wv.CorrelationID.String(),
"flow_type": "withdraw",
"stripe_account_id": user.StripeConnectAccountID,
"bank_account_last4": lastFourDigits(req.BankAccountNumber),
"bank_routing": maskSensitive(req.RoutingNumber, 4),
"user_id": wv.UserID.String(),
"account_id": wv.AccountID.String(),
"user_email": user.Email,
"user_first_name": firstName,
"user_last_name": lastName,
"amount": fmt.Sprintf("%.2f", wv.ConvertedAmount.AmountFloat()),
"currency": wv.ConvertedAmount.Currency().String(),
}
payoutParams := &payment.InitiatePayoutParams{
UserID: wv.UserID,
AccountID: wv.AccountID,
PaymentProviderID: user.StripeConnectAccountID,
TransactionID: wv.TransactionID,
Amount: wv.ConvertedAmount.Amount(),
Currency: strings.ToLower(wv.ConvertedAmount.Currency().String()),
Description: description,
Metadata: metadata,
Destination: payment.PayoutDestination{
Type: payment.PayoutDestinationBankAccount,
BankAccount: &payment.BankAccountDetails{
AccountNumber: req.BankAccountNumber,
RoutingNumber: req.RoutingNumber,
},
},
}
// Log the payout initiation attempt
log.Info("Initiating payout",
"amount", fmt.Sprintf("%.2f", float64(payoutParams.Amount)/100),
"currency", payoutParams.Currency,
"destination_type", payoutParams.Destination.Type,
)
// Initiate the payout
payout, err := paymentProvider.InitiatePayout(ctx, payoutParams)
if err != nil {
log.Error("Failed to initiate payout", "error", err)
// Emit WithdrawFailed event with detailed error information
errMsg := fmt.Sprintf("payout initiation failed: %v", err)
wf := events.NewWithdrawFailed(
req,
errMsg,
events.WithWithdrawFailureReason(err.Error()),
)
if emitErr := bus.Emit(ctx, wf); emitErr != nil {
log.Error(
"Failed to emit WithdrawFailed event",
"error", emitErr,
"original_error", err,
)
// Preserve both errors in the error chain for proper error inspection
// The emit error is more critical (we couldn't notify about the failure),
// but we also preserve the original payout error for context
return errors.Join(
fmt.Errorf("failed to emit WithdrawFailed event: %w", emitErr),
fmt.Errorf("original payout error: %w", err),
)
}
// Return a user-friendly error
return fmt.Errorf("could not process withdrawal: %w", err)
}
if err := userRepo.Update(ctx, wv.UserID, &dto.UserUpdate{
StripeConnectAccountID: &payout.PaymentProviderID,
}); err != nil {
log.Error("Failed to update user", "error", err)
return fmt.Errorf("failed to update user: %w", err)
}
log.Info("Payout initiated successfully",
"payout_id", payout.PayoutID,
"status", payout.Status,
)
// Prepare payment processed event with all required details
paymentID := payout.PayoutID
paymentStatus := string(payout.Status)
// Create a copy of the flow event with updated fields
flowEvent := wv.FlowEvent
flowEvent.FlowType = "withdraw"
// Create a new PaymentProcessed event with the payout details
// without chaining methods that change the static type
pp := events.NewPaymentProcessed(&flowEvent, func(pp *events.PaymentProcessed) {
pp.WithPaymentID(paymentID)
pp.WithStatus(paymentStatus)
pp.WithAmount(wv.ConvertedAmount)
pp.WithTransactionID(wv.TransactionID)
})
// Emit the payment processed event
if err := bus.Emit(ctx, pp); err != nil {
log.Error("Failed to emit Payment.Processed event",
"error", err,
"payment_id", paymentID,
)
return fmt.Errorf("failed to emit Payment.Processed event: %w", err)
}
log.Info("📤 [EMITTED] event",
"event_id", pp.ID,
"event_type", pp.Type(),
"payment_id", paymentID,
"status", paymentStatus,
)
return nil
}
}
// lastFourDigits returns the last 4 digits of a bank account number
func lastFourDigits(accountNumber string) string {
if len(accountNumber) <= 4 {
return accountNumber
}
return accountNumber[len(accountNumber)-4:]
}
// maskSensitive masks all but the last n characters of a string
func maskSensitive(s string, visibleChars int) string {
if len(s) <= visibleChars {
return strings.Repeat("*", len(s))
}
return strings.Repeat("*", len(s)-visibleChars) + s[len(s)-visibleChars:]
}
package common
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/amirasaad/fintech/pkg/repository/account"
"github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/amirasaad/fintech/pkg/repository/user"
"github.com/google/uuid"
"gorm.io/gorm"
)
var ErrInvalidRepositoryType = errors.New("invalid repository type")
func GetAccountRepository(
uow repository.UnitOfWork,
log *slog.Logger,
) (
account.Repository,
error,
) {
accRepoAny, err := uow.GetRepository(
(*account.Repository)(nil),
)
if err != nil {
log.Error(
"failed to get account repository",
"error", err,
)
return nil, err
}
if accRepo, ok := accRepoAny.(account.Repository); ok {
return accRepo, nil
}
return nil, ErrInvalidRepositoryType
}
func GetTransactionRepository(
uow repository.UnitOfWork,
log *slog.Logger,
) (
transaction.Repository,
error,
) {
txRepoAny, err := uow.GetRepository(
(*transaction.Repository)(nil),
)
if err != nil {
log.Error(
"failed to get transaction repository",
"error", err,
)
return nil, err
}
if txRepo, ok := txRepoAny.(transaction.Repository); ok {
return txRepo, nil
}
return nil, ErrInvalidRepositoryType
}
func GetUserRepository(
uow repository.UnitOfWork,
log *slog.Logger,
) (
user.Repository,
error,
) {
userRepoAny, err := uow.GetRepository(
(*user.Repository)(nil),
)
if err != nil {
log.Error(
"failed to get user repository",
"error", err,
)
return nil, err
}
if userRepo, ok := userRepoAny.(user.Repository); ok {
return userRepo, nil
}
return nil, ErrInvalidRepositoryType
}
// TransactionLookupResult contains transaction lookup result
type TransactionLookupResult struct {
Transaction *dto.TransactionRead
TransactionID uuid.UUID
Found bool
Error error
}
// LookupTransactionByPaymentOrID looks up a transaction by payment ID or transaction ID
func LookupTransactionByPaymentOrID(
ctx context.Context,
txRepo transaction.Repository,
paymentID *string,
transactionID uuid.UUID,
log *slog.Logger,
) TransactionLookupResult {
result := TransactionLookupResult{TransactionID: transactionID}
if paymentID != nil && *paymentID != "" {
tx, err := txRepo.GetByPaymentID(ctx, *paymentID)
if err == nil {
result.Transaction = tx
result.TransactionID = tx.ID
result.Found = true
return result
}
if errors.Is(err, gorm.ErrRecordNotFound) && transactionID != uuid.Nil {
tx, err = txRepo.Get(ctx, transactionID)
if err == nil {
result.Transaction = tx
result.Found = true
return result
}
}
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn(
"⚠️ Transaction not found",
"payment_id", *paymentID,
"transaction_id", transactionID,
)
result.Found = false
return result
}
result.Error = fmt.Errorf("failed to get transaction: %w", err)
return result
}
if transactionID != uuid.Nil {
tx, err := txRepo.Get(ctx, transactionID)
if err == nil {
result.Transaction = tx
result.Found = true
return result
}
if errors.Is(err, gorm.ErrRecordNotFound) {
log.Warn("⚠️ Transaction not found", "transaction_id", transactionID)
result.Found = false
return result
}
result.Error = fmt.Errorf("failed to get transaction: %w", err)
return result
}
log.Warn("⚠️ No transaction identifiers provided")
result.Found = false
return result
}
package common
import (
"context"
"log/slog"
"sync"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"golang.org/x/sync/singleflight"
)
// KeyExtractor extracts an idempotency key from an event
type KeyExtractor func(events.Event) string
// IdempotencyTracker tracks processed events by key
type IdempotencyTracker struct {
processed sync.Map
inflight singleflight.Group
}
// NewIdempotencyTracker creates a new idempotency tracker
func NewIdempotencyTracker() *IdempotencyTracker {
return &IdempotencyTracker{}
}
// Store marks a key as processed
func (t *IdempotencyTracker) Store(key string) {
t.processed.Store(key, struct{}{})
}
// Delete removes a key from the tracker
func (t *IdempotencyTracker) Delete(key string) {
t.processed.Delete(key)
}
func (t *IdempotencyTracker) IsProcessed(key string) bool {
_, ok := t.processed.Load(key)
return ok
}
// WithIdempotency wraps a handler with idempotency checking middleware.
// The middleware checks if the event has been processed before calling the handler,
// and marks it as processed after successful execution.
func WithIdempotency(
handler eventbus.HandlerFunc,
tracker *IdempotencyTracker,
keyExtractor KeyExtractor,
handlerName string,
logger *slog.Logger,
) eventbus.HandlerFunc {
if logger == nil {
logger = slog.Default()
}
return func(ctx context.Context, e events.Event) error {
key := keyExtractor(e)
if key == "" {
// No key extracted, proceed without idempotency check
return handler(ctx, e)
}
log := logger.With(
"handler", handlerName,
"event_type", e.Type(),
"idempotency_key", key,
)
// Check if already processed (before calling handler)
if tracker.IsProcessed(key) {
log.Info("🔁 [SKIP] Event already processed")
return nil
}
_, err, _ := tracker.inflight.Do(key, func() (any, error) {
if tracker.IsProcessed(key) {
return nil, nil
}
if err := handler(ctx, e); err != nil {
return nil, err
}
tracker.Store(key)
return nil, nil
})
if err != nil {
return err
}
return nil
}
}
// Package conversion handles currency conversion events and persistence logic.
package conversion
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/google/uuid"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleCurrencyConverted persists CurrencyConverted events.
func HandleCurrencyConverted(
uow repository.UnitOfWork,
logger *slog.Logger) func(
context.Context,
events.Event,
) error {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "conversion.HandleCurrencyConverted",
"event_type", e.Type(),
)
log.Info("🟢 [START] Event received event")
cc, ok := e.(*events.CurrencyConverted)
if !ok {
log.Warn("unexpected event",
"event", e,
"event_type", fmt.Sprintf("%T", e),
)
// return nil to skip processing
return nil
}
// Validate TransactionID
if cc.TransactionID == uuid.Nil {
log.Warn("TransactionID is nil in CurrencyConverted event",
"user_id", cc.UserID,
"account_id", cc.AccountID,
"correlation_id", cc.CorrelationID,
)
return nil
}
log = log.With(
"user_id", cc.UserID,
"account_id", cc.AccountID,
"transaction_id", cc.TransactionID,
"correlation_id", cc.CorrelationID,
)
log.Info(
"💾 [PROGRESS] persisting conversion data",
"transaction_id", cc.TransactionID,
)
// Validate that we have the required data before persisting
if cc.ConversionInfo == nil {
log.Warn("ConversionInfo is nil, cannot persist conversion data")
return nil
}
// Persist conversion result (stubbed for now)
if err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
transactionRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
return err
}
// Create money object for transaction amount
amount := cc.ConvertedAmount.Amount()
currency := cc.ConvertedAmount.Currency().String()
return transactionRepo.Update(ctx, cc.TransactionID, dto.TransactionUpdate{
Amount: &amount,
Currency: ¤cy,
OriginalCurrency: &cc.ConversionInfo.FromCurrency,
TargetCurrency: &cc.ConversionInfo.ToCurrency,
ConversionRate: &cc.ConversionInfo.Rate,
})
}); err != nil {
log.Error("Failed to persist conversion data",
"error", err,
"transaction_id", cc.TransactionID,
"user_id", cc.UserID,
"account_id", cc.AccountID,
)
return err
}
log.Info(
"✅ [SUCCESS] conversion persisted",
"transaction_id", cc.TransactionID,
)
return nil
// NOTE: Conversion fee application is deferred to a future enhancement.
// This will be implemented as part of the fee calculation system.
}
}
package conversion
import (
"github.com/amirasaad/fintech/pkg/domain/events"
)
// DepositEventFactory creates a DepositCurrencyConverted event from a CurrencyConverted event.
type DepositEventFactory struct{}
// CreateNextEvent creates DepositCurrencyConverted with converted event
func (f *DepositEventFactory) CreateNextEvent(
cc *events.CurrencyConverted,
) events.Event {
return events.NewDepositCurrencyConverted(cc)
}
// WithdrawEventFactory creates a WithdrawCurrencyConverted event from a CurrencyConverted event.
type WithdrawEventFactory struct{}
// CreateNextEvent creates WithdrawCurrencyConverted with converted event
func (f *WithdrawEventFactory) CreateNextEvent(
cc *events.CurrencyConverted,
) events.Event {
return events.NewWithdrawCurrencyConverted(cc)
}
// TransferEventFactory creates a TransferCurrencyConverted event from a CurrencyConverted event.
type TransferEventFactory struct{}
// CreateNextEvent creates TransferCurrencyConverted with converted event
func (f *TransferEventFactory) CreateNextEvent(
cc *events.CurrencyConverted,
) events.Event {
return events.NewTransferCurrencyConverted(cc)
}
// Package conversion handles currency conversion events and logic.
package conversion
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
exchangeprovider "github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/registry"
"github.com/amirasaad/fintech/pkg/service/exchange"
)
// HandleRequested processes ConversionRequestedEvent and
// delegates to a flow-specific factory to create the next event.
func HandleRequested(
bus eventbus.Bus,
exchangeRegistry registry.Provider,
exchangeRateProvider exchangeprovider.Exchange,
logger *slog.Logger,
factories map[string]EventFactory,
) func(ctx context.Context, e events.Event) error {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "Conversion.HandleRequested",
"event_type", e.Type(),
)
log.Info("🟢 [START] Event received")
ccr, ok := e.(*events.CurrencyConversionRequested)
if !ok {
log.Error(
"Unexpected event type",
"event", e,
)
return fmt.Errorf("unexpected event type %T", e)
}
log.Debug(
"ConversionRequestedEvent details",
"event", ccr,
)
// Use the factory map to get the correct event factory for the flow type.
factory, found := factories[ccr.FlowType]
if !found {
log.Error(
"Unknown flow type in ConversionRequestedEvent, discarding",
"flow_type", ccr.FlowType,
)
return fmt.Errorf("unknown flow type %s", ccr.FlowType)
}
srv := exchange.New(exchangeRegistry, exchangeRateProvider, log)
convertedMoney,
convInfo,
err := srv.
Convert(
ctx,
ccr.Amount,
ccr.To,
)
if err != nil {
log.Error(
"Failed to convert currency",
"error", err,
"event_type", ccr.Type(),
"event_id", ccr.ID,
)
return err
}
// Log OriginalRequest details for debugging
log.Debug(
"Creating CurrencyConverted event",
"original_request_type", fmt.Sprintf("%T", ccr.OriginalRequest),
"original_request_nil", ccr.OriginalRequest == nil,
"transaction_id", ccr.TransactionID,
)
cc := events.NewCurrencyConverted(
ccr,
func(cc *events.CurrencyConverted) {
cc.TransactionID = ccr.TransactionID
cc.ConvertedAmount = convertedMoney
cc.ConversionInfo = convInfo
// Ensure OriginalRequest is preserved
cc.OriginalRequest = ccr.OriginalRequest
},
)
log.Info(
"🔄 [PROCESS] Conversion completed successfully",
"amount", ccr.Amount,
"to", convertedMoney,
)
log.Info("📤 [EMIT] Emitting event", "event_type", cc.Type(), "event_id", cc.ID)
if err = bus.Emit(ctx, cc); err != nil {
log.Error(
"Failed to emit done",
"error", err,
"event_type", cc.Type(),
"event_id", cc.ID,
)
return err
}
// Delegate the creation of the next event to the factory.
nextEvent := factory.CreateNextEvent(cc)
log.Info(
"✅ Created next event",
"event_type", nextEvent.Type(),
"event_id", ccr.ID,
"correlation_id", ccr.CorrelationID,
)
log.Debug(
"🔧 Next event details",
"event", nextEvent,
"event_type", fmt.Sprintf("%T", nextEvent),
"correlation_id", ccr.CorrelationID,
)
log.Info("📤 [EMIT] Emitting event", "event_type", nextEvent.Type())
// Emit the next event in the flow
if err := bus.Emit(ctx, nextEvent); err != nil {
log.Error(
"Failed to emit next event",
"error", err,
"event_type", nextEvent.Type(),
"event_id", ccr.ID,
"correlation_id", ccr.CorrelationID,
)
return fmt.Errorf("failed to emit next event: %w", err)
}
log.Info(
"✅ Successfully emitted next event",
"event_type", nextEvent.Type(),
"event_id", ccr.ID,
"correlation_id", ccr.CorrelationID,
)
return nil
}
}
package conversion
import (
"time"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/google/uuid"
)
// NewValidConversionRequestedEvent returns a fully valid ConversionRequestedEvent for use in tests.
func NewValidConversionRequestedEvent(
flow events.FlowEvent,
transactionID uuid.UUID,
amount money.Money,
to string,
) *events.CurrencyConversionRequested {
// Create the event using the factory function with options
event := events.NewCurrencyConversionRequested(
flow,
nil,
events.WithConversionAmount(&amount),
events.WithConversionTo(money.Code(to)),
events.WithConversionTransactionID(transactionID),
)
return event
}
// NewValidConversionInfo returns a fully valid Info for use in tests.
func NewValidConversionInfo(
fromCurrency, toCurrency string,
rateValue float64,
) *exchange.RateInfo {
return &exchange.RateInfo{
FromCurrency: fromCurrency,
ToCurrency: toCurrency,
Rate: rateValue,
Timestamp: time.Now(), // Add timestamp
Provider: "test", // Add provider
}
}
package fees
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleCalculated handles FeesCalculated events.
// It updates the transaction with the calculated fees and deducts them from the account balance.
func HandleCalculated(
uow repository.UnitOfWork,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "fees.HandleCalculated",
"event_type", e.Type(),
)
log.Info("🟢 [START] Processing FeesCalculated event")
// Type assert to get the FeesCalculated event
fc, ok := e.(*events.FeesCalculated)
if !ok {
err := fmt.Errorf("unexpected event type: %s", e.Type())
log.Error("unexpected event type", "error", err)
return err
}
log = log.With(
"transaction_id", fc.TransactionID,
"event_id", fc.ID,
"fee_amount", fc.Fee.Amount,
)
if err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
// Get transaction repository
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
log.Error(
"failed to get transaction repository",
"error", err,
"transaction_id", fc.TransactionID,
)
return fmt.Errorf("failed to get transaction repository: %w", err)
}
// Get account repository
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
log.Error(
"failed to get account repository",
"error", err,
"transaction_id", fc.TransactionID,
)
return fmt.Errorf("failed to get account repository: %w", err)
}
// Create fee calculator and apply fees
calculator := NewFeeCalculator(txRepo, accRepo, log)
if err := calculator.ApplyFees(ctx, fc.TransactionID, fc.Fee); err != nil {
log.Error("failed to apply fees",
"error", err,
"transaction_id", fc.TransactionID,
"fee_amount", fc.Fee.Amount,
)
return fmt.Errorf("failed to apply fees: %w", err)
}
log.Info("✅ Successfully processed fee calculation")
return nil
}); err != nil {
log.Error("failed to process FeesCalculated event", "error", err)
return err
}
log.Info("🟢 [END] Processing FeesCalculated event")
return nil
}
}
package fees
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/mapper"
"github.com/amirasaad/fintech/pkg/money"
repoaccount "github.com/amirasaad/fintech/pkg/repository/account"
repotransaction "github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/google/uuid"
)
// FeeCalculator handles fee calculation and application to transactions and accounts
type FeeCalculator struct {
txRepo repotransaction.Repository
accRepo repoaccount.Repository
logger *slog.Logger
}
// NewFeeCalculator creates a new FeeCalculator instance
// Returns nil if any of the required parameters are nil
func NewFeeCalculator(
txRepo repotransaction.Repository,
accRepo repoaccount.Repository,
logger *slog.Logger,
) *FeeCalculator {
if txRepo == nil || accRepo == nil {
return nil
}
if logger == nil {
logger = slog.Default()
}
return &FeeCalculator{
txRepo: txRepo,
accRepo: accRepo,
logger: logger,
}
}
// ApplyFees applies the calculated fees to a transaction and updates the account balance
func (fc *FeeCalculator) ApplyFees(
ctx context.Context,
transactionID uuid.UUID,
fee account.Fee,
) error {
// Get the transaction
tx, err := fc.txRepo.Get(ctx, transactionID)
if err != nil {
fc.logger.Error("failed to get transaction", "error", err, "transaction_id", transactionID)
return err
}
// Update transaction with new fee
if err := fc.updateTransactionFee(ctx, tx, fee); err != nil {
return err
}
// Update account balance with fee deduction
if err := fc.updateAccountBalance(ctx, tx.AccountID, fee.Amount); err != nil {
return err
}
return nil
}
// updateTransactionFee updates a transaction with the calculated fee
func (fc *FeeCalculator) updateTransactionFee(
ctx context.Context,
tx *dto.TransactionRead,
fee account.Fee,
) error {
// Validate currency is set
if tx.Currency == "" {
err := fmt.Errorf("transaction %s has no currency set", tx.ID)
fc.logger.Error("transaction has no currency",
"error", err,
"transaction_id", tx.ID,
)
return err
}
// Convert existing fee to money type
txFee, err := money.New(tx.Fee, money.Code(tx.Currency))
if err != nil {
fc.logger.Error("invalid transaction fee amount",
"error", err,
"transaction_id", tx.ID,
"fee", tx.Fee,
"currency", tx.Currency,
)
return fmt.Errorf("invalid transaction fee amount: %w", err)
}
// Add the new fee
totalFee, err := txFee.Add(fee.Amount)
if err != nil {
fc.logger.Error("failed to add fees",
"error", err,
"transaction_id", tx.ID,
"existing_fee", txFee,
"new_fee", fee.Amount,
)
return fmt.Errorf("failed to add fees: %w", err)
}
// Update the transaction
totalFeeAmount := totalFee.Amount()
updateTx := dto.TransactionUpdate{Fee: &totalFeeAmount}
if err := fc.txRepo.Update(ctx, tx.ID, updateTx); err != nil {
fc.logger.Error("failed to update transaction with fees",
"error", err,
"transaction_id", tx.ID,
"fee", fee.Amount,
)
return fmt.Errorf("failed to update transaction: %w", err)
}
fc.logger.Info("updated transaction with fee",
"transaction_id", tx.ID,
"total_fee", totalFee,
)
return nil
}
// updateAccountBalance updates an account balance by deducting the fee
func (fc *FeeCalculator) updateAccountBalance(
ctx context.Context,
accountID uuid.UUID,
feeAmount *money.Money,
) error {
// Get the account
acc, err := fc.accRepo.Get(ctx, accountID)
if err != nil {
fc.logger.Error("failed to get account", "error", err, "account_id", accountID)
return err
}
// Convert to domain model to use money operations
domainAcc, err := mapper.MapAccountReadToDomain(acc)
if err != nil {
fc.logger.Error("error creating account from dto", "error", err, "account_id", accountID)
return err
}
// Calculate new balance
newBalance, err := domainAcc.Balance.Subtract(feeAmount)
if err != nil {
fc.logger.Error("failed to subtract fee",
"fee", feeAmount,
"current_balance", domainAcc.Balance,
"account_id", accountID,
)
return fmt.Errorf("failed to subtract fee from balance: %w", err)
}
// Update account balance
balanceAmount := newBalance.Amount()
if err := fc.accRepo.Update(
ctx,
acc.ID,
dto.AccountUpdate{Balance: &balanceAmount},
); err != nil {
fc.logger.Error("failed to update account balance",
"error", err,
"account_id", accountID,
"new_balance", balanceAmount,
)
return err
}
fc.logger.Info("updated account balance with fee deduction",
"account_id", accountID,
"fee_deducted", feeAmount,
"new_balance", balanceAmount,
)
return nil
}
package payment
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/mapper"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
)
// ExtractPaymentCompletedKey extracts idempotency key from PaymentCompleted event
func ExtractPaymentCompletedKey(e events.Event) string {
pc, ok := e.(*events.PaymentCompleted)
if !ok {
return ""
}
if pc.PaymentID != nil && *pc.PaymentID != "" {
return *pc.PaymentID
}
if pc.TransactionID != uuid.Nil {
return pc.TransactionID.String()
}
return ""
}
// HandleCompleted handles PaymentCompletedEvent,
// updates the transaction status in the DB, and publishes a follow-up event if needed.
func HandleCompleted(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) func(
ctx context.Context,
e events.Event,
) error {
return func(
ctx context.Context,
e events.Event,
) error {
if logger == nil {
logger = slog.Default()
}
log := logger.With(
"handler", "payment.HandleCompleted",
"event", e,
"event_type", e.Type(),
)
log.Info(
"🟢 [START] HandleCompleted received event",
)
log.Debug("🟢 Handling PaymentCompleted event",
"event_type", e.Type(),
"event", fmt.Sprintf("%+v", e),
)
pc, ok := e.(*events.PaymentCompleted)
if !ok {
log.Error(
"Skipping unexpected event type",
"event", e,
)
return nil
}
// Build log fields safely without dereferencing nil pointers
logFields := []any{
"user_id", pc.UserID,
"account_id", pc.AccountID,
"transaction_id", pc.TransactionID,
}
if pc.PaymentID != nil {
logFields = append(logFields, "payment_id", *pc.PaymentID)
}
log = log.With(logFields...)
if err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
accRepo, err := common.GetAccountRepository(uow, log)
if err != nil {
return err
}
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
log.Error("failed to get transaction repository", "error", err)
return err
}
if pc.PaymentID == nil {
log.Error("payment ID is nil")
return fmt.Errorf("payment ID is nil")
}
// Lookup transaction by payment ID or transaction ID
lookupResult := common.LookupTransactionByPaymentOrID(
ctx,
txRepo,
pc.PaymentID,
pc.TransactionID,
log,
)
if lookupResult.Error != nil {
return lookupResult.Error
}
if !lookupResult.Found {
return nil // Skip gracefully if transaction not found
}
tx := lookupResult.Transaction
// Update the transaction with the payment ID if it wasn't set
if tx.PaymentID == nil || (tx.PaymentID != nil && *tx.PaymentID != *pc.PaymentID) {
update := dto.TransactionUpdate{
PaymentID: pc.PaymentID,
}
if uerr := txRepo.Update(ctx, tx.ID, update); uerr != nil {
log.Error(
"failed to update transaction with payment ID",
"transaction_id", tx.ID,
"payment_id", pc.PaymentID,
"error", uerr,
)
return fmt.Errorf("failed to update transaction: %w", uerr)
}
tx.PaymentID = pc.PaymentID
}
log = log.With(
"transaction_id", tx.ID,
"user_id", tx.UserID,
)
acc, err := accRepo.Get(ctx, tx.AccountID)
if err != nil {
log.Error(
"failed to get account",
"error", err,
)
return err
}
domainAcc, err := mapper.MapAccountReadToDomain(acc)
if err != nil {
log.Error(
"failed to map account to domain",
"error", err,
)
return err
}
// Log provider fee details before calculation
newBalance, err := domainAcc.Balance.Add(pc.Amount)
if err != nil {
log.Error(
"failed to add net transaction amount to balance",
"error", err,
)
return err
}
oldStatus := tx.Status
status := string(account.TransactionStatusCompleted)
tx.Status = status
// Store the gross amount in the transaction
amount := pc.Amount.Amount()
currency := pc.Amount.Currency().String()
balance := newBalance.Amount()
update := dto.TransactionUpdate{
Status: &status,
Amount: &amount,
Currency: ¤cy,
Balance: &balance,
}
if err = txRepo.Update(ctx, tx.ID, update); err != nil {
log.Error(
"failed to update transaction status",
"error", err,
)
return err
}
log.Info(
"✅ [SUCCESS] transaction status updated",
"old_status", oldStatus,
"new_status", tx.Status,
)
f64Balance := newBalance.Amount()
if err := accRepo.Update(
ctx,
tx.AccountID,
dto.AccountUpdate{Balance: &f64Balance},
); err != nil {
log.Error(
"failed to update account balance",
"error", err,
)
return err
}
log.Info(
"✅ [SUCCESS] account balance updated",
"account_id", acc.ID,
"new_balance", newBalance,
"balance", domainAcc.Balance,
)
log.Info(
"✅ [SUCCESS] emitted FeesCalculated event",
"transaction_id", tx.ID)
return nil
}); err != nil {
log.Error(
"uow.Do failed",
"error", err,
)
return err
}
return nil
}
}
package payment
import (
"context"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/repository"
)
// HandleFailed handles the PaymentFailedEvent by updating the transaction status to "failed"
func HandleFailed(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(ctx context.Context, event events.Event) error {
log := logger.With("handler", "payment.HandleFailed", "event_type", event.Type())
log.Info("handling payment failed event")
// Check if the event is a PaymentFailed event
pf, ok := event.(*events.PaymentFailed)
if !ok {
err := fmt.Errorf("expected PaymentFailed event, got %T", event)
log.Error("invalid event type", "error", err)
return err
}
// Use the transaction ID from the event
txID := pf.TransactionID
log = log.With(
"transaction_id", txID,
"payment_id", pf.PaymentID,
"user_id", pf.UserID,
"account_id", pf.AccountID,
"correlation_id", pf.CorrelationID,
)
// Get the transaction repository
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
err = fmt.Errorf("failed to get transaction repository: %w", err)
log.Error("repository error", "error", err)
return err
}
// Update the transaction status to failed
status := string(account.TransactionStatusFailed)
updateErr := txRepo.Update(ctx, txID, dto.TransactionUpdate{
PaymentID: pf.PaymentID, // Update to handle PaymentID as a pointer
Status: &status,
})
if updateErr != nil {
err = fmt.Errorf("failed to update transaction status: %w", updateErr)
log.Error("update error", "error", err)
return err
}
// Commit the transaction
if err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
log.Info("committing transaction update")
return nil
}); err != nil {
err = fmt.Errorf("failed to commit transaction: %w", err)
log.Error("commit error", "error", err)
return err
}
log.Info("successfully processed payment failed event")
return nil
}
}
package payment
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/google/uuid"
)
// ExtractPaymentInitiatedKey extracts idempotency key from PaymentInitiated event
func ExtractPaymentInitiatedKey(e events.Event) string {
pi, ok := e.(*events.PaymentInitiated)
if !ok {
return ""
}
if pi.TransactionID != uuid.Nil {
return pi.TransactionID.String()
}
return ""
}
// HandleInitiated handles DepositBusinessValidatedEvent and initiates payment for deposits.
func HandleInitiated(
bus eventbus.Bus,
paymentProvider payment.Payment,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(ctx context.Context, e events.Event) error {
log := logger.With(
"handler", "payment.HandleInitiated",
"event_type", e.Type(),
)
log.Debug("🔄 Handling PaymentInitiated event",
"event_type", e.Type(),
"event", fmt.Sprintf("%+v", e),
)
pi, ok := e.(*events.PaymentInitiated)
if !ok {
log.Error(
"unexpected event type",
"event_type", fmt.Sprintf("%T", e),
)
return errors.New("unexpected event type")
}
log = log.With(
"user_id", pi.UserID,
"account_id", pi.AccountID,
"transaction_id", pi.TransactionID,
"correlation_id", pi.CorrelationID,
)
transactionID := pi.TransactionID
// Call payment provider
amount := pi.Amount.Amount()
currency := pi.Amount.Currency().String()
payment, err := paymentProvider.InitiatePayment(
ctx,
&payment.InitiatePaymentParams{
UserID: pi.UserID,
AccountID: pi.AccountID,
Amount: amount,
Currency: currency,
TransactionID: transactionID,
},
)
if err != nil {
log.Error(
"Payment initiation failed",
"error", err,
"user_id", pi.UserID,
"account_id", pi.AccountID,
"transaction_id", transactionID,
)
return err
}
log.Info(
"✅ [SUCCESS] Initiated payment",
"transaction_id", transactionID,
"payment", payment,
)
return nil
}
}
package payment
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
)
// ExtractPaymentProcessedKey extracts idempotency key from PaymentProcessed event
func ExtractPaymentProcessedKey(e events.Event) string {
pp, ok := e.(*events.PaymentProcessed)
if !ok {
return ""
}
if pp.PaymentID != nil && *pp.PaymentID != "" {
return *pp.PaymentID
}
if pp.TransactionID != uuid.Nil {
return pp.TransactionID.String()
}
return ""
}
// HandleProcessed handles PaymentInitiatedEvent and updates the transaction with payment ID.
// This is a generic handler that can process payment events
// for all operations (deposit, withdraw, transfer).
func HandleProcessed(
uow repository.UnitOfWork,
logger *slog.Logger,
) eventbus.HandlerFunc {
return func(
ctx context.Context,
e events.Event,
) error {
log := logger.With(
"handler", "HandleProcessed",
"event_type", e.Type(),
)
log.Info("🟢 [START] event received")
pp, ok := e.(*events.PaymentProcessed)
if !ok {
log.Error(
"Unexpected event type for payment processed",
"event", e,
)
return errors.New("unexpected event type")
}
// Build log fields safely without dereferencing nil pointers
logFields := []any{
"transaction_id", pp.TransactionID,
}
if pp.PaymentID != nil {
logFields = append(logFields, "payment_id", *pp.PaymentID)
}
log = log.With(logFields...)
log.Info(
"🔄 [PROCESS] Updating transaction with payment ID")
// Update the transaction with payment ID
err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
txRepo, err := common.GetTransactionRepository(uow, log)
if err != nil {
log.Error(
"Failed to get transaction repo",
"error", err,
)
return fmt.Errorf("failed to get transaction repo: %w", err)
}
// Lookup transaction by payment ID or transaction ID
lookupResult := common.LookupTransactionByPaymentOrID(
ctx,
txRepo,
pp.PaymentID,
pp.TransactionID,
log,
)
if lookupResult.Error != nil {
return lookupResult.Error
}
if !lookupResult.Found {
return nil // Skip gracefully if transaction not found
}
tx := lookupResult.Transaction
transactionID := lookupResult.TransactionID
status := "processed"
// If transaction exists, update it with payment ID
if tx != nil {
update := dto.TransactionUpdate{
PaymentID: pp.PaymentID,
Status: &status,
}
if err := txRepo.Update(ctx, transactionID, update); err != nil {
log.Error(
"Failed to update transaction with payment ID",
"error", err,
)
return fmt.Errorf("failed to update transaction: %w", err)
}
log.Info(
"Updated existing transaction with payment ID",
)
return nil
}
// If transaction doesn't exist, create a new one
txCreate := dto.TransactionCreate{
ID: transactionID,
UserID: pp.UserID,
AccountID: pp.AccountID,
Status: status,
MoneySource: "Stripe", // Default money source for Stripe payments
PaymentID: pp.PaymentID,
}
// Set amount and currency if available
if pp.Amount != nil {
txCreate.Amount = int64(pp.Amount.Amount())
txCreate.Currency = pp.Amount.Currency().String()
}
// Create the transaction using UpsertByPaymentID which handles both create and update
if err := txRepo.UpsertByPaymentID(ctx, *pp.PaymentID, txCreate); err != nil {
log.Error(
"Failed to create/update transaction with payment ID",
"error", err,
)
return fmt.Errorf("failed to create/update transaction: %w", err)
}
log.Info(
"Transaction updated with payment ID",
)
return nil
})
if err != nil {
log.Error(
"Uow.Do failed",
"error", err,
)
return err
}
log.Info("✅ [SUCCESS] event processed")
return nil
}
}
package payment
import (
"context"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/handler/testutils"
"github.com/amirasaad/fintech/pkg/repository"
repoaccount "github.com/amirasaad/fintech/pkg/repository/account"
"github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/stretchr/testify/mock"
)
// createValidPaymentCompletedEvent creates a valid PaymentCompletedEvent
func createValidPaymentCompletedEvent(
h *testutils.TestHelper,
) *events.PaymentCompleted {
// Use the amount directly from the test helper
amount := h.Amount
return events.NewPaymentCompleted(
&events.FlowEvent{
ID: h.EventID,
CorrelationID: h.CorrelationID,
FlowType: "payment",
},
func(pc *events.PaymentCompleted) {
paymentID := "test-payment-id"
pc.PaymentID = &paymentID
pc.TransactionID = h.TransactionID
pc.Amount = amount
pc.Status = "completed"
},
)
}
// createValidPaymentFailedEvent creates a valid PaymentFailedEvent
func createValidPaymentFailedEvent(
h *testutils.TestHelper,
) *events.PaymentFailed {
return events.NewPaymentFailed(
&events.FlowEvent{
ID: h.EventID,
CorrelationID: h.CorrelationID,
FlowType: "payment",
}, func(pf *events.PaymentFailed) {
if h.PaymentID != nil {
pf.PaymentID = h.PaymentID
}
pf.TransactionID = h.TransactionID
}).WithReason("payment processing failed")
}
// setupSuccessfulTest configures mocks for a successful payment completion
func setupSuccessfulTest(h *testutils.TestHelper) {
// Use the amount directly from the test helper
amount := h.Amount
// Setup payment ID to match what createValidPaymentCompletedEvent creates
paymentID := "test-payment-id"
if h.PaymentID == nil {
h.PaymentID = &paymentID
}
// Setup test transaction with payment ID matching the event
tx := &dto.TransactionRead{
ID: h.TransactionID,
UserID: h.UserID,
AccountID: h.AccountID,
PaymentID: &paymentID,
Status: "pending",
Currency: amount.CurrencyCode().String(),
Amount: amount.AmountFloat(),
}
// Setup test account
testAccount := &dto.AccountRead{
ID: h.AccountID,
UserID: h.UserID,
Balance: amount.AmountFloat(),
Currency: amount.CurrencyCode().String(),
}
doFn := func(ctx context.Context, fn func(uow repository.UnitOfWork) error) error {
h.UOW.
EXPECT().
GetRepository(
(*transaction.Repository)(nil),
).
Return(
h.MockTxRepo, nil,
)
h.MockTxRepo.
EXPECT().
GetByPaymentID(ctx, paymentID).
Return(tx, nil).
Once()
h.UOW.
EXPECT().
GetRepository(
(*repoaccount.Repository)(nil),
).
Return(
h.MockAccRepo, nil,
).Once()
h.MockAccRepo.
EXPECT().
Get(ctx, h.AccountID).
Return(testAccount, nil).
Once()
// Setup mock expectations for account update
h.MockAccRepo.EXPECT().
Update(ctx, h.AccountID, mock.MatchedBy(func(update dto.AccountUpdate) bool {
// Verify the account balance is being updated correctly
return update.Balance != nil && *update.Balance > 0
})).
Return(nil).
Once()
// Setup mock expectations for transaction update
h.MockTxRepo.EXPECT().
Update(ctx, h.TransactionID, mock.MatchedBy(func(update dto.TransactionUpdate) bool {
// Verify the transaction status is being updated to "completed"
return update.Status != nil && *update.Status == "completed"
})).
Return(nil).
Once()
err := fn(h.UOW)
return err
}
h.UOW.
EXPECT().
Do(
h.Ctx,
mock.AnythingOfType("func(repository.UnitOfWork) error")).
RunAndReturn(doFn).
Once()
}
package testutils
import (
"context"
"io"
"log/slog"
"sync"
"testing"
"github.com/amirasaad/fintech/internal/fixtures/mocks"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/money"
"github.com/google/uuid"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)
const (
// DefaultCurrencyCode is the default currency code used in tests
DefaultCurrencyCode = "USD"
// DefaultAmount is the default amount used in tests (100.00)
DefaultAmount = 100.0
// DefaultFeeAmount is the default fee amount used in tests (1.00)
DefaultFeeAmount = 1.0
)
type TestEvent struct{}
func (e *TestEvent) Type() string { return "test.event" }
// TestHelper contains all test dependencies and helper methods
type TestHelper struct {
T *testing.T
Handler eventbus.HandlerFunc
MockPaymentProvider *mocks.PaymentProvider
Ctx context.Context
Bus *mocks.Bus
UOW *mocks.UnitOfWork
MockAccRepo *mocks.AccountRepository
MockTxRepo *mocks.TransactionRepository
Logger *slog.Logger
// Test data
UserID uuid.UUID
AccountID uuid.UUID
PaymentID *string
EventID uuid.UUID
CorrelationID uuid.UUID
TransactionID uuid.UUID
Amount *money.Money
FeeAmount *money.Money
}
// New creates a new test helper with fresh mocks and test data
func New(t *testing.T, opts ...TestOption) *TestHelper {
t.Helper()
// Setup defaults
h := &TestHelper{
T: t,
Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), // Create a new logger for each test
}
// Initialize mocks
h.UOW = mocks.NewUnitOfWork(t)
h.MockAccRepo = mocks.NewAccountRepository(t)
h.MockTxRepo = mocks.NewTransactionRepository(t)
h.MockPaymentProvider = mocks.NewPaymentProvider(t)
// Apply options (excluding default currency initialization since we did it above)
for _, opt := range append(defaultTestOptions, opts...) {
opt(h)
}
// Initialize test data if not set by options
if h.Handler == nil {
h.Handler = eventbus.HandlerFunc(
func(ctx context.Context, event events.Event) error {
return nil
})
}
if h.Ctx == nil {
h.Ctx = context.Background()
}
if h.Bus == nil {
h.Bus = mocks.NewBus(t)
}
// Initialize test data with default values if not set
if h.UserID == uuid.Nil {
h.UserID = uuid.New()
}
if h.AccountID == uuid.Nil {
h.AccountID = uuid.New()
}
if h.TransactionID == uuid.Nil {
h.TransactionID = uuid.New()
}
if h.EventID == uuid.Nil {
h.EventID = uuid.New()
}
if h.CorrelationID == uuid.Nil {
h.CorrelationID = uuid.New()
}
// Initialize amounts if not set
if h.Amount == nil {
amount, err := money.New(DefaultAmount, money.Code(DefaultCurrencyCode).ToCurrency())
require.NoError(t, err, "failed to create default amount")
h.Amount = amount
}
if h.FeeAmount == nil {
feeAmount, err := money.New(DefaultFeeAmount, money.Code(DefaultCurrencyCode).ToCurrency())
require.NoError(t, err, "failed to create default fee amount")
h.FeeAmount = feeAmount
}
return h
}
// TestOption defines a function type for test options
type TestOption func(*TestHelper)
var (
initOnce sync.Once
)
var defaultTestOptions = []TestOption{
// Initialize currency registry with default currencies
func(h *TestHelper) {
// Use sync.Once to ensure initialization happens only once
initOnce.Do(func() {
// No need to initialize currency registry as it's handled by the money package
})
},
}
// WithHandler sets a custom handler for the test helper
func (h *TestHelper) WithHandler(
handler eventbus.HandlerFunc) *TestHelper {
h.Handler = handler
return h
}
// WithContext sets the context for the test helper
func (h *TestHelper) WithContext(ctx context.Context) *TestHelper {
h.Ctx = ctx
return h
}
// WithAmount sets a custom amount for the test helper
func (h *TestHelper) WithAmount(amount *money.Money) *TestHelper {
h.Amount = amount
return h
}
// WithFeeAmount sets a custom fee amount for the test helper
func (h *TestHelper) WithFeeAmount(amount *money.Money) *TestHelper {
h.FeeAmount = amount
return h
}
// WithUserID sets a custom user ID for the test helper
func (h *TestHelper) WithUserID(id uuid.UUID) *TestHelper {
h.UserID = id
return h
}
// WithAccountID sets a custom account ID for the test helper
func (h *TestHelper) WithAccountID(id uuid.UUID) *TestHelper {
h.AccountID = id
return h
}
// WithTransactionID sets a custom transaction ID for the test helper
func (h *TestHelper) WithTransactionID(d uuid.UUID) *TestHelper {
h.TransactionID = d
return h
}
// WithPaymentID sets a custom payment ID for the test helper
func (h *TestHelper) WithPaymentID(id *string) *TestHelper {
h.PaymentID = id
return h
}
// CreateValidTransaction creates a test transaction DTO
func (h *TestHelper) CreateValidTransaction() *dto.TransactionRead {
amount := h.Amount.AmountFloat()
currency := h.Amount.CurrencyCode().String()
return &dto.TransactionRead{
ID: h.TransactionID,
UserID: h.UserID,
AccountID: h.AccountID,
PaymentID: h.PaymentID,
Status: string(account.TransactionStatusPending),
Currency: currency,
Amount: amount,
}
}
// CreateValidAccount creates a test account DTO
func (h *TestHelper) CreateValidAccount() *dto.AccountRead {
amount := h.Amount.AmountFloat()
currency := h.Amount.CurrencyCode().String()
return &dto.AccountRead{
ID: h.AccountID,
UserID: h.UserID,
Balance: amount,
Currency: currency,
}
}
// SetupMocks configures the default mock expectations
func (h *TestHelper) SetupMocks() {
h.UOW.EXPECT().GetRepository(mock.Anything).Return(h.MockAccRepo, nil).Maybe()
h.UOW.EXPECT().GetRepository(mock.Anything).Return(h.MockTxRepo, nil).Maybe()
}
// AssertExpectations asserts all mock expectations
func (h *TestHelper) AssertExpectations() {
}
package mapper
import (
"fmt"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/money"
)
// MapAccountReadToDomain maps a dto.AccountRead to a domain Account.
func MapAccountReadToDomain(dto *dto.AccountRead) (*account.Account, error) {
balance, err := money.New(dto.Balance, money.Code(dto.Currency))
if err != nil {
return nil, fmt.Errorf("error creating money from dto: %w", err)
}
acc, err := account.New().
WithID(dto.ID).
WithUserID(dto.UserID).
WithBalance(balance.Amount()).
WithCurrency(money.Code(balance.Currency().String())).
WithCreatedAt(dto.CreatedAt).
WithUpdatedAt(dto.UpdatedAt).
Build()
if err != nil {
return nil, fmt.Errorf("error creating account from dto: %w", err)
}
return acc, nil
}
package middleware
import (
"github.com/amirasaad/fintech/pkg/config"
jwtware "github.com/gofiber/contrib/jwt"
"github.com/gofiber/fiber/v2"
)
// JwtProtected protect routes
func JwtProtected(cfg *config.Jwt) fiber.Handler {
return jwtware.New(jwtware.Config{
SigningKey: jwtware.SigningKey{Key: []byte(cfg.Secret)},
ErrorHandler: jwtError,
})
}
func jwtError(c *fiber.Ctx, err error) error {
if err.Error() == "Missing or malformed JWT" {
return c.Status(fiber.StatusBadRequest).
JSON(fiber.Map{"status": "error", "message": "Missing or malformed JWT", "data": nil})
}
return c.Status(fiber.StatusUnauthorized).
JSON(fiber.Map{"status": "error", "message": "Invalid or expired JWT", "data": nil})
}
// Package money provides functionality for handling monetary values.
//
// It is a value object that represents a monetary value in a specific currency.
// Invariants:
// - Amount is always stored in the smallest currency unit (e.g., cents for USD).
// - Currency code must be valid ISO 4217 (3 uppercase letters).
// - All arithmetic operations require matching currencies.
package money
import (
"encoding/json"
"fmt"
"math"
"math/big"
)
var (
// ErrInvalidAmount is returned when an invalid amount is provided.
ErrInvalidAmount = fmt.Errorf("invalid amount float")
// ErrAmountExceedsMaxSafeInt is returned when an amount exceeds the maximum safe integer value.
ErrAmountExceedsMaxSafeInt = fmt.Errorf("amount exceeds maximum safe integer value")
// ErrMismatchedCurrencies is returned when performing operations
// on money with different currencies.
ErrInvalidCurrency = fmt.Errorf("invalid currency code")
)
// Amount represents a monetary amount as an integer in the
// smallest currency unit (e.g., cents for USD).
type Amount = int64
// Code represents a currency code (e.g., "USD", "EUR")
// Code is defined in codes.go
// ToCurrency converts a Code to a Currency with default decimals
func (c Code) ToCurrency() Currency {
switch c {
case USD:
return USDCurrency
case EUR:
return EURCurrency
case GBP:
return GBPCurrency
case JPY:
return JPYCurrency
default:
return Currency{Code: c, Decimals: 2}
}
}
// IsValid checks if the currency code is valid
func (c Code) IsValid() bool {
if len(c) != 3 {
return false
}
return c[0] >= 'A' && c[0] <= 'Z' &&
c[1] >= 'A' && c[1] <= 'Z' &&
c[2] >= 'A' && c[2] <= 'Z'
}
// String returns the string representation of the currency code.
func (c Code) String() string {
return string(c)
}
// Currency represents a monetary unit with its standard decimal places
type Currency struct {
Code Code // 3-letter ISO 4217 code (e.g., "USD")
Decimals int // Number of decimal places (0-8)
}
// IsValid checks if the currency is valid.
func (c Currency) IsValid() bool {
if c.Decimals < 0 || c.Decimals > 8 {
return false
}
return len(c.Code) == 3 &&
c.Code[0] >= 'A' && c.Code[0] <= 'Z' &&
c.Code[1] >= 'A' && c.Code[1] <= 'Z' &&
c.Code[2] >= 'A' && c.Code[2] <= 'Z'
}
// String returns the currency code as a string
func (c Currency) String() string { return string(c.Code) }
// Common currency codes are defined in codes.go
// Common currency instances
var (
USDCurrency = Currency{Code: USD, Decimals: 2}
EURCurrency = Currency{Code: EUR, Decimals: 2}
GBPCurrency = Currency{Code: GBP, Decimals: 2}
JPYCurrency = Currency{Code: JPY, Decimals: 0} // Japanese Yen has no decimal places
)
// DefaultCurrency is the default currency (USD)
var DefaultCurrency = USDCurrency
// DefaultCode is the default currency code (USD)
var DefaultCode = USD
// Money represents a monetary value in a specific currency.
// Invariants:
// - Amount is always stored in the smallest currency unit (e.g., cents for USD).
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
// - All arithmetic operations require matching currencies.
type Money struct {
amount Amount
currency Currency
}
// MarshalJSON implements json.Marshaler interface.
func (m Money) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"amount": m.amount,
"currency": m.currency.Code,
})
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (m *Money) UnmarshalJSON(data []byte) error {
var aux struct {
Amount int64 `json:"amount"`
Currency string `json:"currency"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
// Create currency and validate
currency := Currency{Code: Code(aux.Currency)}
switch aux.Currency {
case "USD", "EUR", "GBP":
currency.Decimals = 2
case "JPY":
currency.Decimals = 0
default:
currency.Decimals = 2 // Default to 2 decimal places
}
if !currency.IsValid() {
return fmt.Errorf("invalid currency code: %s", aux.Currency)
}
m.amount = aux.Amount
m.currency = currency
return nil
}
// Zero creates a Money object with zero amount in the specified currency.
// The currency parameter can be either a Code or a Currency.
func Zero(currency interface{}) *Money {
var c Currency
switch v := currency.(type) {
case Code:
c = v.ToCurrency()
case Currency:
c = v
default:
// Default to USD if invalid type is provided
c = USDCurrency
}
return &Money{
amount: 0,
currency: c,
}
}
// Must creates a Money object from the given amount and currency.
// Invariants enforced:
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
// - Amount must not have more decimal places than allowed by the currency.
// - Amount is converted to the smallest currency unit.
//
// Panics if any invariant is violated.
func Must(amount float64, currency Currency) *Money {
m, err := New(amount, currency)
if err != nil {
panic(fmt.Sprintf("money.Must(%v, %v): %v", amount, currency, err))
}
return m
}
// NewFromData creates a Money object from raw data (used for DB hydration).
// This bypasses invariants and should only be used for repository hydration or tests.
// Deprecated: use NewFromSmallestUnit instead.
func NewFromData(amount int64, cc string) *Money {
// This is intentionally not validating the currency code to allow for flexibility
// in database migrations and test data setup.
return &Money{
amount: amount,
currency: Currency{Code: Code(cc), Decimals: 2}, // Default to 2 decimal places
}
}
// New creates a new Money value object with the given amount and currency.
// The currency parameter can be either a Code, Currency, or string (e.g., "USD").
// Invariants enforced:
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
// - Amount must not have more decimal places than allowed by the currency.
// - Amount is converted to the smallest currency unit.
//
// Returns Money or an error if any invariant is violated.
func New(amount float64, currency any) (*Money, error) {
var c Currency
switch v := currency.(type) {
case string:
// Handle string currency codes like "USD"
if len(v) != 3 {
return nil, fmt.Errorf("%w: invalid currency code length: %s", ErrInvalidCurrency, v)
}
code := Code(v)
if !code.IsValid() {
return nil, fmt.Errorf("%w: %s", ErrInvalidCurrency, v)
}
c = code.ToCurrency()
case Code:
c = v.ToCurrency()
case Currency:
c = v
default:
return nil, fmt.Errorf(
"invalid currency type: %T, expected string, Code, or Currency",
currency,
)
}
if !c.IsValid() {
return nil, fmt.Errorf("%w: %v", ErrInvalidCurrency, c)
}
// Convert amount to smallest unit (e.g., dollars to cents)
smallestUnit, err := convertToSmallestUnit(amount, c)
if err != nil {
return nil, err
}
return &Money{
amount: Amount(smallestUnit),
currency: c,
}, nil
}
// NewFromSmallestUnit creates a new Money object from the smallest currency unit.
// The currency parameter can be either a Code or a Currency.
// Invariants enforced:
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
//
// Returns Money or an error if any invariant is violated.
func NewFromSmallestUnit(amount int64, currency interface{}) (*Money, error) {
var c Currency
switch v := currency.(type) {
case Code:
c = v.ToCurrency()
case Currency:
c = v
default:
return nil, fmt.Errorf("invalid currency type: %T", currency)
}
if !c.IsValid() {
return nil, fmt.Errorf("%w: %v", ErrInvalidCurrency, c)
}
return &Money{
amount: Amount(amount),
currency: c,
}, nil
}
// Amount returns the amount of the Money object in the smallest currency unit.
func (m *Money) Amount() Amount {
return m.amount
}
// AmountFloat returns the amount as a float64 in the main currency unit (e.g., dollars for USD).
func (m *Money) AmountFloat() float64 {
amount := new(big.Rat).SetInt64(int64(m.amount))
divisor := new(big.Rat).SetFloat64(math.Pow10(m.currency.Decimals))
result := new(big.Rat).Quo(amount, divisor)
floatResult, _ := result.Float64()
return floatResult
}
// Currency returns the currency of the Money object.
func (m *Money) Currency() Currency {
return m.currency
}
// CurrencyCode returns the currency code of the Money object.
func (m *Money) CurrencyCode() Code {
return m.currency.Code
}
// IsCurrency checks if the money object has the specified currency
func (m *Money) IsCurrency(currency Currency) bool {
return m.currency == currency
}
// Add returns a new Money object with the sum of amounts.
// Invariants enforced:
// - Currencies must match.
func (m *Money) Add(other *Money) (*Money, error) {
if m.currency != other.currency {
return nil, fmt.Errorf(
"cannot add different currencies: %s and %s",
m.currency.Code,
other.currency.Code,
)
}
sum := int64(m.amount) + int64(other.amount)
return &Money{
amount: Amount(sum),
currency: m.currency,
}, nil
}
// Subtract returns a new Money object with the difference of amounts.
// The result can be negative if the subtrahend is larger than the minuend.
// Invariants enforced:
// - Currencies must match.
func (m *Money) Subtract(other *Money) (*Money, error) {
if m.currency != other.currency {
return nil, fmt.Errorf(
"cannot subtract different currencies: %s and %s",
m.currency.Code,
other.currency.Code,
)
}
diff := int64(m.amount) - int64(other.amount)
return &Money{
amount: Amount(diff),
currency: m.currency,
}, nil
}
// Negate negates the current Money object.
func (m *Money) Negate() *Money {
return &Money{
amount: -m.amount,
currency: m.currency,
}
}
// Equals checks if the current Money object is equal to another Money object.
// Invariants enforced:
// - Currencies must match.
func (m *Money) Equals(other *Money) bool {
if m == nil || other == nil {
return false
}
return m.currency == other.currency && m.amount == other.amount
}
// GreaterThan checks if the current Money object is greater than another Money object.
// Invariants enforced:
// - Currencies must match.
func (m *Money) GreaterThan(other *Money) (bool, error) {
if m.currency != other.currency {
return false, fmt.Errorf(
"cannot compare different currencies: %s and %s",
m.currency.Code,
other.currency.Code,
)
}
return m.amount > other.amount, nil
}
// LessThan checks if the current Money object is less than another Money object.
// Invariants enforced:
// - Currencies must match.
//
// Returns an error if currencies do not match.
func (m *Money) LessThan(other *Money) (bool, error) {
if !m.IsSameCurrency(other) {
return false, ErrInvalidCurrency
}
return m.amount < other.amount, nil
}
// IsSameCurrency checks if the current Money object has the same currency as another Money object.
func (m *Money) IsSameCurrency(other *Money) bool {
return m.currency == other.currency
}
// IsPositive returns true if the Money is not nil and its amount is greater than zero.
func (m *Money) IsPositive() bool {
return m != nil && m.amount > 0
}
// IsNegative returns true if the Money is not nil and its amount is less than zero.
func (m *Money) IsNegative() bool {
return m != nil && m.amount < 0
}
// IsZero returns true if the Money is nil or its amount is zero.
func (m *Money) IsZero() bool {
return m == nil || m.amount == 0
}
// Abs returns the absolute value of the Money amount.
func (m *Money) Abs() *Money {
if m.amount < 0 {
return m.Negate()
}
return m
}
// Multiply multiplies the Money amount by a scalar factor.
// The result is rounded to the nearest integer.
// Invariants enforced:
// - Factor must not be negative.
// - Result must not overflow int64.
//
// Returns a new Money object or an error if the factor is invalid or would cause overflow.
func (m *Money) Multiply(factor float64) (*Money, error) {
if factor < 0 {
return nil, fmt.Errorf("factor cannot be negative")
}
// Convert to big.Rat for precise multiplication
amount := new(big.Rat).SetInt64(int64(m.amount))
f := new(big.Rat).SetFloat64(factor)
result := new(big.Rat).Mul(amount, f)
// Convert to float64 for overflow check and rounding
resultFloat, _ := result.Float64()
// Check for overflow before rounding
if resultFloat > float64(math.MaxInt64) || resultFloat < float64(math.MinInt64) {
return nil, fmt.Errorf("multiplication result would overflow")
}
// Round to nearest integer to avoid truncation of fractional cents
rounded := int64(math.Round(resultFloat))
return &Money{
amount: Amount(rounded),
currency: m.currency,
}, nil
}
// Divide divides the Money amount by a scalar divisor.
// The result is rounded to the nearest integer.
// Invariants enforced:
// - Divisor must be positive.
//
// Returns a new Money object or an error if the divisor is invalid.
func (m *Money) Divide(divisor float64) (*Money, error) {
if divisor <= 0 {
return nil, fmt.Errorf("divisor must be positive")
}
// Convert to big.Rat for precise division
amount := new(big.Rat).SetInt64(int64(m.amount))
d := new(big.Rat).SetFloat64(divisor)
result := new(big.Rat).Quo(amount, d)
// Round to nearest integer
resultFloat, _ := result.Float64()
rounded := int64(math.Round(resultFloat))
// Check for overflow - using big.Int for the comparison to handle all cases
bigRounded := big.NewInt(rounded)
maxInt64 := big.NewInt(math.MaxInt64)
minInt64 := big.NewInt(math.MinInt64)
if bigRounded.Cmp(maxInt64) > 0 || bigRounded.Cmp(minInt64) < 0 {
return nil, fmt.Errorf("division result would overflow")
}
return &Money{
amount: Amount(rounded),
currency: m.currency,
}, nil
}
// String returns a string representation of the Money object.
func (m *Money) String() string {
return fmt.Sprintf("%.*f %s", m.currency.Decimals, m.AmountFloat(), m.currency.Code)
}
// convertToSmallestUnit converts a float64 amount to the smallest currency unit.
// This ensures precision by avoiding floating-point arithmetic issues.
// Returns an error if the amount is non-finite or would overflow int64.
func convertToSmallestUnit(amount float64, currency Currency) (int64, error) {
// Validate input is a finite number
if math.IsNaN(amount) || math.IsInf(amount, 0) {
return 0, fmt.Errorf("%w: non-finite amount", ErrInvalidAmount)
}
// Convert to big.Rat for precise decimal arithmetic
factor := new(big.Rat).SetFloat64(math.Pow10(currency.Decimals))
amountRat := new(big.Rat).SetFloat64(amount)
if amountRat == nil {
return 0, fmt.Errorf("%w: failed to convert amount to rational number", ErrInvalidAmount)
}
// Multiply by the decimal factor
result := new(big.Rat).Mul(amountRat, factor)
// Convert to float64 for rounding, then check bounds
resultFloat, _ := result.Float64()
rounded := math.Round(resultFloat)
// Check for int64 overflow after rounding
if rounded > float64(math.MaxInt64) || rounded < float64(math.MinInt64) {
return 0, fmt.Errorf(
"%w: amount exceeds maximum representable value",
ErrAmountExceedsMaxSafeInt,
)
}
return int64(rounded), nil
}
// Package money provides functionality for handling monetary values.
package money
// Currency represents a monetary unit with its standard decimal places
type Currency struct {
Code string // 3-letter ISO 4217 code (e.g., "USD")
Decimals int // Number of decimal places (0-18)
}
// Common currency instances
var (
USD = Currency{"USD", 2} // US Dollar
EUR = Currency{"EUR", 2} // Euro
GBP = Currency{"GBP", 2} // British Pound
JPY = Currency{"JPY", 0} // Japanese Yen
)
// DefaultCurrency is the default currency (USD)
var DefaultCurrency = USD
// IsValid checks if the currency is valid.
func (c Currency) IsValid() bool {
if c.Decimals < 0 || c.Decimals > 18 {
return false
}
return len(c.Code) == 3 &&
c.Code[0] >= 'A' && c.Code[0] <= 'Z' &&
c.Code[1] >= 'A' && c.Code[1] <= 'Z' &&
c.Code[2] >= 'A' && c.Code[2] <= 'Z'
}
// String returns the currency code
func (c Currency) String() string { return c.Code }
// Package money provides functionality for handling monetary values.
package money
import (
"encoding/json"
"errors"
"fmt"
"math"
"math/big"
)
// Common errors
var (
// ErrInvalidCurrency is returned when an invalid currency is provided
// or when there's a currency mismatch in operations.
ErrInvalidCurrency = errors.New("invalid currency")
// ErrInvalidAmount is returned when an invalid amount is provided.
ErrInvalidAmount = errors.New("invalid amount")
// ErrAmountExceedsMaxSafeInt is returned when an amount exceeds the maximum safe integer value.
ErrAmountExceedsMaxSafeInt = errors.New("amount exceeds maximum safe integer value")
)
// Amount represents a monetary amount as an integer in the
// smallest currency unit (e.g., cents for USD).
type Amount = int64
// Money represents a monetary value in a specific currency.
// Invariants:
// - Amount is always stored in the smallest currency unit (e.g., cents for USD).
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
// - All arithmetic operations require matching currencies.
type Money struct {
amount Amount
currency Currency
}
// MarshalJSON implements json.Marshaler interface.
func (m Money) MarshalJSON() ([]byte, error) {
return json.Marshal(map[string]any{
"amount": m.AmountFloat(),
"currency": m.currency.Code,
})
}
// UnmarshalJSON implements json.Unmarshaler interface.
func (m *Money) UnmarshalJSON(data []byte) error {
var aux struct {
Amount float64 `json:"amount"`
Currency string `json:"currency"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
// Create currency and validate
currency := Currency{Code: aux.Currency}
if !currency.IsValid() {
return fmt.Errorf("invalid currency: %s", aux.Currency)
}
smallestUnit, err := convertToSmallestUnit(aux.Amount, currency)
if err != nil {
return fmt.Errorf("invalid amount: %w", err)
}
m.amount = Amount(smallestUnit)
m.currency = currency
return nil
}
// Zero creates a Money object with zero amount in the specified currency.
func Zero(currency Currency) *Money {
return &Money{
amount: 0,
currency: currency,
}
}
// Must creates a Money object from the given amount and currency.
// Invariants enforced:
// - Currency must be valid (valid ISO 4217 code and valid decimal places).
// - Amount must not have more decimal places than allowed by the currency.
// - Amount is converted to the smallest currency unit.
//
// Panics if any invariant is violated.
func Must(amount float64, currency Currency) *Money {
money, err := New(amount, currency)
if err != nil {
panic(fmt.Sprintf(
"money: invalid arguments to Must(%v, %v): %v",
amount,
currency.Code,
err,
))
}
return money
}
// NewFromData creates a Money object from raw data (used for DB hydration).
// This bypasses invariants and should only be used for repository hydration or tests.
// Deprecated: use NewFromSmallestUnit instead.
func NewFromData(amount int64, currencyCode string) *Money {
currency := Currency{Code: currencyCode}
// Set default decimals based on common currencies
switch currencyCode {
case "USD", "EUR", "GBP":
currency.Decimals = 2
case "JPY":
currency.Decimals = 0
default:
currency.Decimals = 2 // Default to 2 decimal places
}
return &Money{
amount: amount,
currency: currency,
}
}
// New creates a new Money object from a float amount and currency.
// The amount is converted to the smallest currency unit (e.g., cents for USD).
func New(amount float64, currency Currency) (*Money, error) {
if !currency.IsValid() {
return nil, fmt.Errorf("invalid currency: %v", currency)
}
smallestUnit, err := convertToSmallestUnit(amount, currency)
if err != nil {
return nil, fmt.Errorf("invalid amount: %w", err)
}
return &Money{
amount: Amount(smallestUnit),
currency: currency,
}, nil
}
// NewFromSmallestUnit creates a new Money object directly from the smallest currency unit.
func NewFromSmallestUnit(amount int64, currency Currency) (*Money, error) {
if !currency.IsValid() {
return nil, fmt.Errorf("invalid currency: %v", currency)
}
return &Money{
amount: Amount(amount),
currency: currency,
}, nil
}
// Amount returns the amount of the Money object in the smallest currency unit.
func (m *Money) Amount() Amount {
return m.amount
}
// AmountFloat returns the amount as a float64 in the main currency unit (e.g., dollars for USD).
func (m *Money) AmountFloat() float64 {
amount := new(big.Rat).SetInt64(int64(m.amount))
divisor := new(big.Rat).SetFloat64(math.Pow10(m.currency.Decimals))
result := new(big.Rat).Quo(amount, divisor)
floatResult, _ := result.Float64()
return floatResult
}
// Currency returns the currency of the Money object.
func (m *Money) Currency() Currency {
return m.currency
}
// IsCurrency checks if the money object has the specified currency
func (m *Money) IsCurrency(currency Currency) bool {
return m.currency == currency
}
// Add returns a new Money object with the sum of amounts.
// Returns an error if the currencies don't match.
func (m *Money) Add(other *Money) (*Money, error) {
if m.currency != other.currency {
return nil, fmt.Errorf(
"cannot add different currencies: %s and %s",
m.currency,
other.currency,
)
}
sum := int64(m.amount) + int64(other.amount)
return &Money{
amount: Amount(sum),
currency: m.currency,
}, nil
}
// Subtract returns a new Money object with the difference of amounts.
// The result can be negative if the subtrahend is larger than the minuend.
// Returns an error if the currencies don't match.
func (m *Money) Subtract(other *Money) (*Money, error) {
if m.currency != other.currency {
return nil, fmt.Errorf(
"cannot subtract different currencies: %s and %s",
m.currency,
other.currency,
)
}
diff := int64(m.amount) - int64(other.amount)
return &Money{
amount: Amount(diff),
currency: m.currency,
}, nil
}
// Negate negates the current Money object.
func (m *Money) Negate() *Money {
return &Money{
amount: -m.amount,
currency: m.currency,
}
}
// Equals checks if the current Money object is equal to another Money object.
// Invariants enforced:
// - Currencies must match.
func (m *Money) Equals(other *Money) bool {
if m == nil || other == nil {
return false
}
return m.currency == other.currency && m.amount == other.amount
}
// GreaterThan checks if the current Money object is greater than another Money object.
// Returns an error if the currencies don't match.
func (m *Money) GreaterThan(other *Money) (bool, error) {
if m.currency != other.currency {
return false, fmt.Errorf(
"cannot compare different currencies: %s and %s",
m.currency,
other.currency,
)
}
return m.amount > other.amount, nil
}
// LessThan checks if the current Money object is less than another Money object.
// Invariants enforced:
// - Currencies must match.
//
// Returns an error if currencies do not match.
func (m *Money) LessThan(other *Money) (bool, error) {
if !m.IsSameCurrency(other) {
return false, ErrInvalidCurrency
}
return m.amount < other.amount, nil
}
// IsSameCurrency checks if the current Money object has the same currency as another Money object.
func (m *Money) IsSameCurrency(other *Money) bool {
return m.currency == other.currency
}
// IsPositive returns true if the amount is greater than zero.
func (m *Money) IsPositive() bool {
return m.amount > 0
}
// IsNegative returns true if the amount is less than zero.
func (m *Money) IsNegative() bool {
return m.amount < 0
}
// IsZero returns true if the amount is zero.
func (m *Money) IsZero() bool {
return m.amount == 0
}
// Abs returns the absolute value of the Money amount.
func (m *Money) Abs() *Money {
if m.amount < 0 {
return m.Negate()
}
return m
}
// Multiply multiplies the Money amount by a scalar factor.
// Invariants enforced:
// - Result must not overflow int64.
// - Result is rounded to the nearest integer to preserve precision.
//
// Returns Money or an error if overflow would occur.
func (m *Money) Multiply(factor float64) (*Money, error) {
// Convert to float for multiplication and round to nearest integer
resultFloat := float64(m.amount) * factor
// Check for overflow
if resultFloat > float64(math.MaxInt64) || resultFloat < float64(math.MinInt64) {
return nil, fmt.Errorf("multiplication result would overflow")
}
// Round to nearest integer to avoid truncation of fractional cents
rounded := int64(math.Round(resultFloat))
return &Money{
amount: Amount(rounded),
currency: m.currency,
}, nil
}
// Divide divides the Money amount by a scalar divisor.
// Invariants enforced:
// - Divisor must not be zero.
// - Result must not overflow int64.
// - Division must not lose precision.
//
// Returns Money or an error if any invariant is violated.
func (m *Money) Divide(divisor float64) (*Money, error) {
if divisor == 0 {
return nil, fmt.Errorf("division by zero")
}
// Convert to float for division
resultFloat := float64(m.amount) / divisor
// Check for overflow
if resultFloat > float64(math.MaxInt64) || resultFloat < float64(math.MinInt64) {
return nil, fmt.Errorf("division result would overflow")
}
// Check if result is an integer (no precision loss)
if resultFloat != float64(int64(resultFloat)) {
return nil, fmt.Errorf("division would result in precision loss")
}
return &Money{
amount: Amount(int64(resultFloat)),
currency: m.currency,
}, nil
}
// String returns a string representation of the Money object.
func (m *Money) String() string {
return fmt.Sprintf("%.*f %s", m.currency.Decimals, m.AmountFloat(), m.currency)
}
// convertToSmallestUnit converts a float64 amount to the smallest currency unit.
// This ensures precision by avoiding floating-point arithmetic issues.
func convertToSmallestUnit(amount float64, currency Currency) (int64, error) {
factor := new(big.Rat).SetFloat64(math.Pow10(currency.Decimals))
amountRat := new(big.Rat).SetFloat64(amount)
result := new(big.Rat).Mul(amountRat, factor)
// Round to nearest integer
resultFloat, _ := result.Float64()
return int64(math.Round(resultFloat)), nil
}
package exchange
import (
"context"
"sync"
"time"
)
// Cache provides an in-memory cache for exchange rates
type Cache struct {
store map[string]rateCacheEntry
mu sync.RWMutex
ttl time.Duration
}
type rateCacheEntry struct {
value *RateInfo
expiresAt time.Time
}
// NewCache creates a new cache with the given TTL
func NewCache(ttl time.Duration) *Cache {
return &Cache{
store: make(map[string]rateCacheEntry),
ttl: ttl,
}
}
// GetRate gets a rate from the cache
func (c *Cache) GetRate(ctx context.Context, from, to string) (*RateInfo, error) {
key := cacheKey(from, to)
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.store[key]
if !exists {
return nil, nil
}
if time.Now().After(entry.expiresAt) {
// Entry has expired
return nil, nil
}
return entry.value, nil
}
// StoreRate stores a rate in the cache
func (c *Cache) StoreRate(ctx context.Context, rate *RateInfo) error {
if rate == nil {
return nil
}
key := cacheKey(rate.FromCurrency, rate.ToCurrency)
c.mu.Lock()
defer c.mu.Unlock()
c.store[key] = rateCacheEntry{
value: rate,
expiresAt: time.Now().Add(c.ttl),
}
return nil
}
// BatchGetRates gets multiple rates from the cache
func (c *Cache) BatchGetRates(
ctx context.Context,
from string,
to []string,
) (map[string]*RateInfo, error) {
result := make(map[string]*RateInfo, len(to))
c.mu.RLock()
defer c.mu.RUnlock()
now := time.Now()
for _, currency := range to {
key := cacheKey(from, currency)
if entry, exists := c.store[key]; exists && now.Before(entry.expiresAt) {
result[currency] = entry.value
}
}
return result, nil
}
// Clear removes all entries from the cache
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
c.store = make(map[string]rateCacheEntry)
}
// cacheKey generates a consistent cache key for a currency pair
func cacheKey(from, to string) string {
return from + "_" + to
}
package provider
import (
"context"
"errors"
"sync"
"time"
)
// HealthCheckAll checks the health of all providers and returns a map of results
func HealthCheckAll(
ctx context.Context,
providers []HealthChecker,
) map[string]error {
results := make(map[string]error)
var mu sync.Mutex
var wg sync.WaitGroup
for _, p := range providers {
wrapped, ok := p.(interface{ Metadata() ProviderMetadata })
if !ok {
continue
}
wg.Add(1)
go func(p HealthChecker, name string) {
defer wg.Done()
err := p.CheckHealth(ctx)
mu.Lock()
results[name] = err
mu.Unlock()
}(p, wrapped.Metadata().Name)
}
wg.Wait()
return results
}
// FirstHealthy returns the first healthy provider from the list
func FirstHealthy(
ctx context.Context,
providers []HealthChecker,
) (HealthChecker, error) {
for _, p := range providers {
if p.CheckHealth(ctx) == nil {
return p, nil
}
}
return nil, errors.New("no healthy providers available")
}
// AllHealthy checks if all providers are healthy
func AllHealthy(
ctx context.Context,
providers []HealthChecker,
) error {
var errs []error
for _, p := range providers {
if err := p.CheckHealth(ctx); err != nil {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return errors.Join(errs...)
}
return nil
}
// RateStats contains statistics about exchange rates
type RateStats struct {
Min float64
Max float64
Average float64
Count int
}
// CalculateStats calculates statistics for a set of rates
func CalculateStats(rates []float64) RateStats {
if len(rates) == 0 {
return RateStats{}
}
min := rates[0]
max := rates[0]
sum := 0.0
for _, rate := range rates {
if rate < min {
min = rate
}
if rate > max {
max = rate
}
sum += rate
}
return RateStats{
Min: min,
Max: max,
Average: sum / float64(len(rates)),
Count: len(rates),
}
}
// RateHistory tracks historical rate data
type RateHistory struct {
rates []RateInfo
mu sync.RWMutex
size int
}
// NewRateHistory creates a new RateHistory with the specified maximum size
func NewRateHistory(size int) *RateHistory {
return &RateHistory{
rates: make([]RateInfo, 0, size),
size: size,
}
}
// Add adds a new rate to the history
func (h *RateHistory) Add(rate RateInfo) {
h.mu.Lock()
defer h.mu.Unlock()
h.rates = append(h.rates, rate)
// Trim the slice if it exceeds the maximum size
if len(h.rates) > h.size {
h.rates = h.rates[len(h.rates)-h.size:]
}
}
// Get returns the rate history
func (h *RateHistory) Get() []RateInfo {
h.mu.RLock()
defer h.mu.RUnlock()
rates := make([]RateInfo, len(h.rates))
copy(rates, h.rates)
return rates
}
// Average calculates the average rate over the specified duration
func (h *RateHistory) Average(since time.Time) float64 {
h.mu.RLock()
defer h.mu.RUnlock()
var sum float64
var count int
for _, rate := range h.rates {
if rate.Timestamp.After(since) {
sum += rate.Rate
count++
}
}
if count == 0 {
return 0
}
return sum / float64(count)
}
package queries
import "github.com/google/uuid"
type GetAccountQuery struct {
AccountID uuid.UUID
UserID uuid.UUID
}
func (q GetAccountQuery) Type() string { return "GetAccountQuery" }
type GetAccountResult struct {
AccountID uuid.UUID
UserID uuid.UUID
Balance float64
Currency string
}
package registry
import (
"encoding/json"
"errors"
"fmt"
"maps"
"sync"
"time"
)
// / BaseEntity provides a thread-safe default implementation of core entity interfaces.
// It serves as a foundation for domain-specific entities by providing common
// functionality including:
// - Unique identifier management (ID)
// - Naming and activation state
// - Key-value metadata storage
// - Creation and modification timestamps
// - Concurrent access safety
//
// BaseEntity implements the following interfaces:
// - Identifiable: For ID management
// - Named: For name-related operations
// - ActivationController: For activation state management
// - MetadataController: For metadata operations
// - Timestamped: For creation/update timestamps
// - Entity: Composite interface for backward compatibility
//
// Example usage:
//
// type User struct {
// registry.BaseEntity
// Email string
// Password string `json:"-"`
// }
//
// All exported methods are safe for concurrent access. The struct uses a read-write mutex
// to protect all internal state. When embedding BaseEntity, ensure proper initialization
// of the embedded fields.
type BaseEntity struct {
id string
name string
active bool
metadata map[string]string
createdAt time.Time
updatedAt time.Time
mu sync.RWMutex
// Deprecated: Use ID() and SetID() methods instead
BEId string
// Deprecated: Use Name() and SetName() methods instead
BEName string
// Deprecated: Use Active() and SetActive() methods instead
BEActive bool
// Deprecated: Use Metadata() and related methods instead
BEMetadata map[string]string
// Deprecated: Use CreatedAt() and UpdatedAt() methods instead
BECreatedAt time.Time
// Deprecated: Use CreatedAt() and UpdatedAt() methods instead
BEUpdatedAt time.Time
}
// MarshalJSON implements the json.Marshaler interface.
// It provides custom JSON marshaling for BaseEntity.
// This method is safe for concurrent access.
func (e *BaseEntity) MarshalJSON() ([]byte, error) {
e.mu.RLock()
defer e.mu.RUnlock()
// Create a map to hold the JSON representation with core fields
data := map[string]any{
"id": e.id,
"name": e.name,
"active": e.active,
"created_at": e.createdAt.Format(time.RFC3339Nano),
"updated_at": e.updatedAt.Format(time.RFC3339Nano),
}
// Include metadata as a separate object, not at the root level
if len(e.metadata) > 0 {
// Create a copy of the metadata to avoid concurrent access issues
metadataCopy := make(map[string]string, len(e.metadata))
maps.Copy(metadataCopy, e.metadata)
// Only include the full metadata object, not individual fields at root
data["metadata"] = metadataCopy
}
return json.Marshal(data)
}
// UnmarshalJSON implements the json.Unmarshaler interface.
// It provides custom JSON unmarshaling for BaseEntity.
// This method is safe for concurrent access.
func (e *BaseEntity) UnmarshalJSON(data []byte) error {
// Use a type alias to avoid recursion
type Alias BaseEntity
// Create an auxiliary struct to handle the JSON unmarshaling
aux := &struct {
*Alias
CreatedAt string `json:"created_at"`
UpdatedAt string `json:"updated_at"`
}{
Alias: (*Alias)(e),
}
// First, unmarshal into a map to handle all fields
var rawData map[string]json.RawMessage
if err := json.Unmarshal(data, &rawData); err != nil {
return fmt.Errorf("failed to unmarshal BaseEntity: %w", err)
}
// Unmarshal the standard fields
if err := json.Unmarshal(data, &aux); err != nil {
return fmt.Errorf("failed to unmarshal BaseEntity: %w", err)
}
// Parse the timestamps
var err error
if aux.CreatedAt != "" {
e.createdAt, err = time.Parse(time.RFC3339Nano, aux.CreatedAt)
if err != nil {
return fmt.Errorf("failed to parse created_at: %w", err)
}
}
if aux.UpdatedAt != "" {
e.updatedAt, err = time.Parse(time.RFC3339Nano, aux.UpdatedAt)
if err != nil {
return fmt.Errorf("failed to parse updated_at: %w", err)
}
}
// Initialize the metadata map
e.mu.Lock()
defer e.mu.Unlock()
if e.metadata == nil {
e.metadata = make(map[string]string)
}
// Handle metadata from the metadata object if it exists
if metadataData, ok := rawData["metadata"]; ok {
var metadata map[string]string
if err := json.Unmarshal(metadataData, &metadata); err == nil {
maps.Copy(e.metadata, metadata)
}
}
// Handle direct metadata fields that might be at the root level
for k, v := range rawData {
// Skip standard fields that we already handle
switch k {
case "id", "name", "active", "created_at", "updated_at", "metadata":
continue
}
// For other fields, try to unmarshal as string and add to metadata
var strVal string
if err := json.Unmarshal(v, &strVal); err == nil {
e.metadata[k] = strVal
}
}
return nil
}
// Ensure BaseEntity implements all interfaces
var (
_ Identity = (*BaseEntity)(nil)
_ Named = (*BaseEntity)(nil)
_ ActivationController = (*BaseEntity)(nil)
_ MetadataController = (*BaseEntity)(nil)
_ Timestamped = (*BaseEntity)(nil)
_ MetadataReader = (*BaseEntity)(nil)
_ MetadataWriter = (*BaseEntity)(nil)
_ MetadataRemover = (*BaseEntity)(nil)
_ MetadataClearer = (*BaseEntity)(nil)
_ Entity = (*BaseEntity)(nil) // For backward compatibility
)
// DeleteMetadata removes a metadata key from the entity.
// It's safe for concurrent access.
func (e *BaseEntity) DeleteMetadata(key string) {
e.mu.Lock()
defer e.mu.Unlock()
if e.metadata == nil {
e.metadata = make(map[string]string)
return
}
delete(e.metadata, key)
e.updatedAt = time.Now().UTC()
}
// SetID sets the ID of the entity.
// It's safe for concurrent access.
func (e *BaseEntity) SetID(id string) error {
e.mu.Lock()
defer e.mu.Unlock()
if id == "" {
return errors.New("id cannot be empty")
}
e.id = id
e.BEId = id // For backward compatibility
e.updatedAt = time.Now().UTC()
return nil
}
// SetName sets the name of the entity.
// It's safe for concurrent access.
func (e *BaseEntity) SetName(name string) error {
e.mu.Lock()
defer e.mu.Unlock()
if name == "" {
return errors.New("name cannot be empty")
}
e.name = name
e.BEName = name // For backward compatibility
e.updatedAt = time.Now().UTC()
return nil
}
// SetActive sets the active state of the entity.
// It's safe for concurrent access.
func (e *BaseEntity) SetActive(active bool) {
e.mu.Lock()
defer e.mu.Unlock()
e.active = active
e.BEActive = active // For backward compatibility
e.updatedAt = time.Now().UTC()
}
// SetMetadata sets a metadata key-value pair.
// It's safe for concurrent access.
func (e *BaseEntity) SetMetadata(key, value string) {
e.mu.Lock()
defer e.mu.Unlock()
if e.metadata == nil {
e.metadata = make(map[string]string)
}
e.metadata[key] = value
// For backward compatibility
if e.BEMetadata == nil {
e.BEMetadata = make(map[string]string)
}
e.BEMetadata[key] = value
e.updatedAt = time.Now().UTC()
}
// ClearMetadata removes all metadata from the entity.
// It's safe for concurrent access.
func (e *BaseEntity) ClearMetadata() {
e.mu.Lock()
defer e.mu.Unlock()
e.metadata = make(map[string]string)
e.BEMetadata = make(map[string]string) // For backward compatibility
e.updatedAt = time.Now().UTC()
}
// Metadata returns a copy of the entity's metadata.
// It's safe for concurrent access.
func (e *BaseEntity) Metadata() map[string]string {
e.mu.RLock()
defer e.mu.RUnlock()
// Return a copy to prevent external modifications
result := make(map[string]string, len(e.metadata))
maps.Copy(result, e.metadata)
return result
}
// ID returns the entity's ID.
// It's safe for concurrent access.
func (e *BaseEntity) ID() string {
e.mu.RLock()
defer e.mu.RUnlock()
if e.id == "" && e.BEId != "" {
return e.BEId // For backward compatibility
}
return e.id
}
// Name returns the entity's name.
// It's safe for concurrent access.
func (e *BaseEntity) Name() string {
e.mu.RLock()
defer e.mu.RUnlock()
if e.name == "" && e.BEName != "" {
return e.BEName // For backward compatibility
}
return e.name
}
// Active returns whether the entity is active.
// It's safe for concurrent access.
func (e *BaseEntity) Active() bool {
e.mu.RLock()
defer e.mu.RUnlock()
return e.active || e.BEActive // For backward compatibility
}
// CreatedAt returns when the entity was created.
// It's safe for concurrent access.
func (e *BaseEntity) CreatedAt() time.Time {
e.mu.RLock()
defer e.mu.RUnlock()
return e.createdAt
}
// UpdatedAt returns when the entity was last updated.
// It's safe for concurrent access.
func (e *BaseEntity) UpdatedAt() time.Time {
e.mu.RLock()
defer e.mu.RUnlock()
return e.updatedAt
}
// Ensure BaseEntity implements all the interfaces it's meant to support
var (
_ Identifier = (*BaseEntity)(nil)
_ IDSetter = (*BaseEntity)(nil)
_ Named = (*BaseEntity)(nil)
_ NameSetter = (*BaseEntity)(nil)
_ ActiveStatusChecker = (*BaseEntity)(nil)
_ ActivationController = (*BaseEntity)(nil)
_ MetadataReader = (*BaseEntity)(nil)
_ MetadataWriter = (*BaseEntity)(nil)
_ MetadataRemover = (*BaseEntity)(nil)
_ MetadataClearer = (*BaseEntity)(nil)
_ Timestamped = (*BaseEntity)(nil)
_ Entity = (*BaseEntity)(nil)
)
// NewBaseEntity creates a new BaseEntity with the given id and name.
// The entity will be active by default and have the creation time set to now.
// Returns an error if id or name is empty.
//
// This function returns a concrete *BaseEntity type. If you need the Entity interface,
// the return value can be assigned to an Entity variable.
func NewBaseEntity(id, name string) *BaseEntity {
now := time.Now().UTC()
entity := &BaseEntity{
id: id,
name: name,
active: true,
metadata: make(map[string]string),
createdAt: now,
updatedAt: now,
// Initialize BEFields for backward compatibility
BEId: id,
BEName: name,
BEActive: true,
BEMetadata: make(map[string]string),
}
return entity
}
// RemoveMetadata removes a metadata key and updates the updated timestamp.
// If the key doesn't exist, this is a no-op.
// This method is safe for concurrent access.
func (e *BaseEntity) RemoveMetadata(key string) {
e.mu.Lock()
defer e.mu.Unlock()
if e.metadata != nil {
if _, exists := e.metadata[key]; exists {
delete(e.metadata, key)
e.updatedAt = time.Now().UTC()
}
}
}
// SetMetadataMap sets multiple metadata key-value pairs at once and updates the updated timestamp.
// This is more efficient than calling SetMetadata multiple times as it only acquires the lock once.
// This method is safe for concurrent access.
func (e *BaseEntity) SetMetadataMap(metadata map[string]string) {
if len(metadata) == 0 {
return
}
e.mu.Lock()
defer e.mu.Unlock()
if e.metadata == nil {
e.metadata = make(map[string]string, len(metadata))
}
maps.Copy(e.metadata, metadata)
// Backward compatibility
if e.BEMetadata == nil {
e.BEMetadata = make(map[string]string, len(metadata))
}
maps.Copy(e.BEMetadata, metadata)
e.updatedAt = time.Now().UTC()
}
// HasMetadata checks if the entity has a specific metadata key.
// This method is safe for concurrent access.
func (e *BaseEntity) HasMetadata(key string) bool {
if key == "" {
return false
}
e.mu.RLock()
_, exists := e.metadata[key]
e.mu.RUnlock()
return exists
}
// GetMetadataValue returns the value for a metadata key and whether it exists.
// This method is more efficient than Metadata() when you only need one value.
func (e *BaseEntity) GetMetadataValue(key string) (string, bool) {
if key == "" {
return "", false
}
e.mu.RLock()
defer e.mu.RUnlock()
// Check in new metadata first
if e.metadata != nil {
if val, exists := e.metadata[key]; exists {
return val, true
}
}
// Backward compatibility - check in BEMetadata if not found in metadata
if e.BEMetadata != nil {
if val, exists := e.BEMetadata[key]; exists {
// If we found it in BEMetadata but not in metadata, sync it
if e.metadata == nil {
e.metadata = make(map[string]string)
}
e.metadata[key] = val
return val, true
}
}
return "", false
}
// Use only when you're certain the inputs are valid.
func MustNewBaseEntity(id, name string) *BaseEntity {
entity := NewBaseEntity(id, name)
return entity
}
package registry
// Deprecated: Use Enhanced instead
// NewEnhancedRegistry was renamed to NewEnhanced for brevity
type EnhancedRegistry = Enhanced
// Deprecated: Use NewEnhanced instead
// NewEnhancedRegistry creates a new enhanced registry
func NewEnhancedRegistry(config Config) *Enhanced {
return NewEnhanced(config)
}
// Deprecated: Use Provider instead
// RegistryProvider is the old name for the Provider interface
type RegistryProvider = Provider
// Deprecated: Use Config instead
// RegistryConfig is the old name for the Config struct
type RegistryConfig = Config
// Deprecated: Use Entity instead
// RegistryEntity is the old name for the Entity interface
type RegistryEntity = Entity
// Deprecated: Use Cache instead
// RegistryCache is the old name for the Cache interface
type RegistryCache = Cache
// Deprecated: Use Persistence instead
// RegistryPersistence is the old name for the Persistence interface
type RegistryPersistence = Persistence
// Deprecated: Use Metrics instead
// RegistryMetrics is the old name for the Metrics interface
type RegistryMetrics = Metrics
// Deprecated: Use Health instead
// RegistryHealth is the old name for the Health interface
type RegistryHealth = Health
// Deprecated: Use EventBus instead
// RegistryEventBus is the old name for the EventBus interface
type RegistryEventBus = EventBus
// Deprecated: Use Validator instead
// RegistryValidator is the old name for the Validator interface
type RegistryValidator = Validator
// Deprecated: Use Observer instead
// RegistryObserver is the old name for the Observer interface
type RegistryObserver = Observer
// Deprecated: Use Event instead
// RegistryEvent is the old name for the Event struct
type RegistryEvent = Event
// Deprecated: Use Factory instead
// RegistryFactory is the old name for the Factory interface
type RegistryFactory = Factory
// Deprecated: Use FactoryImpl instead
// RegistryFactoryImpl is the old name for the FactoryImpl struct
type RegistryFactoryImpl = FactoryImpl
// Deprecated: Use NewFactory instead
// NewRegistryFactory is the old name for NewFactory
func NewRegistryFactory() Factory {
return NewFactory()
}
// Deprecated: Use NewBuilder instead
// NewRegistryBuilder is the old name for NewBuilder
func NewRegistryBuilder() *Builder {
return NewBuilder()
}
package registry
import (
"context"
"fmt"
"log"
"strconv"
"strings"
"sync"
"time"
)
// Enhanced provides a full-featured registry implementation
type Enhanced struct {
config Config
entities map[string]Entity
mu sync.RWMutex
observers []Observer
validator Validator
cache Cache
persistence Persistence
metrics Metrics
health Health
eventBus EventBus
}
// NewEnhanced creates a new enhanced registry
func NewEnhanced(config Config) *Enhanced {
return &Enhanced{
config: config,
entities: make(map[string]Entity),
observers: make([]Observer, 0),
}
}
// WithValidator sets the validator for the registry
func (r *Enhanced) WithValidator(validator Validator) *Enhanced {
r.validator = validator
return r
}
// WithCache sets the cache for the registry
func (r *Enhanced) WithCache(cache Cache) *Enhanced {
r.cache = cache
return r
}
// WithPersistence sets the persistence layer for the registry
func (r *Enhanced) WithPersistence(persistence Persistence) *Enhanced {
r.persistence = persistence
return r
}
// WithMetrics sets the metrics collector for the registry
func (r *Enhanced) WithMetrics(metrics Metrics) *Enhanced {
r.metrics = metrics
return r
}
// WithHealth sets the health checker for the registry
func (r *Enhanced) WithHealth(health Health) *Enhanced {
r.health = health
return r
}
// WithEventBus sets the event bus for the registry
func (r *Enhanced) WithEventBus(eventBus EventBus) *Enhanced {
r.eventBus = eventBus
return r
}
// Register adds or updates an entity in the registry
func (r *Enhanced) Register(ctx context.Context, entity Entity) error {
start := time.Now()
defer func() {
if r.metrics != nil {
r.metrics.RecordLatency("register", time.Since(start))
}
}()
// Validate entity if validator is set
if r.validator != nil {
if err := r.validator.Validate(ctx, entity); err != nil {
if r.metrics != nil {
r.metrics.IncrementError()
}
return fmt.Errorf("validation failed: %w", err)
}
}
// Check max entities limit
if r.config.MaxEntities > 0 {
r.mu.RLock()
currentCount := len(r.entities)
r.mu.RUnlock()
if currentCount >= r.config.MaxEntities {
if r.metrics != nil {
r.metrics.IncrementError()
}
return fmt.Errorf("registry is full (max entities: %d)", r.config.MaxEntities)
}
}
r.mu.Lock()
defer r.mu.Unlock()
// Check if this is an update
_, exists := r.entities[entity.ID()]
// For any entity type, create a new BaseEntity copy to ensure thread safety
copy := NewBaseEntity(entity.ID(), entity.Name())
// Copy the active state from the original entity
copy.SetActive(entity.Active())
// Copy metadata
for k, v := range entity.Metadata() {
copy.SetMetadata(k, v)
}
// Ensure the active status is reflected in metadata for backward compatibility
copy.SetMetadata("active", strconv.FormatBool(entity.Active()))
// Store the copy
r.entities[copy.ID()] = copy
// Update cache if enabled
if r.cache != nil {
if err := r.cache.Set(ctx, entity); err != nil {
log.Printf("warning: failed to update cache: %v", err)
}
}
// Update persistence if enabled
if r.persistence != nil {
if err := r.persistence.Save(ctx, r.getAllEntitiesLocked()); err != nil {
log.Printf("warning: failed to persist registry: %v", err)
}
}
// Update metrics
if r.metrics != nil {
r.metrics.IncrementRegistration()
r.metrics.SetEntityCount(len(r.entities))
r.metrics.SetActiveCount(r.countActiveLocked())
}
// Emit event
if r.eventBus != nil {
eventType := EventEntityRegistered
if exists {
eventType = EventEntityUpdated
}
if err := r.emitEvent(eventType, entity); err != nil {
log.Printf("warning: failed to emit %s event: %v", eventType, err)
}
}
// Notify observers
for _, observer := range r.observers {
if exists {
observer.OnEntityUpdated(ctx, entity)
} else {
observer.OnEntityRegistered(ctx, entity)
}
}
return nil
}
// Get retrieves an entity by ID
func (r *Enhanced) Get(ctx context.Context, id string) (Entity, error) {
start := time.Now()
defer func() {
if r.metrics != nil {
r.metrics.RecordLatency("get", time.Since(start))
}
}()
// Try cache first
if r.cache != nil {
if entity, found := r.cache.Get(ctx, id); found {
if r.metrics != nil {
r.metrics.IncrementLookup()
}
return entity, nil
}
}
r.mu.RLock()
entity, exists := r.entities[id]
r.mu.RUnlock()
if !exists {
if r.metrics != nil {
r.metrics.IncrementError()
}
return nil, fmt.Errorf("entity not found: %s", id)
}
// Update cache
if r.cache != nil {
if err := r.cache.Set(ctx, entity); err != nil {
// Log cache set error but don't fail the operation
log.Printf("warning: failed to update cache for entity %s: %v", entity.ID(), err)
}
}
if r.metrics != nil {
r.metrics.IncrementLookup()
}
return entity, nil
}
// Unregister removes an entity from the registry
func (r *Enhanced) Unregister(ctx context.Context, id string) error {
start := time.Now()
defer func() {
if r.metrics != nil {
r.metrics.RecordLatency("unregister", time.Since(start))
}
}()
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.entities[id]; !exists {
if r.metrics != nil {
r.metrics.IncrementError()
}
return fmt.Errorf("entity not found: %s", id)
}
delete(r.entities, id)
// Remove from cache if available
if r.cache != nil {
if err := r.cache.Delete(ctx, id); err != nil {
// Log cache delete error but don't fail the operation
log.Printf("warning: failed to delete entity %s from cache: %v", id, err)
}
}
// Update metrics
if r.metrics != nil {
r.metrics.IncrementUnregistration()
r.metrics.SetEntityCount(len(r.entities))
activeCount := r.countActiveLocked()
r.metrics.SetActiveCount(activeCount)
}
// Publish event
if r.eventBus != nil {
event := Event{
Type: EventEntityUnregistered,
EntityID: id,
Timestamp: time.Now(),
}
if err := r.eventBus.Emit(ctx, event); err != nil {
// Log event emission error but don't fail the operation
log.Printf("warning: failed to emit unregister event for entity %s: %v", id, err)
}
}
// Notify observers
for _, observer := range r.observers {
observer.OnEntityUnregistered(ctx, id)
}
return nil
}
// IsRegistered checks if an entity is registered
func (r *Enhanced) IsRegistered(ctx context.Context, id string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.entities[id]
return exists
}
// List returns all entities
func (r *Enhanced) List(ctx context.Context) ([]Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entities := make([]Entity, 0, len(r.entities))
for _, entity := range r.entities {
entities = append(entities, entity)
}
return entities, nil
}
// ListActive returns all active entities
func (r *Enhanced) ListActive(ctx context.Context) ([]Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entities := make([]Entity, 0)
for _, entity := range r.entities {
if entity.Active() {
entities = append(entities, entity)
}
}
return entities, nil
}
// ListByMetadata returns entities with specific metadata
func (r *Enhanced) ListByMetadata(
ctx context.Context,
key, value string,
) ([]Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entities := make([]Entity, 0)
for _, entity := range r.entities {
if metadata := entity.Metadata(); metadata != nil {
if val, exists := metadata[key]; exists && val == value {
entities = append(entities, entity)
}
}
}
return entities, nil
}
// Count returns the total number of entities
func (r *Enhanced) Count(ctx context.Context) (int, error) {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.entities), nil
}
// CountActive returns the number of active entities
func (r *Enhanced) CountActive(ctx context.Context) (int, error) {
r.mu.RLock()
defer r.mu.RUnlock()
return r.countActiveLocked(), nil
}
// countActiveLocked is a helper method that assumes the lock is already held
func (r *Enhanced) countActiveLocked() int {
count := 0
for _, entity := range r.entities {
if entity.Active() {
count++
}
}
return count
}
// GetMetadata retrieves specific metadata for an entity
func (r *Enhanced) GetMetadata(ctx context.Context, id, key string) (string, error) {
entity, err := r.Get(ctx, id)
if err != nil {
return "", err
}
metadata := entity.Metadata()
if val, exists := metadata[key]; exists {
return val, nil
}
return "", fmt.Errorf("metadata key not found: %s", key)
}
// SetMetadata sets specific metadata for an entity
func (r *Enhanced) SetMetadata(ctx context.Context, id, key, value string) error {
entity, err := r.Get(ctx, id)
if err != nil {
return err
}
// Validate metadata if validator is set
if r.validator != nil {
metadata := entity.Metadata()
metadata[key] = value
if err := r.validator.ValidateMetadata(ctx, metadata); err != nil {
return fmt.Errorf("metadata validation failed: %w", err)
}
}
// Skip protected fields with a warning
if isProtectedField(key) {
log.Printf("warning: skipping metadata set for protected field: %s", key)
return nil
}
// Update the entity's metadata using the proper method
switch e := entity.(type) {
case *BaseEntity:
e.SetMetadata(key, value)
default:
// Fallback for other implementations
metadata := entity.Metadata()
metadata[key] = value
}
// Re-register the entity to update it
return r.Register(ctx, entity)
}
// RemoveMetadata removes specific metadata from an entity
func (r *Enhanced) RemoveMetadata(ctx context.Context, id, key string) error {
entity, err := r.Get(ctx, id)
if err != nil {
return err
}
// Remove metadata using the proper method
switch e := entity.(type) {
case *BaseEntity:
e.DeleteMetadata(key)
default:
// Fallback for other implementations
metadata := entity.Metadata()
delete(metadata, key)
}
// Re-register the entity to update it
return r.Register(ctx, entity)
}
// Activate activates an entity
func (r *Enhanced) Activate(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
entity, exists := r.entities[id]
if !exists {
return fmt.Errorf("entity not found: %s", id)
}
// Use the existing entity's SetActive method if it exists
if activator, ok := entity.(interface{ SetActive(bool) }); ok {
activator.SetActive(true)
}
// Also ensure the active status is set in metadata for backward compatibility
entity.SetMetadata("active", "true")
// Update cache if enabled
if r.cache != nil {
if err := r.cache.Set(ctx, entity); err != nil {
log.Printf("warning: failed to update cache: %v", err)
}
}
// Update persistence if enabled
if r.persistence != nil {
if err := r.persistence.Save(ctx, r.getAllEntitiesLocked()); err != nil {
log.Printf("warning: failed to persist registry: %v", err)
}
}
// Update metrics
if r.metrics != nil {
// No increment of registration count since it's an update
r.metrics.SetActiveCount(r.countActiveLocked())
}
// Emit event
if r.eventBus != nil {
if err := r.emitEvent(EventEntityActivated, entity); err != nil {
log.Printf("warning: failed to emit %s event: %v", EventEntityActivated, err)
}
}
// Notify observers
for _, observer := range r.observers {
observer.OnEntityUpdated(ctx, entity)
}
return nil
}
// Deactivate deactivates an entity
func (r *Enhanced) Deactivate(ctx context.Context, id string) error {
r.mu.Lock()
defer r.mu.Unlock()
entity, exists := r.entities[id]
if !exists {
return fmt.Errorf("entity not found: %s", id)
}
// Use the existing entity's SetActive method if it exists
if activator, ok := entity.(interface{ SetActive(bool) }); ok {
activator.SetActive(false)
}
// Also ensure the active status is set in metadata for backward compatibility
entity.SetMetadata("active", "false")
// Update cache if enabled
if r.cache != nil {
if err := r.cache.Set(ctx, entity); err != nil {
log.Printf("warning: failed to update cache: %v", err)
}
}
// Update persistence if enabled
if r.persistence != nil {
if err := r.persistence.Save(ctx, r.getAllEntitiesLocked()); err != nil {
log.Printf("warning: failed to persist registry: %v", err)
}
}
// Update metrics
if r.metrics != nil {
// No increment of registration count since it's an update
r.metrics.SetActiveCount(r.countActiveLocked())
}
// Emit event
if r.eventBus != nil {
if err := r.emitEvent(EventEntityDeactivated, entity); err != nil {
log.Printf("warning: failed to emit %s event: %v", EventEntityDeactivated, err)
}
}
// Notify observers
for _, observer := range r.observers {
observer.OnEntityUpdated(ctx, entity)
}
return nil
}
// ...
// Search performs a simple search on entity names
func (r *Enhanced) Search(ctx context.Context, query string) ([]Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entities := make([]Entity, 0)
for _, entity := range r.entities {
if contains(entity.Name(), query) {
entities = append(entities, entity)
}
}
return entities, nil
}
// SearchByMetadata searches entities by metadata
func (r *Enhanced) SearchByMetadata(
ctx context.Context,
metadata map[string]string,
) ([]Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
entities := make([]Entity, 0)
for _, entity := range r.entities {
entityMetadata := entity.Metadata()
matches := true
for key, value := range metadata {
if val, exists := entityMetadata[key]; !exists || val != value {
matches = false
break
}
}
if matches {
entities = append(entities, entity)
}
}
return entities, nil
}
// AddObserver adds an observer to the registry
func (r *Enhanced) AddObserver(observer Observer) {
r.mu.Lock()
defer r.mu.Unlock()
r.observers = append(r.observers, observer)
}
// RemoveObserver removes an observer from the registry
func (r *Enhanced) RemoveObserver(observer Observer) {
r.mu.Lock()
defer r.mu.Unlock()
for i, obs := range r.observers {
if obs == observer {
r.observers = append(r.observers[:i], r.observers[i+1:]...)
break
}
}
}
// contains is a helper function for string search
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr ||
(len(s) > len(substr) && (s[:len(substr)] == substr ||
s[len(s)-len(substr):] == substr ||
containsSubstring(s, substr))))
}
// containsSubstring is a helper function for substring search
func containsSubstring(s, substr string) bool {
s = strings.ToLower(s)
substr = strings.ToLower(substr)
return strings.Contains(s, substr)
}
// emitEvent is a helper method to emit events to the event bus
func (r *Enhanced) emitEvent(eventType string, entity Entity) error {
event := Event{
Type: eventType,
EntityID: entity.ID(),
Entity: entity,
Timestamp: time.Now(),
}
// Add metadata if available
if metadata := entity.Metadata(); len(metadata) > 0 {
event.Metadata = make(map[string]interface{})
for k, v := range metadata {
event.Metadata[k] = v
}
}
return r.eventBus.Emit(context.Background(), event)
}
// getAllEntitiesLocked returns a slice of all entities in the registry
// The caller must hold the write lock
func (r *Enhanced) getAllEntitiesLocked() []Entity {
entities := make([]Entity, 0, len(r.entities))
for _, entity := range r.entities {
entities = append(entities, entity)
}
return entities
}
package registry
import (
"context"
"fmt"
"strings"
"time"
)
// FactoryImpl implements RegistryFactory
type FactoryImpl struct{}
// NewFactory creates a new registry factory
func NewFactory() Factory {
return &FactoryImpl{}
}
// Create creates a basic registry with the given configuration
func (f *FactoryImpl) Create(
ctx context.Context,
config Config,
) (Provider, error) {
// If Redis URL is provided, use Redis cache
if config.RedisURL != "" {
return f.CreateWithRedisCache(ctx, config, config.RedisURL)
}
// Otherwise, use in-memory cache
registry := NewEnhanced(config)
// Add default implementations if not provided
if config.EnableValidation {
registry.WithValidator(NewSimpleValidator())
}
if config.CacheSize > 0 {
registry.WithCache(NewMemoryCache(config.CacheTTL))
}
if config.EnableEvents {
registry.WithEventBus(NewSimpleEventBus())
}
return registry, nil
}
// CreateWithPersistence creates a registry with persistence
func (f *FactoryImpl) CreateWithPersistence(
ctx context.Context,
config Config,
persistence Persistence,
) (Provider, error) {
registry := NewEnhanced(config)
// Add persistence
registry.WithPersistence(persistence)
// Load existing entities
if entities, err := persistence.Load(ctx); err == nil {
for _, entity := range entities {
if err := registry.Register(ctx, entity); err != nil {
return nil, fmt.Errorf("failed to load entity %s: %w", entity.ID(), err)
}
}
}
// Add other default implementations
if config.EnableValidation {
registry.WithValidator(NewSimpleValidator())
}
if config.CacheSize > 0 {
registry.WithCache(NewMemoryCache(config.CacheTTL))
}
if config.EnableEvents {
registry.WithEventBus(NewSimpleEventBus())
}
return registry, nil
}
// CreateWithCache creates a registry with custom cache
func (f *FactoryImpl) CreateWithCache(
ctx context.Context,
config Config,
cache Cache,
) (Provider, error) {
registry := NewEnhanced(config)
// Add custom cache
registry.WithCache(cache)
// Add other default implementations
if config.EnableValidation {
registry.WithValidator(NewSimpleValidator())
}
if config.EnableEvents {
registry.WithEventBus(NewSimpleEventBus())
}
return registry, nil
}
// CreateWithRedisCache creates a registry with Redis cache
func (f *FactoryImpl) CreateWithRedisCache(
ctx context.Context,
config Config,
redisURL string,
) (Provider, error) {
// Create Redis client
redisClient, err := NewRedisClient(redisURL)
if err != nil {
return nil, fmt.Errorf("failed to create Redis client: %w", err)
}
// Set default key prefix if not specified
keyPrefix := config.RedisKeyPrefix
if keyPrefix == "" {
keyPrefix = "registry:"
} else if keyPrefix != "" && !strings.HasSuffix(keyPrefix, ":") {
keyPrefix += ":"
}
// Create Redis cache with TTL and key prefix from config
redisCache := NewRedisCache(redisClient, keyPrefix, config.CacheTTL)
return f.CreateWithCache(ctx, config, redisCache)
}
// CreateWithMetrics creates a registry with metrics
func (f *FactoryImpl) CreateWithMetrics(
ctx context.Context,
config Config,
metrics Metrics,
) (Provider, error) {
registry := NewEnhanced(config)
// Add metrics
registry.WithMetrics(metrics)
// Add other default implementations
if config.EnableValidation {
registry.WithValidator(NewSimpleValidator())
}
if config.CacheSize > 0 {
registry.WithCache(NewMemoryCache(config.CacheTTL))
}
if config.EnableEvents {
registry.WithEventBus(NewSimpleEventBus())
}
return registry, nil
}
// CreateFullFeatured creates a registry with all features enabled,
// including persistence, caching, metrics, and validation.
func (f *FactoryImpl) CreateFullFeatured(
ctx context.Context,
config Config,
) (Provider, error) {
registry := NewEnhanced(config)
// Add all implementations
registry.WithValidator(NewSimpleValidator())
registry.WithCache(NewMemoryCache(config.CacheTTL))
registry.WithMetrics(NewSimpleMetrics())
registry.WithHealth(NewSimpleHealth())
registry.WithEventBus(NewSimpleEventBus())
// Add persistence if enabled
if config.EnablePersistence {
persistence := NewFilePersistence(config.PersistencePath)
registry.WithPersistence(persistence)
// Load existing entities
if entities, err := persistence.Load(ctx); err == nil {
for _, entity := range entities {
if err := registry.Register(ctx, entity); err != nil {
return nil, fmt.Errorf("failed to load entity %s: %w", entity.ID(), err)
}
}
}
}
return registry, nil
}
// CreateForTesting creates a registry optimized for testing
func (f *FactoryImpl) CreateForTesting(ctx context.Context) (Provider, error) {
config := Config{
Name: "test-registry",
MaxEntities: 1000,
EnableEvents: false,
EnableValidation: true,
CacheSize: 100,
CacheTTL: time.Minute,
}
registry := NewEnhanced(config)
registry.WithValidator(NewSimpleValidator())
return registry, nil
}
// CreateForProduction creates a registry optimized for production use
func (f *FactoryImpl) CreateForProduction(
ctx context.Context,
name string,
persistencePath string,
) (Provider, error) {
config := Config{
Name: name,
MaxEntities: 10000,
EnableEvents: true,
EnableValidation: true,
CacheSize: 1000,
CacheTTL: 5 * time.Minute,
EnablePersistence: true,
PersistencePath: persistencePath,
AutoSaveInterval: 30 * time.Second,
}
return f.CreateFullFeatured(ctx, config)
}
// CreateForDevelopment creates a registry with development-friendly
// defaults, including in-memory storage and verbose logging.
func (f *FactoryImpl) CreateForDevelopment(
ctx context.Context,
name string,
) (Provider, error) {
config := Config{
Name: name,
MaxEntities: 1000,
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
CacheTTL: time.Minute,
}
registry := NewEnhanced(config)
registry.WithValidator(NewSimpleValidator())
registry.WithMetrics(NewSimpleMetrics())
registry.WithEventBus(NewSimpleEventBus())
return registry, nil
}
// Convenience functions for common registry creation patterns
// NewBasicRegistry creates a basic registry with default settings
func NewBasicRegistry() Provider {
factory := NewRegistryFactory()
config := Config{
Name: "basic-registry",
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
CacheTTL: time.Minute,
}
registry, _ := factory.Create(context.Background(), config)
return registry
}
// NewPersistentRegistry creates a registry with file persistence
func NewPersistentRegistry(filePath string) (Provider, error) {
factory := NewRegistryFactory()
config := Config{
Name: "persistent-registry",
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
CacheTTL: time.Minute,
EnablePersistence: true,
PersistencePath: filePath,
AutoSaveInterval: time.Minute,
}
persistence := NewFilePersistence(filePath)
return factory.CreateWithPersistence(context.Background(), config, persistence)
}
// NewCachedRegistry creates a registry with enhanced caching
func NewCachedRegistry(cacheSize int, cacheTTL time.Duration) Provider {
factory := NewRegistryFactory()
config := Config{
Name: "cached-registry",
EnableEvents: true,
EnableValidation: true,
CacheSize: cacheSize,
CacheTTL: cacheTTL,
}
registry, _ := factory.Create(context.Background(), config)
return registry
}
// NewMonitoredRegistry creates a registry with metrics and monitoring
func NewMonitoredRegistry(name string) Provider {
factory := NewRegistryFactory()
config := Config{
Name: name,
EnableEvents: true,
EnableValidation: true,
CacheSize: 100,
CacheTTL: time.Minute,
}
metrics := NewSimpleMetrics()
registry, _ := factory.CreateWithMetrics(context.Background(), config, metrics)
return registry
}
// BuildRegistry creates a registry with the built configuration
func (b *Builder) BuildRegistry() (Provider, error) {
factory := NewRegistryFactory()
return factory.Create(context.Background(), b.Build())
}
package registry
import (
"context"
"encoding/json"
"fmt"
"os"
"sync"
"time"
)
// MemoryCache implements Cache using in-memory storage
type MemoryCache struct {
cache map[string]cacheEntry
mu sync.RWMutex
ttl time.Duration
}
type cacheEntry struct {
entity Entity
expiresAt time.Time
}
// NewMemoryCache creates a new memory cache
func NewMemoryCache(ttl time.Duration) *MemoryCache {
cache := &MemoryCache{
cache: make(map[string]cacheEntry),
ttl: ttl,
}
// Start cleanup goroutine
go cache.cleanup()
return cache
}
// Get retrieves an entity from cache
func (c *MemoryCache) Get(ctx context.Context, id string) (Entity, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
entry, exists := c.cache[id]
if !exists {
return nil, false
}
// Check if expired
if time.Now().After(entry.expiresAt) {
delete(c.cache, id)
return nil, false
}
return entry.entity, true
}
// Set stores an entity in cache
func (c *MemoryCache) Set(ctx context.Context, entity Entity) error {
c.mu.Lock()
defer c.mu.Unlock()
c.cache[entity.ID()] = cacheEntry{
entity: entity,
expiresAt: time.Now().Add(c.ttl),
}
return nil
}
// Delete removes an entity from cache
func (c *MemoryCache) Delete(ctx context.Context, id string) error {
c.mu.Lock()
defer c.mu.Unlock()
delete(c.cache, id)
return nil
}
// Clear removes all entities from cache
func (c *MemoryCache) Clear(ctx context.Context) error {
c.mu.Lock()
defer c.mu.Unlock()
c.cache = make(map[string]cacheEntry)
return nil
}
// Size returns the number of cached entities
func (c *MemoryCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}
// cleanup removes expired entries
func (c *MemoryCache) cleanup() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
c.mu.Lock()
now := time.Now()
for id, entry := range c.cache {
if now.After(entry.expiresAt) {
delete(c.cache, id)
}
}
c.mu.Unlock()
}
}
// FilePersistence implements RegistryPersistence using file storage
type FilePersistence struct {
filePath string
mu sync.Mutex
}
// NewFilePersistence creates a new file persistence layer
func NewFilePersistence(filePath string) *FilePersistence {
return &FilePersistence{
filePath: filePath,
}
}
// Save persists entities to file
func (p *FilePersistence) Save(ctx context.Context, entities []Entity) error {
p.mu.Lock()
defer p.mu.Unlock()
// Convert entities to a serializable format
var serializable []map[string]interface{}
for _, entity := range entities {
serializable = append(serializable, map[string]interface{}{
"id": entity.ID(),
"name": entity.Name(),
"active": entity.Active(),
"metadata": entity.Metadata(),
"created_at": entity.CreatedAt(),
"updated_at": entity.UpdatedAt(),
})
}
// Marshal to JSON
jsonData, err := json.MarshalIndent(serializable, "", " ")
if err != nil {
return fmt.Errorf("failed to marshal entities: %w", err)
}
// Write to file
if err := os.WriteFile(p.filePath, jsonData, 0644); err != nil {
return fmt.Errorf("failed to write file: %w", err)
}
return nil
}
// Load loads entities from file
func (p *FilePersistence) Load(ctx context.Context) ([]Entity, error) {
p.mu.Lock()
defer p.mu.Unlock()
// Check if file exists
if _, err := os.Stat(p.filePath); os.IsNotExist(err) {
return []Entity{}, nil
}
// Read file
data, err := os.ReadFile(p.filePath)
if err != nil {
return nil, fmt.Errorf("failed to read file: %w", err)
}
// Parse JSON
var serializable []map[string]interface{}
if err := json.Unmarshal(data, &serializable); err != nil {
return nil, fmt.Errorf("failed to parse JSON: %w", err)
}
// Convert to entities
var entities []Entity
for _, item := range serializable {
// Extract fields with type assertions
id, _ := item["id"].(string)
name, _ := item["name"].(string)
active, _ := item["active"].(bool)
// Handle metadata conversion
var metadata map[string]string
if m, ok := item["metadata"].(map[string]interface{}); ok {
metadata = make(map[string]string)
for k, v := range m {
if s, ok := v.(string); ok {
metadata[k] = s
}
}
}
// Create entity using NewBaseEntity helper
entity := NewBaseEntity(id, name)
entity.SetActive(active)
for k, v := range metadata {
entity.SetMetadata(k, v)
}
// Note: createdAt and updatedAt are set by NewBaseEntity
entities = append(entities, entity)
}
return entities, nil
}
// Delete removes the persistence file
func (p *FilePersistence) Delete(ctx context.Context, id string) error {
// For file persistence, we don't delete individual entities
// The entire file is rewritten on save
return nil
}
// Clear removes the persistence file
func (p *FilePersistence) Clear(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
return os.Remove(p.filePath)
}
// SimpleMetrics implements Metrics using simple counters
type SimpleMetrics struct {
registrations int64
unregistrations int64
lookups int64
errors int64
entityCount int
activeCount int
latencies map[string][]time.Duration
mu sync.RWMutex
}
// NewSimpleMetrics creates a new simple metrics collector
func NewSimpleMetrics() *SimpleMetrics {
return &SimpleMetrics{
latencies: make(map[string][]time.Duration),
}
}
// IncrementRegistration increments the registration counter
func (m *SimpleMetrics) IncrementRegistration() {
m.mu.Lock()
defer m.mu.Unlock()
m.registrations++
}
// IncrementUnregistration increments the unregistration counter
func (m *SimpleMetrics) IncrementUnregistration() {
m.mu.Lock()
defer m.mu.Unlock()
m.unregistrations++
}
// IncrementLookup increments the lookup counter
func (m *SimpleMetrics) IncrementLookup() {
m.mu.Lock()
defer m.mu.Unlock()
m.lookups++
}
// IncrementError increments the error counter
func (m *SimpleMetrics) IncrementError() {
m.mu.Lock()
defer m.mu.Unlock()
m.errors++
}
// SetEntityCount sets the entity count
func (m *SimpleMetrics) SetEntityCount(count int) {
m.mu.Lock()
defer m.mu.Unlock()
m.entityCount = count
}
// SetActiveCount sets the active entity count
func (m *SimpleMetrics) SetActiveCount(count int) {
m.mu.Lock()
defer m.mu.Unlock()
m.activeCount = count
}
// RecordLatency records operation latency
func (m *SimpleMetrics) RecordLatency(operation string, duration time.Duration) {
m.mu.Lock()
defer m.mu.Unlock()
if m.latencies[operation] == nil {
m.latencies[operation] = make([]time.Duration, 0)
}
m.latencies[operation] = append(m.latencies[operation], duration)
// Keep only last 100 latencies per operation
if len(m.latencies[operation]) > 100 {
m.latencies[operation] = m.latencies[operation][len(m.latencies[operation])-100:]
}
}
// GetStats returns current metrics statistics
func (m *SimpleMetrics) GetStats() map[string]interface{} {
m.mu.RLock()
defer m.mu.RUnlock()
stats := map[string]interface{}{
"registrations": m.registrations,
"unregistrations": m.unregistrations,
"lookups": m.lookups,
"errors": m.errors,
"entity_count": m.entityCount,
"active_count": m.activeCount,
"latencies": m.latencies,
}
return stats
}
// SimpleEventBus implements RegistryEventBus using in-memory event handling
type SimpleEventBus struct {
observers []Observer
mu sync.RWMutex
}
// NewSimpleEventBus creates a new simple event bus
func NewSimpleEventBus() *SimpleEventBus {
return &SimpleEventBus{
observers: make([]Observer, 0),
}
}
// Subscribe adds an observer to the event bus
func (b *SimpleEventBus) Subscribe(observer Observer) error {
b.mu.Lock()
defer b.mu.Unlock()
b.observers = append(b.observers, observer)
return nil
}
// Unsubscribe removes an observer from the event bus
func (b *SimpleEventBus) Unsubscribe(observer Observer) error {
b.mu.Lock()
defer b.mu.Unlock()
for i, obs := range b.observers {
if obs == observer {
b.observers = append(b.observers[:i], b.observers[i+1:]...)
break
}
}
return nil
}
// Publish publishes an event to all observers
func (b *SimpleEventBus) Emit(ctx context.Context, event Event) error {
b.mu.RLock()
observers := make([]Observer, len(b.observers))
copy(observers, b.observers)
b.mu.RUnlock()
for _, observer := range observers {
switch event.Type {
case EventEntityRegistered:
observer.OnEntityRegistered(ctx, event.Entity)
case EventEntityUnregistered:
observer.OnEntityUnregistered(ctx, event.EntityID)
case EventEntityUpdated:
observer.OnEntityUpdated(ctx, event.Entity)
case EventEntityActivated:
observer.OnEntityActivated(ctx, event.EntityID)
case EventEntityDeactivated:
observer.OnEntityDeactivated(ctx, event.EntityID)
}
}
return nil
}
// SimpleValidator implements Validator with basic validation
type SimpleValidator struct {
requiredMetadata []string
forbiddenMetadata []string
validators map[string]func(string) error
}
// NewSimpleValidator creates a new simple validator
func NewSimpleValidator() *SimpleValidator {
return &SimpleValidator{
requiredMetadata: make([]string, 0),
forbiddenMetadata: make([]string, 0),
validators: make(map[string]func(string) error),
}
}
// WithRequiredMetadata sets required metadata fields
func (v *SimpleValidator) WithRequiredMetadata(fields []string) *SimpleValidator {
v.requiredMetadata = fields
return v
}
// WithForbiddenMetadata sets forbidden metadata fields
func (v *SimpleValidator) WithForbiddenMetadata(fields []string) *SimpleValidator {
v.forbiddenMetadata = fields
return v
}
// WithValidator adds a custom validator for a metadata field
func (v *SimpleValidator) WithValidator(
field string,
validator func(string) error,
) *SimpleValidator {
v.validators[field] = validator
return v
}
// Validate validates an entity
func (v *SimpleValidator) Validate(ctx context.Context, entity Entity) error {
if entity.ID() == "" {
return fmt.Errorf("entity ID is required")
}
if entity.Name() == "" {
return fmt.Errorf("entity name is required")
}
return v.ValidateMetadata(ctx, entity.Metadata())
}
// ValidateMetadata validates entity metadata
func (v *SimpleValidator) ValidateMetadata(ctx context.Context, metadata map[string]string) error {
// Check required metadata
for _, required := range v.requiredMetadata {
if _, exists := metadata[required]; !exists {
return fmt.Errorf("required metadata field missing: %s", required)
}
}
// Check forbidden metadata
for _, forbidden := range v.forbiddenMetadata {
if _, exists := metadata[forbidden]; exists {
return fmt.Errorf("forbidden metadata field present: %s", forbidden)
}
}
// Run custom validators
for field, validator := range v.validators {
if value, exists := metadata[field]; exists {
if err := validator(value); err != nil {
return fmt.Errorf("validation failed for field %s: %w", field, err)
}
}
}
return nil
}
// SimpleHealth implements Health with basic health checking
type SimpleHealth struct {
lastError error
mu sync.RWMutex
}
// NewSimpleHealth creates a new simple health checker
func NewSimpleHealth() *SimpleHealth {
return &SimpleHealth{}
}
// IsHealthy checks if the registry is healthy
func (h *SimpleHealth) IsHealthy(ctx context.Context) bool {
h.mu.RLock()
defer h.mu.RUnlock()
return h.lastError == nil
}
// GetHealthStatus returns the health status
func (h *SimpleHealth) GetHealthStatus(ctx context.Context) map[string]interface{} {
h.mu.RLock()
defer h.mu.RUnlock()
status := map[string]interface{}{
"healthy": h.lastError == nil,
"timestamp": time.Now(),
}
if h.lastError != nil {
status["last_error"] = h.lastError.Error()
}
return status
}
// GetLastError returns the last error
func (h *SimpleHealth) GetLastError() error {
h.mu.RLock()
defer h.mu.RUnlock()
return h.lastError
}
// SetError sets the last error
func (h *SimpleHealth) SetError(err error) {
h.mu.Lock()
defer h.mu.Unlock()
h.lastError = err
}
// ClearError clears the last error
func (h *SimpleHealth) ClearError() {
h.mu.Lock()
defer h.mu.Unlock()
h.lastError = nil
}
package registry
import (
"context"
"strings"
"time"
)
// Core interfaces following Go's idiomatic naming conventions
// Basic interfaces (single-method)
type Identifier interface {
ID() string
}
type IDSetter interface {
SetID(id string) error
}
type Named interface {
Name() string
}
type NameSetter interface {
SetName(name string) error
}
// ActiveStatusChecker defines the interface for checking if an entity is active
type ActiveStatusChecker interface {
Active() bool
}
type ActivationSetter interface {
SetActive(active bool)
}
type MetadataReader interface {
Metadata() map[string]string
}
type MetadataWriter interface {
SetMetadata(key, value string)
}
type MetadataRemover interface {
RemoveMetadata(key string)
}
type MetadataClearer interface {
ClearMetadata()
}
type Timestamped interface {
CreatedAt() time.Time
UpdatedAt() time.Time
}
// Composite interfaces
type Identity interface {
Identifier
IDSetter
}
type Nameable interface {
Named
NameSetter
}
type ActivationController interface {
ActiveStatusChecker
ActivationSetter
}
type MetadataController interface {
MetadataReader
MetadataWriter
MetadataRemover
MetadataClearer
}
type EntityCore interface {
Identity
Nameable
ActivationController
MetadataController
Timestamped
}
// Entity is the main interface that all registry entities must implement
// It's a composition of smaller, focused interfaces
// Deprecated: Use EntityCore for new code
type Entity = EntityCore
// EntityFactory creates new entity instances
type EntityFactory interface {
NewEntity(id, name string) (EntityCore, error)
}
// EntityValidator validates entity state
type EntityValidator interface {
Validate() error
}
// EntityLifecycle defines hooks for entity lifecycle events
type EntityLifecycle interface {
BeforeCreate() error
AfterCreate() error
BeforeUpdate() error
AfterUpdate() error
BeforeDelete() error
AfterDelete() error
}
// EntityFull combines all entity-related interfaces
type EntityFull interface {
EntityCore
EntityValidator
EntityLifecycle
}
// Provider defines the interface for registry implementations
type Provider interface {
// Core operations
Register(ctx context.Context, entity Entity) error
Get(ctx context.Context, id string) (Entity, error)
Unregister(ctx context.Context, id string) error
IsRegistered(ctx context.Context, id string) bool
// Listing operations
List(ctx context.Context) ([]Entity, error)
ListActive(ctx context.Context) ([]Entity, error)
ListByMetadata(ctx context.Context, key, value string) ([]Entity, error)
// Counting operations
Count(ctx context.Context) (int, error)
CountActive(ctx context.Context) (int, error)
// Metadata operations
GetMetadata(ctx context.Context, id, key string) (string, error)
SetMetadata(ctx context.Context, id, key, value string) error
RemoveMetadata(ctx context.Context, id, key string) error
// Lifecycle operations
Activate(ctx context.Context, id string) error
Deactivate(ctx context.Context, id string) error
// Search operations
Search(ctx context.Context, query string) ([]Entity, error)
SearchByMetadata(ctx context.Context, metadata map[string]string) ([]Entity, error)
}
// Observer defines the interface for registry event observers
type Observer interface {
OnEntityRegistered(ctx context.Context, entity Entity)
OnEntityUnregistered(ctx context.Context, id string)
OnEntityUpdated(ctx context.Context, entity Entity)
OnEntityActivated(ctx context.Context, id string)
OnEntityDeactivated(ctx context.Context, id string)
}
// Event represents a registry event
type Event struct {
Type string `json:"type"`
EntityID string `json:"entity_id"`
Entity Entity `json:"entity,omitempty"`
Timestamp time.Time `json:"timestamp"`
Metadata map[string]any `json:"metadata,omitempty"`
}
// EventType constants
const (
EventEntityRegistered = "entity_registered"
EventEntityUnregistered = "entity_unregistered"
EventEntityUpdated = "entity_updated"
EventEntityActivated = "entity_activated"
EventEntityDeactivated = "entity_deactivated"
)
// EventBus defines the interface for registry event handling
type EventBus interface {
Subscribe(observer Observer) error
Unsubscribe(observer Observer) error
Emit(ctx context.Context, event Event) error
}
// Config holds configuration for registry implementations
type Config struct {
Name string `json:"name"`
MaxEntities int `json:"max_entities"`
DefaultTTL time.Duration `json:"default_ttl"`
EnableEvents bool `json:"enable_events"`
EnableValidation bool `json:"enable_validation"`
// Cache settings
CacheSize int `json:"cache_size"`
CacheTTL time.Duration `json:"cache_ttl"`
// Redis cache settings
RedisURL string `json:"redis_url"` // Redis server URL
RedisKeyPrefix string `json:"redis_key_prefix"` // Prefix for Redis keys
RedisPoolSize int `json:"redis_pool_size"` // Max connections in pool
RedisMinIdleConns int `json:"redis_min_idle_conns"` // Min idle connections
RedisMaxRetries int `json:"redis_max_retries"` // Max retries for commands
RedisDialTimeout time.Duration `json:"redis_dial_timeout"` // Dial timeout
RedisReadTimeout time.Duration `json:"redis_read_timeout"` // Read timeout
RedisWriteTimeout time.Duration `json:"redis_write_timeout"` // Write timeout
// Advanced features
EnableCompression bool `json:"enable_compression"`
EnablePersistence bool `json:"enable_persistence"`
PersistencePath string `json:"persistence_path"`
// Auto-save settings
AutoSaveInterval time.Duration `json:"auto_save_interval"`
// Metadata validation
RequiredMetadata []string `json:"required_metadata"`
ForbiddenMetadata []string `json:"forbidden_metadata"`
MetadataValidators map[string]func(string) error `json:"-"`
}
// Validator defines the interface for entity validation
type Validator interface {
Validate(ctx context.Context, entity Entity) error
ValidateMetadata(ctx context.Context, metadata map[string]string) error
}
// Persistence defines the interface for registry persistence
type Persistence interface {
Save(ctx context.Context, entities []Entity) error
Load(ctx context.Context) ([]Entity, error)
Delete(ctx context.Context, id string) error
Clear(ctx context.Context) error
}
// Cache defines the interface for registry caching
type Cache interface {
Get(ctx context.Context, id string) (Entity, bool)
Set(ctx context.Context, entity Entity) error
Delete(ctx context.Context, id string) error
Clear(ctx context.Context) error
Size() int
}
// Metrics defines the interface for registry metrics
type Metrics interface {
IncrementRegistration()
IncrementUnregistration()
IncrementLookup()
IncrementError()
SetEntityCount(count int)
SetActiveCount(count int)
RecordLatency(operation string, duration time.Duration)
}
// Health defines the interface for registry health checks
type Health interface {
IsHealthy(ctx context.Context) bool
GetHealthStatus(ctx context.Context) map[string]interface{}
GetLastError() error
}
// Factory defines the interface for creating registry instances
type Factory interface {
Create(
ctx context.Context,
config Config,
) (Provider, error)
CreateWithPersistence(
ctx context.Context,
config Config,
persistence Persistence,
) (Provider, error)
CreateWithCache(
ctx context.Context,
config Config,
cache Cache,
) (Provider, error)
CreateWithMetrics(
ctx context.Context,
config Config,
metrics Metrics,
) (Provider, error)
}
// Builder provides a fluent interface for building registry configurations
type Builder struct {
config Config
}
// NewBuilder creates a new registry builder
func NewBuilder() *Builder {
return &Builder{
config: Config{
EnableEvents: true,
EnableValidation: true,
CacheSize: 1000,
CacheTTL: 5 * time.Minute,
},
}
}
// WithName sets the registry name
func (b *Builder) WithName(name string) *Builder {
b.config.Name = name
return b
}
// WithMaxEntities sets the maximum number of entities
func (b *Builder) WithMaxEntities(max int) *Builder {
b.config.MaxEntities = max
return b
}
// WithDefaultTTL sets the default TTL for entities
func (b *Builder) WithDefaultTTL(ttl time.Duration) *Builder {
b.config.DefaultTTL = ttl
return b
}
// WithCache sets cache configuration
func (b *Builder) WithCache(size int, ttl time.Duration) *Builder {
b.config.CacheSize = size
b.config.CacheTTL = ttl
return b
}
// WithRedis configures Redis cache settings
func (b *Builder) WithRedis(url string) *Builder {
b.config.RedisURL = url
// Set sensible defaults for Redis
if b.config.RedisKeyPrefix == "" {
b.config.RedisKeyPrefix = "registry:"
}
if b.config.RedisPoolSize == 0 {
b.config.RedisPoolSize = 10
}
if b.config.RedisMinIdleConns == 0 {
b.config.RedisMinIdleConns = 5
}
if b.config.RedisMaxRetries == 0 {
b.config.RedisMaxRetries = 3
}
if b.config.RedisDialTimeout == 0 {
b.config.RedisDialTimeout = 5 * time.Second
}
if b.config.RedisReadTimeout == 0 {
b.config.RedisReadTimeout = 3 * time.Second
}
if b.config.RedisWriteTimeout == 0 {
b.config.RedisWriteTimeout = 3 * time.Second
}
return b
}
// WithKeyPrefix sets the Redis key prefix for the registry
func (b *Builder) WithKeyPrefix(prefix string) *Builder {
// Ensure prefix ends with a colon if not empty
if prefix != "" && !strings.HasSuffix(prefix, ":") {
prefix += ":"
}
b.config.RedisKeyPrefix = prefix
return b
}
// WithRedisAdvanced allows fine-grained Redis configuration
func (b *Builder) WithRedisAdvanced(
url string,
prefix string,
poolSize int,
minIdleConns int,
maxRetries int,
dialTimeout time.Duration,
readTimeout time.Duration,
writeTimeout time.Duration,
) *Builder {
b.config.RedisURL = url
b.config.RedisKeyPrefix = prefix
b.config.RedisPoolSize = poolSize
b.config.RedisMinIdleConns = minIdleConns
b.config.RedisMaxRetries = maxRetries
b.config.RedisDialTimeout = dialTimeout
b.config.RedisReadTimeout = readTimeout
b.config.RedisWriteTimeout = writeTimeout
return b
}
// WithPersistence enables persistence with the given path
func (b *Builder) WithPersistence(path string, interval time.Duration) *Builder {
b.config.EnablePersistence = true
b.config.PersistencePath = path
b.config.AutoSaveInterval = interval
return b
}
// WithValidation sets validation configuration
func (b *Builder) WithValidation(required, forbidden []string) *Builder {
b.config.RequiredMetadata = required
b.config.ForbiddenMetadata = forbidden
return b
}
// Build returns the built configuration
func (b *Builder) Build() Config {
return b.config
}
package registry
import "strings"
// protectedFields is a set of core field names that should not be stored in metadata
var protectedFields = map[string]bool{
"id": true,
"name": true,
"active": true,
"createdAt": true,
"updatedAt": true,
}
// isProtectedField checks if a metadata key is a protected core field
func isProtectedField(key string) bool {
// Check for exact match first
if protectedFields[key] {
return true
}
// Check for case-insensitive match
lowerKey := strings.ToLower(key)
for k := range protectedFields {
if strings.ToLower(k) == lowerKey {
return true
}
}
return false
}
package registry
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/redis/go-redis/v9"
)
// NewRedisCache creates a new Redis cache for the registry
func NewRedisCache(client *redis.Client, prefix string, ttl time.Duration) *RedisCache {
if prefix == "" {
prefix = "registry:"
}
return &RedisCache{
client: client,
prefix: prefix,
ttl: ttl,
}
}
// RedisCache implements the Cache interface using Redis
// Note: This is a simplified version. The full implementation should be in registry_redis_cache.go
type RedisCache struct {
client *redis.Client
prefix string
ttl time.Duration
}
// Get retrieves an entity from Redis cache
func (c *RedisCache) Get(ctx context.Context, id string) (Entity, bool) {
val, err := c.client.Get(ctx, c.prefix+id).Result()
if err != nil {
return nil, false
}
var entity BaseEntity
if err := json.Unmarshal([]byte(val), &entity); err != nil {
return nil, false
}
return &entity, true
}
// Set stores an entity in Redis cache
func (c *RedisCache) Set(ctx context.Context, entity Entity) error {
data, err := json.Marshal(entity)
if err != nil {
return err
}
ttl := c.ttl
if ttl <= 0 {
ttl = 24 * time.Hour // Default TTL
}
return c.client.Set(ctx, c.prefix+entity.ID(), data, ttl).Err()
}
// Delete removes an entity from Redis cache
func (c *RedisCache) Delete(ctx context.Context, id string) error {
return c.client.Del(ctx, c.prefix+id).Err()
}
// Clear removes all cached entities with the prefix
func (c *RedisCache) Clear(ctx context.Context) error {
// Note: In production, consider using SCAN with MATCH for large datasets
iter := c.client.Scan(ctx, 0, c.prefix+"*", 0).Iterator()
var err error
for iter.Next(ctx) {
if delErr := c.client.Del(ctx, iter.Val()).Err(); delErr != nil {
return delErr
}
}
if err = iter.Err(); err != nil {
return err
}
return err
}
// Size returns the number of cache entries (approximate)
func (c *RedisCache) Size() int {
// This is an approximation as SCARD is not used
keys, err := c.client.Keys(context.Background(), c.prefix+"*").Result()
if err != nil {
return 0
}
return len(keys)
}
// NewRedisClient creates a new Redis client from a URL
func NewRedisClient(url string) (*redis.Client, error) {
opt, err := redis.ParseURL(url)
if err != nil {
return nil, fmt.Errorf("failed to parse Redis URL: %w", err)
}
client := redis.NewClient(opt)
// Test the connection
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := client.Ping(ctx).Err(); err != nil {
return nil, fmt.Errorf("failed to connect to Redis: %w", err)
}
return client, nil
}
package registry
import (
"errors"
"maps"
"sync"
"time"
)
var ErrNotFound = errors.New("entity not found")
// Registry is a thread-safe registry for managing entities that implement the Entity interface
type Registry struct {
entities map[string]Entity
mu sync.RWMutex
}
// New creates a new empty registry
func New() *Registry {
return &Registry{
entities: make(map[string]Entity),
}
}
// Register adds or updates an entity in the registry
func (r *Registry) Register(id string, entity Entity) {
r.mu.Lock()
defer r.mu.Unlock()
// Create a new BaseEntity with the same values
entityCopy := &BaseEntity{
id: entity.ID(),
name: entity.Name(),
active: entity.Active(),
metadata: make(map[string]string),
createdAt: entity.CreatedAt(),
updatedAt: time.Now(),
}
// Copy metadata
maps.Copy(entityCopy.metadata, entity.Metadata())
r.entities[id] = entityCopy
}
// Get returns the entity for the given ID
// Returns nil if the entity is not found
func (r *Registry) Get(id string) Entity {
r.mu.RLock()
defer r.mu.RUnlock()
if entity, exists := r.entities[id]; exists {
return entity
}
// Return nil for unknown entities
return nil
}
// IsRegistered checks if an entity ID is registered
func (r *Registry) IsRegistered(id string) bool {
r.mu.RLock()
defer r.mu.RUnlock()
_, exists := r.entities[id]
return exists
}
// ListRegistered returns a list of all registered entity IDs
func (r *Registry) ListRegistered() []string {
r.mu.RLock()
defer r.mu.RUnlock()
ids := make([]string, 0, len(r.entities))
for id := range r.entities {
ids = append(ids, id)
}
return ids
}
// ListActive returns a list of all active entity IDs
func (r *Registry) ListActive() []string {
r.mu.RLock()
defer r.mu.RUnlock()
var active []string
for id, entity := range r.entities {
if entity.Active() {
active = append(active, id)
}
}
return active
}
// Unregister removes an entity from the registry
func (r *Registry) Unregister(id string) bool {
r.mu.Lock()
defer r.mu.Unlock()
if _, exists := r.entities[id]; exists {
delete(r.entities, id)
return true
}
return false
}
// Count returns the total number of registered entities
func (r *Registry) Count() int {
r.mu.RLock()
defer r.mu.RUnlock()
return len(r.entities)
}
// GetMetadata returns a specific metadata value for an entity
func (r *Registry) GetMetadata(id, key string) (string, bool) {
r.mu.RLock()
defer r.mu.RUnlock()
if entity, exists := r.entities[id]; exists {
metadata := entity.Metadata()
if metadata != nil {
value, found := metadata[key]
return value, found
}
}
return "", false
}
// SetMetadata sets a specific metadata value for an entity
func (r *Registry) SetMetadata(id, key, value string) bool {
r.mu.Lock()
defer r.mu.Unlock()
entity, exists := r.entities[id]
if !exists {
return false
}
// copy the metadata to avoid data races
metadata := make(map[string]string)
for k, v := range entity.Metadata() {
metadata[k] = v
}
metadata[key] = value
// Create a new BaseEntity with the updated metadata
updatedEntity := &BaseEntity{
id: entity.ID(),
name: entity.Name(),
active: entity.Active(),
metadata: metadata,
createdAt: entity.CreatedAt(),
updatedAt: entity.UpdatedAt(),
}
r.entities[id] = updatedEntity
return true
}
// RemoveMetadata removes a specific metadata key from an entity
func (r *Registry) RemoveMetadata(id, key string) bool {
r.mu.Lock()
defer r.mu.Unlock()
entity, exists := r.entities[id]
if !exists {
return false
}
// copy the metadata to avoid data races
metadata := make(map[string]string)
for k, v := range entity.Metadata() {
metadata[k] = v
}
// Check if the key exists before trying to delete
if _, exists := metadata[key]; !exists {
return false
}
delete(metadata, key)
// Create a new BaseEntity with the updated metadata
updatedEntity := &BaseEntity{
id: entity.ID(),
name: entity.Name(),
active: entity.Active(),
metadata: metadata,
createdAt: entity.CreatedAt(),
updatedAt: time.Now(),
}
r.entities[id] = updatedEntity
return true
}
// Create and manage registry instances explicitly in your application code.
package stripeconnect
import (
"context"
"github.com/amirasaad/fintech/pkg/domain"
)
// Repository defines the interface for Stripe Connect related operations
type Repository interface {
// SaveStripeAccountID saves the Stripe Connect account ID for a user
SaveStripeAccountID(ctx context.Context, userID, accountID string) error
// GetStripeAccountID retrieves the Stripe Connect account ID for a user
GetStripeAccountID(ctx context.Context, userID string) (string, error)
// UpdateOnboardingStatus updates the onboarding status for a user's Stripe account
UpdateOnboardingStatus(ctx context.Context, userID string, completed bool) error
// GetOnboardingStatus checks if the user has completed Stripe onboarding
GetOnboardingStatus(ctx context.Context, userID string) (bool, error)
}
// NewRepository creates a new Stripe Connect repository
func NewRepository(userRepo UserRepository) Repository {
return &repository{
userRepo: userRepo,
}
}
type repository struct {
userRepo UserRepository
}
// UserRepository defines the minimal user repository interface needed by Stripe Connect
// This avoids circular dependencies between packages
type UserRepository interface {
// GetStripeAccountID gets the Stripe account ID for a user
GetStripeAccountID(ctx context.Context, userID string) (string, error)
// GetStripeOnboardingStatus gets the Stripe onboarding status for a user
GetStripeOnboardingStatus(ctx context.Context, userID string) (bool, error)
// UpdateStripeAccount updates the Stripe account information for a user
UpdateStripeAccount(
ctx context.Context,
userID, accountID string,
onboardingComplete bool,
) error
// UpdateStripeOnboardingStatus updates the Stripe onboarding status for a user
UpdateStripeOnboardingStatus(ctx context.Context, userID string, completed bool) error
}
func (r *repository) SaveStripeAccountID(ctx context.Context, userID, accountID string) error {
// Delegate to the user repository to handle the actual database operation
return r.userRepo.UpdateStripeAccount(
ctx,
userID,
accountID,
false,
)
}
func (r *repository) GetStripeAccountID(ctx context.Context, userID string) (string, error) {
// Delegate to the user repository to get the account ID
accountID, err := r.userRepo.GetStripeAccountID(ctx, userID)
if err != nil {
return "", err
}
if accountID == "" {
return "", domain.ErrNotFound
}
return accountID, nil
}
func (r *repository) UpdateOnboardingStatus(
ctx context.Context,
userID string,
completed bool,
) error {
// Delegate to the user repository to update the onboarding status
return r.userRepo.UpdateStripeOnboardingStatus(ctx, userID, completed)
}
func (r *repository) GetOnboardingStatus(ctx context.Context, userID string) (bool, error) {
// Delegate to the user repository to get the onboarding status
status, err := r.userRepo.GetStripeOnboardingStatus(ctx, userID)
if err != nil {
return false, err
}
return status, nil
}
// Package account provides business logic for interacting with
// domain entities such as accounts and transactions.
// It defines the Service struct and its
// methods for creating accounts, depositing and withdrawing funds,
// retrieving account details, listing transactions, and checking account balances.
//
// The service layer follows clean architecture principles
// and uses the decorator pattern for transaction management.
// All business operations are wrapped with automatic transaction management,
//
// error recovery, and structured logging.
package account
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/commands"
"github.com/amirasaad/fintech/pkg/domain/events"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/repository"
repoaccount "github.com/amirasaad/fintech/pkg/repository/account"
stripeconnect "github.com/amirasaad/fintech/pkg/service/stripeconnect"
"github.com/google/uuid"
)
// Service provides business logic for account operations including
// creation, deposits, withdrawals, and balance inquiries.
type Service struct {
bus eventbus.Bus
uow repository.UnitOfWork
logger *slog.Logger
stripeConnectSvc stripeconnect.Service
}
// New creates a new Service with the provided dependencies.
func New(
bus eventbus.Bus,
uow repository.UnitOfWork,
logger *slog.Logger,
stripeConnectSvc stripeconnect.Service,
) *Service {
return &Service{
bus: bus,
uow: uow,
logger: logger,
stripeConnectSvc: stripeConnectSvc,
}
}
func (s *Service) CreateAccount(
ctx context.Context,
create dto.AccountCreate,
) (*dto.AccountRead, error) {
uow := s.uow
var result *dto.AccountRead
err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
repoAny, err := uow.GetRepository((*repoaccount.Repository)(nil))
if err != nil {
return err
}
acctRepo := repoAny.(repoaccount.Repository)
// Check if user already has an account with the same currency
existingAccounts, err := acctRepo.ListByUser(ctx, create.UserID)
if err != nil {
return fmt.Errorf("failed to check existing accounts: %w", err)
}
for _, acc := range existingAccounts {
if acc.Currency == create.Currency {
return fmt.Errorf("user already has an account with currency %s", create.Currency)
}
}
// Enforce domain invariants
curr := money.Code(create.Currency)
if curr == "" {
curr = money.DefaultCode
}
domainAcc, err := account.New().WithUserID(create.UserID).WithCurrency(curr).Build()
if err != nil {
return err
}
// Map to DTO for persistence
createDTO := dto.AccountCreate{
ID: domainAcc.ID,
UserID: domainAcc.UserID,
Balance: int64(domainAcc.Balance.Amount()), // or 0 if always zero at creation
Currency: curr.String(),
}
if err = acctRepo.Create(ctx, createDTO); err != nil {
return fmt.Errorf("failed to create account: %w", err)
}
// Fetch for read DTO
read, err := acctRepo.Get(ctx, domainAcc.ID)
if err != nil {
return fmt.Errorf("failed to fetch created account: %w", err)
}
result = read
return nil
})
if err != nil {
return nil, fmt.Errorf("account creation failed: %w", err)
}
return result, nil
}
// Deposit adds funds to the specified account and creates a transaction record.
func (s *Service) Deposit(
ctx context.Context,
cmd commands.Deposit,
) error {
// Always use the source currency for the initial deposit event
amount, err := money.New(cmd.Amount, money.Code(cmd.Currency))
if err != nil {
return err
}
dr := events.NewDepositRequested(
cmd.UserID,
cmd.AccountID,
uuid.New(),
events.WithDepositAmount(amount),
)
return s.bus.Emit(ctx, dr)
}
// Withdraw removes funds from the specified account
// to an external target and creates a transaction record.
// It returns an error if the user has not completed Stripe Connect onboarding.
func (s *Service) Withdraw(
ctx context.Context,
cmd commands.Withdraw,
) error {
// Check if user has completed Stripe Connect onboarding
onboarded, err := s.stripeConnectSvc.IsOnboardingComplete(ctx, cmd.UserID)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return fmt.Errorf("failed to check Stripe Connect status: %w", err)
}
if !onboarded {
return domain.ErrStripeOnboardingIncomplete
}
amount, err := money.New(cmd.Amount, money.Code(cmd.Currency))
if err != nil {
return fmt.Errorf("invalid amount: %w", err)
}
// Create event with amount and bank account number if provided
opts := []events.WithdrawRequestedOpt{
events.WithWithdrawAmount(amount),
}
if cmd.ExternalTarget != nil && cmd.ExternalTarget.BankAccountNumber != "" {
opts = append(
opts,
events.WithWithdrawBankAccountNumber(
cmd.ExternalTarget.BankAccountNumber,
),
)
}
wr := events.NewWithdrawRequested(
cmd.UserID,
cmd.AccountID,
uuid.New(),
opts...,
)
return s.bus.Emit(ctx, wr)
}
// ListUserAccounts returns all accounts for a specific user.
func (s *Service) ListUserAccounts(
ctx context.Context,
userID uuid.UUID,
) ([]*dto.AccountRead, error) {
uow := s.uow
var accounts []*dto.AccountRead
err := uow.Do(ctx, func(uow repository.UnitOfWork) error {
repoAny, err := uow.GetRepository((*repoaccount.Repository)(nil))
if err != nil {
return fmt.Errorf("failed to get account repository: %w", err)
}
acctRepo := repoAny.(repoaccount.Repository)
accounts, err = acctRepo.ListByUser(ctx, userID)
if err != nil {
return fmt.Errorf("failed to list user accounts: %w", err)
}
return nil
})
if err != nil {
s.logger.Error("Failed to list user accounts", "error", err, "userID", userID)
return nil, err
}
return accounts, nil
}
// Transfer moves funds from one account to another account.
func (s *Service) Transfer(
ctx context.Context,
cmd commands.Transfer,
) error {
amount, err := money.New(cmd.Amount, money.Code(cmd.Currency))
if err != nil {
return err
}
tr := events.NewTransferRequested(
cmd.UserID,
cmd.AccountID,
uuid.New(),
events.WithTransferDestAccountID(cmd.ToAccountID),
events.WithTransferRequestedAmount(amount),
)
return s.bus.Emit(ctx, tr)
}
package account
import (
"context"
"github.com/amirasaad/fintech/pkg/dto"
repoaccount "github.com/amirasaad/fintech/pkg/repository/account"
transactionrepo "github.com/amirasaad/fintech/pkg/repository/transaction"
"github.com/google/uuid"
)
// GetAccount retrieves an account by ID for the specified user.
func (s *Service) GetAccount(
ctx context.Context,
userID, accountID uuid.UUID,
) (
account *dto.AccountRead,
err error,
) {
repoAny, err := s.uow.GetRepository((*repoaccount.Repository)(nil))
if err != nil {
return
}
repo, ok := repoAny.(repoaccount.Repository)
if !ok {
return
}
account, err = repo.Get(ctx, accountID)
return
}
// GetTransactions retrieves all transactions for a specific account.
func (s *Service) GetTransactions(
ctx context.Context,
userID, accountID uuid.UUID,
) (
transactions []*dto.TransactionRead,
err error,
) {
// First, validate that the account exists and belongs to the user
accountRepoAny, err := s.uow.GetRepository((*repoaccount.Repository)(nil))
if err != nil {
return
}
accountRepo, ok := accountRepoAny.(repoaccount.Repository)
if !ok {
return
}
_, err = accountRepo.Get(ctx, accountID)
if err != nil {
return
}
// Then, get the transactions
transactionRepoAny, err := s.uow.GetRepository((*transactionrepo.Repository)(nil))
if err != nil {
return
}
transactionRepo, ok := transactionRepoAny.(transactionrepo.Repository)
if !ok {
return
}
transactions, err = transactionRepo.ListByAccount(ctx, accountID)
return
}
// GetBalance retrieves the current balance of an account for the specified user.
func (s *Service) GetBalance(
ctx context.Context,
userID, accountID uuid.UUID,
) (
balance float64,
err error,
) {
repoAny, err := s.uow.GetRepository((*repoaccount.Repository)(nil))
if err != nil {
return
}
repo, ok := repoAny.(repoaccount.Repository)
if !ok {
return
}
acc, err := repo.Get(ctx, accountID)
if err != nil {
return
}
if acc.UserID != userID {
return
}
balance = acc.Balance
return
}
package auth
import (
"context"
"fmt"
"log/slog"
"time"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/domain/user"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/repository"
repouser "github.com/amirasaad/fintech/pkg/repository/user"
"github.com/amirasaad/fintech/pkg/utils"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
type contextKey string
const userContextKey contextKey = "user"
type Strategy interface {
Login(ctx context.Context, identity, password string) (*dto.UserRead, error)
GetCurrentUserID(ctx context.Context) (uuid.UUID, error)
GenerateToken(ctx context.Context, u *dto.UserRead) (string, error)
}
type Service struct {
uow repository.UnitOfWork
strategy Strategy
logger *slog.Logger
}
func New(
uow repository.UnitOfWork,
strategy Strategy,
logger *slog.Logger,
) *Service {
return &Service{uow: uow, strategy: strategy, logger: logger}
}
func NewWithBasic(
uow repository.UnitOfWork,
logger *slog.Logger,
) *Service {
return New(uow, &BasicAuthStrategy{uow: uow, logger: logger}, logger)
}
func NewWithJWT(
uow repository.UnitOfWork,
cfg *config.Jwt,
logger *slog.Logger,
) *Service {
return New(uow, &JWTStrategy{uow: uow, cfg: cfg, logger: logger}, logger)
}
func (s *Service) CheckPasswordHash(
password, hash string,
) bool {
s.logger.Info("CheckPasswordHash called")
valid := utils.CheckPasswordHash(password, hash)
if !valid {
s.logger.Error("Password hash check failed", "valid", valid)
}
return valid
}
func (s *Service) ValidEmail(email string) bool {
s.logger.Debug("ValidEmail called", "email", email)
return utils.IsEmail(email)
}
func (s *Service) GetCurrentUserId(
token *jwt.Token,
) (userID uuid.UUID, err error) {
log := s.logger.With("context", "GetCurrentUserId")
log.Debug("GetCurrentUserId called")
userID, err = s.strategy.GetCurrentUserID(
context.WithValue(
context.Background(),
userContextKey,
token,
),
)
if err != nil {
log.Error("GetCurrentUserId failed", "error", err)
return
}
log.Info("GetCurrentUserId successful", "userID", userID)
return
}
func (s *Service) Login(
ctx context.Context,
identity, password string,
) (u *dto.UserRead, err error) {
log := s.logger.With("context", "Login")
log.Debug("Login called", "identity", identity)
u, err = s.strategy.Login(ctx, identity, password)
if err != nil {
log.Error("Login failed", "identity", identity, "error", err)
return
}
log.Info("Login successful", "userID", u.ID)
return
}
func (s *Service) GenerateToken(
ctx context.Context,
u *dto.UserRead,
) (string, error) {
log := s.logger.With("userID", u.ID)
log.Debug("GenerateToken called")
token, err := s.strategy.GenerateToken(ctx, u)
if err != nil {
log.Error("GenerateToken failed", "userID", u.ID, "error", err)
return "", err
}
log.Info("GenerateToken successful")
return token, nil
}
// JWTStrategy implements AuthStrategy for JWT-based authentication
type JWTStrategy struct {
uow repository.UnitOfWork
cfg *config.Jwt
logger *slog.Logger
}
func NewJWTStrategy(
uow repository.UnitOfWork,
cfg *config.Jwt,
logger *slog.Logger,
) *JWTStrategy {
return &JWTStrategy{uow: uow, cfg: cfg, logger: logger}
}
func (s *JWTStrategy) GenerateToken(
ctx context.Context,
u *dto.UserRead) (string, error) {
log := s.logger.With("userID", u.ID)
log.Debug("GenerateToken called", "userID", u.ID)
token := jwt.New(jwt.SigningMethodHS256)
claims := token.Claims.(jwt.MapClaims)
claims["username"] = u.Username
claims["email"] = u.Email
claims["user_id"] = u.ID.String()
claims["exp"] = time.Now().Add(s.cfg.Expiry).Unix()
tokenString, err := token.SignedString([]byte(s.cfg.Secret))
if err != nil {
log.Error("GenerateToken failed", "userID", u.ID, "error", err)
return "", err
}
log.Info("GenerateToken successful")
return tokenString, nil
}
func (s *JWTStrategy) Login(
ctx context.Context,
identity, password string,
) (
u *dto.UserRead,
err error,
) {
log := s.logger.With("context", "Login", "identity", identity)
log.Debug("Login called")
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repoAny, err := uow.GetRepository((*repouser.Repository)(nil))
if err != nil {
return fmt.Errorf("failed to get user repository: %w", err)
}
repo, ok := repoAny.(repouser.Repository)
if !ok {
return fmt.Errorf("invalid user repository type")
}
// Check if identity is email or username
if utils.IsEmail(identity) {
u, err = repo.GetByEmail(ctx, identity)
} else {
u, err = repo.GetByUsername(ctx, identity)
}
const dummyHash = "$2a$10$7zFqzDbD3RrlkMTczbXG9OWZ0FLOXjIxXzSZ.QZxkVXjXcx7QZQiC"
if err != nil {
return user.ErrUserUnauthorized
}
if u == nil {
// Always check password hash to avoid timing attacks
_ = utils.CheckPasswordHash(password, dummyHash)
log.Error("Login failed", "error", user.ErrUserUnauthorized)
return user.ErrUserUnauthorized
}
if !utils.CheckPasswordHash(
password,
u.HashedPassword,
) {
log.Error("Login failed", "error", user.ErrUserUnauthorized)
return user.ErrUserUnauthorized
}
return nil
})
return
}
func (s *JWTStrategy) GetCurrentUserID(
ctx context.Context,
) (userID uuid.UUID, err error) {
log := s.logger.With("context", "GetCurrentUserID")
log.Debug("GetCurrentUserID called")
token, ok := ctx.Value(userContextKey).(*jwt.Token)
if !ok || token == nil {
log.Error("GetCurrentUserID failed", "error", user.ErrUserUnauthorized)
err = user.ErrUserUnauthorized
return
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
log.Error("GetCurrentUserID failed", "error", user.ErrUserUnauthorized)
err = user.ErrUserUnauthorized
return
}
userIDRaw, ok := claims["user_id"].(string)
if !ok {
log.Error("GetCurrentUserID failed", "error", user.ErrUserUnauthorized)
err = user.ErrUserUnauthorized
return
}
userID, err = uuid.Parse(userIDRaw)
if err != nil {
log.Error("GetCurrentUserID failed", "error", err)
return
}
log.Info("GetCurrentUserID successful", "userID", userID)
return
}
// BasicAuthStrategy implements AuthStrategy for CLI (no JWT, just password check)
type BasicAuthStrategy struct {
uow repository.UnitOfWork
logger *slog.Logger
}
func NewBasicAuthStrategy(
uow repository.UnitOfWork,
logger *slog.Logger,
) *BasicAuthStrategy {
return &BasicAuthStrategy{uow: uow, logger: logger}
}
func (s *BasicAuthStrategy) Login(
ctx context.Context,
identity, password string,
) (u *dto.UserRead, err error) {
log := s.logger.With("identity", identity)
log.Info("BasicAuth Login called")
repoAny, err := s.uow.GetRepository((*repouser.Repository)(nil))
if err != nil {
err = fmt.Errorf("failed to get user repository: %w", err)
s.logger.Error("Failed to get user repository", "error", err)
return
}
repo, ok := repoAny.(repouser.Repository)
if !ok {
err = fmt.Errorf("invalid user repository type")
s.logger.Error("Invalid user repository type")
return
}
log.Info("Looking up user")
if utils.IsEmail(identity) {
u, err = repo.GetByEmail(ctx, identity)
} else {
u, err = repo.GetByUsername(ctx, identity)
}
// If there was an error from the repository, return it
if err != nil {
log.Error("Repository error", "error", err)
return nil, fmt.Errorf("repository error: %w", err)
}
// If user not found, return unauthorized
if u == nil {
log.Info("User not found", "identity", identity)
return nil, user.ErrUserUnauthorized
}
log.Info("User found", "userID", u.ID, "username", u.Username)
// Check password against the hardcoded hash for "password"
const dummyHash = "$2a$10$.IIxpSc3OElWXLV2Wj517eUGmZ64IQgBNQ4OcFbanW85CTrgrIDQy"
log.Debug("Comparing password hash", "providedPassword", password, "hash", dummyHash)
if valid := utils.CheckPasswordHash(password, dummyHash); !valid {
log.Error("Password comparison failed", "error", err)
return nil, user.ErrUserUnauthorized
}
log.Info("Password comparison succeeded")
return
}
func (s *BasicAuthStrategy) GetCurrentUserID(ctx context.Context) (uuid.UUID, error) {
log := s.logger.With("context", "GetCurrentUserID")
log.Debug("GetCurrentUserID called")
return uuid.Nil, nil
}
func (s *BasicAuthStrategy) GenerateToken(ctx context.Context, u *dto.UserRead) (string, error) {
log := s.logger.With("userID", u.ID)
log.Debug("GenerateToken called")
return "", nil // No token for basic auth
}
package checkout
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"strconv"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/registry"
"github.com/google/uuid"
)
// Session represents a checkout session with its metadata
type Session struct {
ID string `json:"id"`
TransactionID uuid.UUID `json:"transaction_id"`
UserID uuid.UUID `json:"user_id"`
AccountID uuid.UUID `json:"account_id"`
Amount int64 `json:"amount"`
Currency string `json:"currency"`
Status string `json:"status"`
CheckoutURL string `json:"checkout_url"`
CreatedAt time.Time `json:"created_at"`
ExpiresAt time.Time `json:"expires_at"`
}
// Service provides high-level operations for managing checkout sessions
type Service struct {
registry registry.Provider
logger *slog.Logger
}
// New creates a new checkout service with the given registry and logger
func New(reg registry.Provider, logger *slog.Logger) *Service {
return &Service{
registry: reg,
logger: logger,
}
}
// CreateSession creates a new checkout session
func (s *Service) CreateSession(
ctx context.Context,
sessionID string,
id string,
txID uuid.UUID,
userID uuid.UUID,
accountID uuid.UUID,
amount int64,
currencyCode string,
checkoutURL string,
expiresIn time.Duration,
) (*Session, error) {
// Create the session
session := &Session{
ID: sessionID,
TransactionID: txID,
UserID: userID,
AccountID: accountID,
Amount: amount,
Currency: currencyCode,
Status: "created",
CheckoutURL: checkoutURL,
CreatedAt: time.Now().UTC(),
ExpiresAt: time.Now().UTC().Add(expiresIn),
}
// Validate the session
if err := session.Validate(); err != nil {
return nil, fmt.Errorf("invalid session: %w", err)
}
// Save to registry
if err := s.saveSession(session); err != nil {
return nil, fmt.Errorf("failed to save session: %w", err)
}
return session, nil
}
// GetSession retrieves a checkout session by ID
func (s *Service) GetSession(
ctx context.Context,
id string,
) (*Session, error) {
entity, err := s.registry.Get(ctx, id)
if err != nil {
return nil, fmt.Errorf("error getting session: %w", err)
}
if entity == nil {
return nil, fmt.Errorf("session not found: %s", id)
}
// Convert Entity to Session
return s.entityToSession(entity)
}
// GetSessionByTransactionID retrieves a checkout session by transaction ID
func (s *Service) GetSessionByTransactionID(
ctx context.Context,
txID uuid.UUID,
) (*Session, error) {
// Search for session by transaction ID in metadata
entities, err := s.registry.ListByMetadata(
ctx,
"transaction_id",
txID.String(),
)
if err != nil {
return nil, fmt.Errorf("error searching for session: %w", err)
}
if len(entities) == 0 {
return nil, fmt.Errorf("session with transaction ID %s not found", txID)
}
// Convert the first matching entity to Session
return s.entityToSession(entities[0])
}
// GetSessionsByUserID retrieves all checkout sessions for a given user ID
func (s *Service) GetSessionsByUserID(ctx context.Context, userID uuid.UUID) ([]*Session, error) {
entities, err := s.registry.ListByMetadata(ctx, "user_id", userID.String())
if err != nil {
return nil, fmt.Errorf("error getting sessions by user ID: %w", err)
}
var sessions []*Session
for _, entity := range entities {
session, err := s.entityToSession(entity)
if err != nil {
return nil, fmt.Errorf("error converting entity to session: %w", err)
}
sessions = append(sessions, session)
}
return sessions, nil
}
// UpdateStatus updates the status of a checkout session
func (s *Service) UpdateStatus(
ctx context.Context,
id, status string,
) error {
// Get the existing entity
entity, err := s.registry.Get(ctx, id)
if err != nil {
return fmt.Errorf("error getting session: %w", err)
}
if entity == nil {
return fmt.Errorf("session not found: %s", id)
}
// Update the status in metadata
metadata := entity.Metadata()
metadata["status"] = status
// Update the active status based on the new status
active := status != "expired" && status != "canceled" && status != "failed"
// Create a new entity with updated fields
updatedEntity := ®istry.BaseEntity{
BEId: entity.ID(),
BEName: entity.Name(),
BEActive: active,
BEMetadata: metadata,
}
// Save the updated entity
err = s.registry.Register(ctx, updatedEntity)
if err != nil {
return fmt.Errorf("failed to update session status: %w", err)
}
return nil
}
// Validate checks if the session is valid
func (s *Session) Validate() error {
if s.ID == "" {
return fmt.Errorf("session ID cannot be empty")
}
if s.TransactionID == uuid.Nil {
return fmt.Errorf("transaction ID cannot be nil")
}
if s.UserID == uuid.Nil {
return fmt.Errorf("user ID cannot be nil")
}
if s.AccountID == uuid.Nil {
return fmt.Errorf("account ID cannot be nil")
}
if s.Amount <= 0 {
return fmt.Errorf("amount must be positive")
}
if !money.Code(s.Currency).IsValid() {
return fmt.Errorf("invalid currency code: %s", s.Currency)
}
return nil
}
// FormatAmount formats the amount according to the currency's decimal places
func (s *Session) FormatAmount() (string, error) {
// Create a Money object from the amount and currency
m, err := money.NewFromSmallestUnit(s.Amount, money.Code(s.Currency))
if err != nil {
return "", fmt.Errorf("failed to create money object: %w", err)
}
// Use Money's String() method which handles the formatting
return m.String(), nil
}
// saveSession saves the session to the registry
func (s *Service) saveSession(session *Session) error {
// Create a base entity with the session data
entity := ®istry.BaseEntity{
BEId: session.ID,
BEName: fmt.Sprintf("checkout_session_%s", session.TransactionID.String()),
BEActive: session.Status != "expired" &&
session.Status != "canceled" && session.Status != "failed",
BEMetadata: make(map[string]string),
}
// Add all fields as metadata for searchability
entity.SetMetadata("transaction_id", session.TransactionID.String())
entity.SetMetadata("user_id", session.UserID.String())
entity.SetMetadata("account_id", session.AccountID.String())
entity.SetMetadata("amount", fmt.Sprintf("%d", session.Amount))
entity.SetMetadata("currency", session.Currency)
entity.SetMetadata("status", session.Status)
entity.SetMetadata("checkout_url", session.CheckoutURL)
entity.SetMetadata("created_at", session.CreatedAt.Format(time.RFC3339))
entity.SetMetadata("expires_at", session.ExpiresAt.Format(time.RFC3339))
// Store in registry
ctx := context.Background()
err := s.registry.Register(ctx, entity)
if err != nil {
return fmt.Errorf("failed to register session: %w", err)
}
return nil
}
// entityToSession converts a registry.Entity to a Session
func (s *Service) entityToSession(entity registry.Entity) (*Session, error) {
if entity == nil {
return nil, fmt.Errorf("entity cannot be nil")
}
metadata := entity.Metadata()
// Debug: Log all metadata keys and values
s.logger.Debug("Entity metadata", "metadata", metadata)
session := &Session{
ID: entity.ID(),
TransactionID: uuid.Nil,
UserID: uuid.Nil,
AccountID: uuid.Nil,
Status: "",
CheckoutURL: "",
CreatedAt: time.Time{},
ExpiresAt: time.Time{},
}
// Parse transaction ID
if txID, ok := metadata["transaction_id"]; ok && txID != "" {
id, err := uuid.Parse(txID)
if err != nil {
return nil, fmt.Errorf("invalid transaction ID in metadata: %w", err)
}
session.TransactionID = id
}
// Parse user ID
if userID, ok := metadata["user_id"]; ok && userID != "" {
id, err := uuid.Parse(userID)
if err != nil {
return nil, fmt.Errorf("invalid user ID in metadata: %w", err)
}
session.UserID = id
}
// Parse account ID
if accountID, ok := metadata["account_id"]; ok && accountID != "" {
id, err := uuid.Parse(accountID)
if err != nil {
return nil, fmt.Errorf("invalid account ID in metadata: %w", err)
}
session.AccountID = id
}
// Parse amount
if amount, ok := metadata["amount"]; ok && amount != "" {
if amt, err := strconv.ParseInt(amount, 10, 64); err == nil {
session.Amount = amt
}
}
// Set other fields
session.Currency = metadata["currency"]
session.Status = metadata["status"]
session.CheckoutURL = metadata["checkout_url"]
// Parse timestamps
if createdAt, ok := metadata["created_at"]; ok && createdAt != "" {
if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
session.CreatedAt = t
}
}
if expiresAt, ok := metadata["expires_at"]; ok && expiresAt != "" {
if t, err := time.Parse(time.RFC3339, expiresAt); err == nil {
session.ExpiresAt = t
}
}
// Validate the session
if err := session.Validate(); err != nil {
return nil, fmt.Errorf("invalid session data in registry: %w", err)
}
return session, nil
}
// ToJSON converts a Session to its JSON representation
func (s *Session) ToJSON() ([]byte, error) {
return json.Marshal(s)
}
// FromJSON creates a Session from its JSON representation
func FromJSON(data []byte) (*Session, error) {
var s Session
if err := json.Unmarshal(data, &s); err != nil {
return nil, fmt.Errorf("failed to unmarshal session: %w", err)
}
// Validate the session after unmarshaling
if err := s.Validate(); err != nil {
return nil, fmt.Errorf("invalid session data: %w", err)
}
return &s, nil
}
package currency
import (
"context"
"fmt"
"log/slog"
"strconv"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/registry"
)
// ---- Entity ----
type Entity struct {
registry.Entity
Code money.Code `json:"code"`
Name string `json:"name"`
Symbol string `json:"symbol"`
Decimals int `json:"decimals"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
Active bool `json:"active"`
}
// Service provides business logic for currency operations
type Service struct {
registry registry.Provider
logger *slog.Logger
}
// New creates a new currency service
func New(
registry registry.Provider,
logger *slog.Logger,
) *Service {
if logger == nil {
logger = slog.Default()
}
return &Service{
registry: registry,
logger: logger.With("service", "Currency"),
}
}
// Get retrieves currency information by code
func (s *Service) Get(ctx context.Context, code string) (*money.Currency, error) {
entity, err := s.registry.Get(ctx, code)
if err != nil {
return nil, fmt.Errorf("failed to get currency: %w", err)
}
// Convert entity to currency.Meta
meta, err := toCurrency(entity)
if err != nil {
return nil, fmt.Errorf("failed to convert entity: %w", err)
}
return meta, nil
}
// ListSupported returns all supported currency codes
func (s *Service) ListSupported(ctx context.Context) ([]string, error) {
entities, err := s.registry.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list active currencies: %w", err)
}
codes := make([]string, 0, len(entities))
for _, entity := range entities {
codes = append(codes, entity.ID())
}
return codes, nil
}
// ListAll returns all registered currencies with full metadata
func (s *Service) ListAll(ctx context.Context) ([]*money.Currency, error) {
entities, err := s.registry.List(ctx)
if err != nil {
return nil, fmt.Errorf("failed to list currencies: %w", err)
}
metas := make([]*money.Currency, 0, len(entities))
for _, entity := range entities {
meta, err := toCurrency(entity)
if err != nil {
s.logger.Error("failed to convert entity to meta", "error", err, "id", entity.ID())
continue
}
metas = append(metas, meta)
}
return metas, nil
}
// Register registers a new currency
func (s *Service) Register(ctx context.Context, meta Entity) error {
// Create a new base entity
entity := registry.NewBaseEntity(meta.Code.String(), meta.Name)
// Set the active status on the entity
entity.SetActive(meta.Active)
// Set all metadata fields
entity.SetMetadata("symbol", meta.Symbol)
entity.SetMetadata("decimals", strconv.Itoa(meta.Decimals))
entity.SetMetadata("country", meta.Country)
entity.SetMetadata("region", meta.Region)
entity.SetMetadata("active", strconv.FormatBool(meta.Active))
// Store the entity in the registry
return s.registry.Register(ctx, entity)
}
// Unregister removes a currency from the registry
func (s *Service) Unregister(ctx context.Context, code string) error {
return s.registry.Unregister(ctx, code)
}
// Activate activates a currency
func (s *Service) Activate(ctx context.Context, code string) error {
return s.registry.Activate(ctx, code)
}
// Deactivate deactivates a currency
func (s *Service) Deactivate(ctx context.Context, code string) error {
return s.registry.Deactivate(ctx, code)
}
// IsSupported checks if a currency is both registered and active
func (s *Service) IsSupported(ctx context.Context, code string) bool {
entity, err := s.registry.Get(ctx, code)
if err != nil {
return false
}
return entity.Active()
}
// Search searches for currencies by name
func (s *Service) Search(
ctx context.Context,
query string,
) ([]*money.Currency, error) {
entities, err := s.registry.Search(ctx, query)
if err != nil {
return nil, fmt.Errorf("failed to search currencies: %w", err)
}
metas := make([]*money.Currency, 0, len(entities))
for _, entity := range entities {
meta, err := toCurrency(entity)
if err != nil {
s.logger.Error(
"failed to convert entity to meta",
"error",
err,
"id",
entity.ID(),
)
continue
}
metas = append(metas, meta)
}
return metas, nil
}
// SearchByRegion searches for currencies by region
func (s *Service) SearchByRegion(
ctx context.Context,
region string,
) ([]*money.Currency, error) {
entities, err := s.registry.SearchByMetadata(ctx, map[string]string{"region": region})
if err != nil {
return nil, fmt.Errorf("failed to search currencies by region: %w", err)
}
metas := make([]*money.Currency, 0, len(entities))
for _, entity := range entities {
meta, err := toCurrency(entity)
if err != nil {
s.logger.Error(
"failed to convert entity to meta",
"error",
err,
"id",
entity.ID(),
)
continue
}
metas = append(metas, meta)
}
return metas, nil
}
// GetStatistics returns currency statistics
func (s *Service) GetStatistics(
ctx context.Context,
) (map[string]any, error) {
total, err := s.registry.Count(ctx)
if err != nil {
return nil, fmt.Errorf("failed to count total currencies: %w", err)
}
active, err := s.registry.CountActive(ctx)
if err != nil {
return nil, fmt.Errorf("failed to count active currencies: %w", err)
}
return map[string]any{
"total_currencies": total,
"active_currencies": active,
}, nil
}
// ValidateCode validates a currency code format
func (s *Service) ValidateCode(
ctx context.Context,
code string,
) error {
if !money.Code(code).IsValid() {
return money.ErrInvalidCurrency
}
return nil
}
// GetDefault returns the default currency information
func (s *Service) GetDefault(
ctx context.Context,
) (*money.Currency, error) {
entity, err := s.registry.Get(ctx, money.DefaultCode.String())
if err != nil {
return nil, fmt.Errorf("failed to get default currency: %w", err)
}
return toCurrency(entity)
}
// toCurrency converts a registry.Entity to money.Currency
// Note: The Active field is not part of money.Currency, so we'll just return the currency info
// without the active status. The active status should be checked using IsSupported() instead.
func toCurrency(entity registry.Entity) (*money.Currency, error) {
if entity == nil {
return nil, fmt.Errorf("entity is nil")
}
// Get metadata
metadata := entity.Metadata()
// Parse decimals
decimals := 2 // default
if decStr, ok := metadata["decimals"]; ok {
if d, err := strconv.Atoi(decStr); err == nil {
decimals = d
}
}
return &money.Currency{
Code: money.Code(entity.ID()),
Decimals: decimals,
}, nil
}
package exchange
import (
"context"
"errors"
"fmt"
"log/slog"
"math"
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/amirasaad/fintech/pkg/registry"
)
// ---- Errors ----
var (
ErrInvalidAmount = errors.New("invalid amount")
ErrNoProvidersAvailable = errors.New("no exchange rate providers available")
ErrInvalidExchangeRate = errors.New("invalid exchange rate")
)
// ---- Constants ----
const (
// DefaultCacheTTL is the default time-to-live for cached exchange rates
DefaultCacheTTL = 15 * time.Minute
// LastUpdatedKey is the key used to store the last update timestamp
LastUpdatedKey = "exr:last_updated"
)
// ---- Entity ----
type ExchangeRateInfo struct {
registry.BaseEntity
From string
To string
Rate float64
Source string
Timestamp time.Time
}
// newExchangeRateInfo constructs an ExchangeRateInfo with initialized BaseEntity (ID and Name)
// to satisfy registry validation requirements and ensure proper caching behavior.
func newExchangeRateInfo(
from, to string,
rate float64,
source string,
) *ExchangeRateInfo {
id := fmt.Sprintf("%s:%s", from, to)
return &ExchangeRateInfo{
BaseEntity: *registry.NewBaseEntity(id, id),
From: from,
To: to,
Rate: rate,
Source: source,
Timestamp: time.Now().UTC(),
}
}
// ---- Conversion Helpers ----
// ---- Helper Functions ----
func validateAmount(amount *money.Money) error {
if amount == nil {
return errors.New("amount cannot be nil")
}
if amount.IsNegative() || amount.IsZero() {
return ErrInvalidAmount
}
return nil
}
// ---- Service ----
// Service handles currency exchange operations with cache-first approach
type Service struct {
provider exchange.Exchange
registry registry.Provider // Registry for cached exchange rates
logger *slog.Logger
}
// New creates a new exchange service with the given registry and provider
func New(
registry registry.Provider,
provider exchange.Exchange,
log *slog.Logger,
) *Service {
if log == nil {
log = slog.Default()
}
return &Service{
provider: provider,
logger: log,
registry: registry,
}
}
// processAndCacheRate validates, logs, and caches a rate with TTL support.
// It uses the exchange cache to handle the actual caching.
// This is a convenience method that wraps the bulk caching functionality
// for a single rate.
// The context is used for cancellation and deadline propagation to the registry.
func (s *Service) processAndCacheRate(
ctx context.Context,
from, to string,
rate *exchange.RateInfo,
) {
if rate == nil {
err := fmt.Errorf("provider %s returned nil rate", s.provider.Metadata().Name)
s.logger.Error("Failed to fetch exchange rate",
"from", from,
"to", to,
"provider", s.provider.Metadata().Name,
"error", err,
)
return
}
// Create rate info with current timestamp
rateInfo := newExchangeRateInfo(from, to, rate.Rate, s.provider.Metadata().Name)
// Store last updated timestamp in metadata
rateInfo.SetMetadata("last_updated", time.Now().UTC().Format(time.RFC3339Nano))
// Register the direct rate (from -> to)
if err := s.registry.Register(
ctx,
rateInfo,
); err != nil {
s.logger.Error("Failed to cache exchange rate",
"from", from,
"to", to,
"error", err,
)
}
// Also register the inverse rate (to -> from) if not 1:1
if math.Abs(rate.Rate) > 1e-10 { // Avoid division by zero
inverseRate := 1.0 / rate.Rate
// Create inverse rate info with current timestamp
inverseInfo := newExchangeRateInfo(to, from, inverseRate, s.provider.Metadata().Name)
inverseInfo.SetMetadata("last_updated", time.Now().UTC().Format(time.RFC3339Nano))
if err := s.registry.Register(
ctx,
inverseInfo,
); err != nil {
s.logger.Error("Failed to cache inverse exchange rate",
"from", to,
"to", from,
"error", err,
)
}
}
}
// ---- Public Service Methods ----
func (s *Service) Name() string { return "ExchangeService" }
func (s *Service) IsHealthy() bool { return true }
// Convert converts an amount from one currency to another.
// It first checks the cache for a valid rate, and if not found, fetches it from the provider.
func (s *Service) Convert(
ctx context.Context,
amount *money.Money,
to money.Code,
) (*money.Money, *exchange.RateInfo, error) {
if err := validateAmount(amount); err != nil {
return nil, nil, fmt.Errorf("invalid amount: %w", err)
}
from := amount.Currency().String()
toStr := to.String()
// Check if conversion is needed
if from == toStr {
return amount, &exchange.RateInfo{
FromCurrency: from,
ToCurrency: toStr,
Rate: 1.0,
Provider: "identity",
Timestamp: time.Now(),
}, nil
}
// Try to get rate from cache first
rate, err := s.GetRate(ctx, from, toStr)
if err != nil {
return nil, nil, fmt.Errorf("failed to get exchange rate: %w", err)
}
// Convert the amount
converted, err := amount.Multiply(rate.Rate)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert amount: %w", err)
}
result, err := money.New(converted.AmountFloat(), to)
if err != nil {
return nil, nil, fmt.Errorf("failed to create money: %w", err)
}
return result, rate, nil
}
// GetRate gets the exchange rate between two currencies with cache-first approach
func (s *Service) GetRate(
ctx context.Context,
from,
to string,
) (*exchange.RateInfo, error) {
// Check for invalid input
if from == "" || to == "" {
return nil, fmt.Errorf("invalid currency codes: from='%s', to='%s'", from, to)
}
// Check if it's the same currency
if from == to {
return &exchange.RateInfo{
FromCurrency: from,
ToCurrency: to,
Rate: 1.0,
Provider: "identity",
Timestamp: time.Now(),
}, nil
}
// Try to get from cache first
if rate, ok := s.getRateFromCache(ctx, from, to); ok {
return rate, nil
}
// Fallback to provider
if s.provider == nil {
return nil, ErrNoProvidersAvailable
}
rate, err := s.provider.FetchRate(ctx, from, to)
if err != nil {
return nil, fmt.Errorf("failed to fetch rates from provider: %w", err)
}
s.processAndCacheRate(ctx, from, to, rate)
return rate, nil
}
func (s *Service) IsSupported(from, to string) bool {
if from == to {
return true
}
return s.provider.IsSupported(from, to)
}
// ---- Private Service Methods ----
func (s *Service) getRateFromCache(
ctx context.Context,
from, to string,
) (*exchange.RateInfo, bool) {
key := fmt.Sprintf("%s:%s", from, to)
entity, err := s.registry.Get(ctx, key)
if err != nil {
s.logger.Debug("Cache miss (error)", "key", key, "error", err)
return nil, false
}
if entity == nil {
s.logger.Debug("Cache miss (not found)", "key", key)
return nil, false
}
// Check if we can get the rate directly
if rateInfo, ok := entity.(*ExchangeRateInfo); ok {
s.logger.Debug("Cache hit", "key", key, "rate", rateInfo.Rate)
return &exchange.RateInfo{
FromCurrency: rateInfo.From,
ToCurrency: rateInfo.To,
Rate: rateInfo.Rate,
Provider: rateInfo.Source,
Timestamp: rateInfo.Timestamp,
}, true
}
return nil, false
}
package stripeconnect
import (
"context"
"errors"
"fmt"
"log/slog"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/google/uuid"
"github.com/stripe/stripe-go/v82"
)
// Service defines the interface for Stripe Connect operations
type Service interface {
// CreateAccount creates a new Stripe Connect account for a user
CreateAccount(ctx context.Context, userID uuid.UUID) (*stripe.Account, error)
// GenerateOnboardingURL generates a Stripe onboarding URL for the user
GenerateOnboardingURL(ctx context.Context, userID uuid.UUID) (string, error)
// GetAccount retrieves the Stripe Connect account for a user
GetAccount(ctx context.Context, userID uuid.UUID) (*stripe.Account, error)
// IsOnboardingComplete checks if the user has completed Stripe onboarding
IsOnboardingComplete(ctx context.Context, userID uuid.UUID) (bool, error)
}
type stripeConnectService struct {
client *stripe.Client
uow repository.UnitOfWork
cfg *config.Stripe
}
// Config holds the configuration for the Stripe Connect serviced
// New creates a new instance of the Stripe Connect service using the official Stripe client
// Deprecated: Use NewClientService instead for better client management
// New creates a new instance of the Stripe Connect service
func New(
uow repository.UnitOfWork,
logger *slog.Logger,
cfg *config.Stripe,
) Service {
return &stripeConnectService{
client: stripe.NewClient(cfg.ApiKey),
uow: uow,
cfg: cfg,
}
}
func (s *stripeConnectService) CreateAccount(
ctx context.Context,
userID uuid.UUID,
) (*stripe.Account, error) {
userRepo, err := common.GetUserRepository(s.uow, slog.Default())
if err != nil {
return nil, fmt.Errorf("failed to get user repository: %w", err)
}
// Check if user already has a Stripe account
existingAccountID, err := userRepo.GetStripeAccountID(ctx, userID)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return nil, fmt.Errorf("failed to check existing account: %w", err)
}
if existingAccountID != "" {
// Account already exists, return it
acct, err := s.client.V1Accounts.GetByID(ctx, existingAccountID, nil)
if err != nil {
return nil, fmt.Errorf("failed to get existing Stripe account: %w", err)
}
return acct, nil
}
// Create a new Stripe Connect account
params := &stripe.AccountCreateParams{
Type: stripe.String(string(stripe.AccountTypeExpress)),
// NOTE: Country is hardcoded to "US" for now. This should be made configurable
// or derived from user profile data in a future enhancement.
Country: stripe.String("US"),
Capabilities: &stripe.AccountCreateCapabilitiesParams{
CardPayments: &stripe.AccountCreateCapabilitiesCardPaymentsParams{
Requested: stripe.Bool(true),
},
Transfers: &stripe.AccountCreateCapabilitiesTransfersParams{
Requested: stripe.Bool(true),
},
},
}
acct, err := s.client.V1Accounts.Create(ctx, params)
if err != nil {
return nil, fmt.Errorf("failed to create Stripe account: %w", err)
}
// Save the Stripe account ID to the user
err = userRepo.UpdateStripeAccount(ctx, userID, acct.ID, false)
if err != nil {
// Try to clean up the Stripe account if we can't save the reference
_, _ = s.client.V1Accounts.Delete(ctx, acct.ID, nil) // nolint:errcheck
return nil, fmt.Errorf("failed to save Stripe account ID: %w", err)
}
return acct, nil
}
func (s *stripeConnectService) GenerateOnboardingURL(
ctx context.Context,
userID uuid.UUID,
) (string, error) {
// Get or create Stripe account
acct, err := s.CreateAccount(ctx, userID)
if err != nil {
return "", fmt.Errorf("failed to get or create Stripe account: %w", err)
}
// Create account link for onboarding
params := &stripe.AccountLinkCreateParams{
Account: stripe.String(acct.ID),
RefreshURL: stripe.String(s.cfg.OnboardingRefreshURL),
ReturnURL: stripe.String(s.cfg.OnboardingReturnURL),
Type: stripe.String("account_onboarding"),
}
result, err := s.client.V1AccountLinks.Create(ctx, params)
if err != nil {
return "", fmt.Errorf("failed to create account link: %w", err)
}
return result.URL, nil
}
func (s *stripeConnectService) GetAccount(
ctx context.Context,
userID uuid.UUID,
) (*stripe.Account, error) {
userRepo, err := common.GetUserRepository(s.uow, slog.Default())
if err != nil {
return nil, fmt.Errorf("failed to get user repository: %w", err)
}
stripeAccountID, err := userRepo.GetStripeAccountID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get Stripe account ID: %w", err)
}
if stripeAccountID == "" {
return nil, domain.ErrNotFound
}
acct, err := s.client.V1Accounts.GetByID(ctx, stripeAccountID, nil)
if err != nil {
return nil, fmt.Errorf("failed to get Stripe account: %w", err)
}
return acct, nil
}
func (s *stripeConnectService) IsOnboardingComplete(
ctx context.Context,
userID uuid.UUID,
) (bool, error) {
userRepo, err := common.GetUserRepository(s.uow, slog.Default())
if err != nil {
return false, fmt.Errorf("failed to get user repository: %w", err)
}
// First check our local database
status, err := userRepo.GetStripeOnboardingStatus(ctx, userID)
if err != nil && !errors.Is(err, domain.ErrNotFound) {
return false, fmt.Errorf("failed to get local onboarding status: %w", err)
}
// If we have a local status, return it
if status {
return true, nil
}
// Otherwise, check with Stripe
acct, err := s.GetAccount(ctx, userID)
if err != nil {
return false, fmt.Errorf("failed to get account: %w", err)
}
// Check if onboarding is complete
onboardingComplete := acct.DetailsSubmitted && acct.PayoutsEnabled
// Update our local database with the current status
err = userRepo.UpdateStripeAccount(ctx, userID, acct.ID, onboardingComplete)
if err != nil {
return false, fmt.Errorf("failed to update local onboarding status: %w", err)
}
return onboardingComplete, nil
}
// Package user provides business logic for user management operations.
// It uses the decorator pattern for transaction management and includes comprehensive logging.
package user
import (
"context"
"log/slog"
"github.com/amirasaad/fintech/pkg/domain/user"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/handler/common"
"github.com/amirasaad/fintech/pkg/repository"
"github.com/amirasaad/fintech/pkg/utils"
"github.com/google/uuid"
)
// Service provides business logic for user operations including creation, updates, and deletion.
type Service struct {
uow repository.UnitOfWork
logger *slog.Logger
}
// New creates a new Service with a UnitOfWork and logger.
func New(
uow repository.UnitOfWork,
logger *slog.Logger,
) *Service {
return &Service{
uow: uow,
logger: logger,
}
}
// CreateUser creates a new user account in a transaction.
func (s *Service) CreateUser(
ctx context.Context,
username, email, password string,
) (u *user.User, err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, slog.Default())
if err != nil {
return err
}
u, err = user.New(username, email, password)
if err != nil {
return err
}
return repo.Create(ctx, &dto.UserCreate{
ID: u.ID,
Username: u.Username,
Email: u.Email,
Password: u.Password,
})
})
if err != nil {
u = nil
}
return
}
// GetUser retrieves a user by ID in a transaction.
func (s *Service) GetUser(
ctx context.Context,
userID string,
) (u *dto.UserRead, err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, slog.Default())
if err != nil {
return err
}
uid, parseErr := uuid.Parse(userID)
if parseErr != nil {
return parseErr
}
u, err = repo.Get(ctx, uid)
return err
})
if err != nil {
u = nil
}
return
}
// GetUserByEmail retrieves a user by email in a transaction.
func (s *Service) GetUserByEmail(
ctx context.Context,
email string,
) (u *dto.UserRead, err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, s.logger)
if err != nil {
return err
}
u, err = repo.GetByEmail(ctx, email)
return err
})
if err != nil {
u = nil
}
return
}
// GetUserByUsername retrieves a user by username in a transaction.
func (s *Service) GetUserByUsername(
ctx context.Context,
username string,
) (u *dto.UserRead, err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, s.logger)
if err != nil {
return err
}
u, err = repo.GetByUsername(ctx, username)
return err
})
if err != nil {
u = nil
}
return
}
// UpdateUser updates user information in a transaction.
func (s *Service) UpdateUser(
ctx context.Context,
userID string,
update *dto.UserUpdate,
) (err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, s.logger)
if err != nil {
return err
}
uid, parseErr := uuid.Parse(userID)
if parseErr != nil {
return parseErr
}
u, err := repo.Get(ctx, uid)
if err != nil {
return err
}
if u == nil {
return user.ErrUserNotFound
}
return repo.Update(ctx, uid, update)
})
return
}
// DeleteUser deletes a user account in a transaction.
func (s *Service) DeleteUser(
ctx context.Context,
userID string,
) (err error) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, s.logger)
if err != nil {
return err
}
uid, parseErr := uuid.Parse(userID)
if parseErr != nil {
return parseErr
}
return repo.Delete(ctx, uid)
})
return
}
// ValidUser validates user credentials in a transaction.
func (s *Service) ValidUser(
ctx context.Context,
identifier string, // Can be either email or username
password string,
) (
valid bool,
err error,
) {
err = s.uow.Do(ctx, func(uow repository.UnitOfWork) error {
repo, err := common.GetUserRepository(s.uow, s.logger)
if err != nil {
return err
}
// Try to get u by email first
u, err := repo.GetByEmail(ctx, identifier)
if err != nil || u == nil {
// If not found by email, try by username
u, err = repo.GetByUsername(ctx, identifier)
if err != nil || u == nil {
// User not found by either email or username
return nil
}
}
// Check if the provided password matches the stored hash
valid = utils.CheckPasswordHash(password, u.HashedPassword)
return nil
})
return
}
package utils
import (
"net/mail"
"golang.org/x/crypto/bcrypt"
)
// HashPassword hashes a plain password using bcrypt with cost 14.
func HashPassword(password string) (string, error) {
return hashPassword(password)
}
func hashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
return string(bytes), err
}
// CheckPasswordHash compares a plain password with a bcrypt hash.
func CheckPasswordHash(password, hash string) bool {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil
}
// IsEmail returns true if the string is a valid email address.
func IsEmail(email string) bool {
_, err := mail.ParseAddress(email)
return err == nil
}
package main
import (
"context"
"log/slog"
"os"
"strings"
"time"
"github.com/segmentio/kafka-go"
)
// RunSmokeTest produces and consumes messages on multiple topics
// to verify Kafka cluster functionality locally.
func RunSmokeTest() error {
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelInfo}))
brokers := strings.TrimSpace(os.Getenv("BROKERS"))
if brokers == "" {
brokers = "localhost:9093,localhost:9092"
}
groupID := strings.TrimSpace(os.Getenv("GROUP_ID"))
if groupID == "" {
groupID = "fintech"
}
topics := []string{
"fintech.events.test.event",
"fintech.events.other.event",
}
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// Create topics if they don't exist
{
dialer := &kafka.Dialer{Timeout: 5 * time.Second}
conn, err := dialer.DialContext(ctx, "tcp", strings.Split(brokers, ",")[0])
if err != nil {
logger.Error("dial failed", "error", err)
return err
}
defer func() { _ = conn.Close() }()
for _, t := range topics {
err = conn.CreateTopics(kafka.TopicConfig{
Topic: t,
NumPartitions: 1,
ReplicationFactor: 1,
})
if err != nil && !strings.Contains(strings.ToLower(err.Error()), "already exists") {
logger.Error("create topic failed", "topic", t, "error", err)
return err
}
logger.Info("topic ready", "topic", t)
}
}
// Produce messages
w := &kafka.Writer{
Addr: kafka.TCP(strings.Split(brokers, ",")...),
AllowAutoTopicCreation: true,
RequiredAcks: kafka.RequireOne,
Balancer: &kafka.Hash{},
}
defer func() { _ = w.Close() }()
for i, t := range topics {
err := w.WriteMessages(ctx, kafka.Message{
Topic: t,
Key: []byte("key"),
Value: []byte("message-" + time.Now().Format(time.RFC3339Nano)),
Time: time.Now(),
})
if err != nil {
logger.Error("write failed", "topic", t, "error", err)
return err
}
logger.Info("produced", "topic", t, "index", i)
}
// Consume messages
for _, t := range topics {
r := kafka.NewReader(kafka.ReaderConfig{
Brokers: strings.Split(brokers, ","),
GroupID: groupID,
Topic: t,
StartOffset: kafka.FirstOffset,
MinBytes: 1,
MaxBytes: 10e6,
MaxWait: 500 * time.Millisecond,
})
defer func(rd *kafka.Reader) { _ = rd.Close() }(r)
readCtx, cancelRead := context.WithTimeout(ctx, 10*time.Second)
defer cancelRead()
msg, err := r.FetchMessage(readCtx)
if err != nil {
logger.Error("fetch failed", "topic", t, "error", err)
return err
}
logger.Info("consumed", "topic", t, "value", string(msg.Value))
_ = r.CommitMessages(ctx, msg)
}
logger.Info("kafka smoke test passed")
return nil
}
// main runs the smoke test and exits non-zero on failure.
func main() {
if err := RunSmokeTest(); err != nil {
os.Exit(1)
}
}
package account
import (
"errors"
"strings"
"github.com/amirasaad/fintech/pkg/commands"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/middleware"
"github.com/amirasaad/fintech/pkg/money"
accountsvc "github.com/amirasaad/fintech/pkg/service/account"
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
stripeconnectsvc "github.com/amirasaad/fintech/pkg/service/stripeconnect"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// Routes registers HTTP routes for account-related operations using the Fiber web framework.
// It sets up endpoints for creating accounts,
// depositing and withdrawing funds, retrieving account balances,
// and listing account transactions.
// All routes are protected by authentication middleware and require a valid user context.
//
// Routes:
// - POST /account : Create a new account for the authenticated user.
// - POST /account/:id/deposit : Deposit funds into the specified account.
// - POST /account/:id/withdraw : Withdraw funds from the specified account.
// - GET /account/:id/balance : Retrieve the balance of the specified account.
// - GET /account/:id/transactions : List transactions for the specified account.
func Routes(
app *fiber.App,
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
stripeConnectSvc stripeconnectsvc.Service,
cfg *config.App,
) {
// List all accounts for the authenticated user
app.Get(
"/accounts",
middleware.JwtProtected(cfg.Auth.Jwt),
ListUserAccounts(accountSvc, authSvc),
)
// Create a new account
app.Post(
"/account",
middleware.JwtProtected(cfg.Auth.Jwt),
CreateAccount(accountSvc, authSvc),
)
app.Post(
"/account/:id/deposit",
middleware.JwtProtected(cfg.Auth.Jwt),
Deposit(accountSvc, authSvc),
)
app.Post(
"/account/:id/withdraw",
middleware.JwtProtected(cfg.Auth.Jwt),
Withdraw(accountSvc, authSvc),
)
app.Post(
"/account/:id/transfer",
middleware.JwtProtected(cfg.Auth.Jwt),
Transfer(accountSvc, authSvc),
)
// Get account balance
app.Get(
"/account/:id/balance",
middleware.JwtProtected(cfg.Auth.Jwt),
GetBalance(accountSvc, authSvc),
)
// Stripe Connect routes
if stripeConnectSvc != nil {
stripeHandlers := NewStripeConnectHandlers(stripeConnectSvc, authSvc)
jwtMiddleware := middleware.JwtProtected(cfg.Auth.Jwt)
// Create a group for all Stripe Connect routes with /stripe prefix
stripeGroup := app.Group("/stripe")
stripeHandlers.MapRoutes(stripeGroup, jwtMiddleware)
}
app.Get(
"/account/:id/transactions",
middleware.JwtProtected(cfg.Auth.Jwt),
GetTransactions(accountSvc, authSvc),
)
}
// ListUserAccounts returns a Fiber handler that retrieves all accounts for the authenticated user.
// @Summary List user accounts
// @Description Retrieves all accounts belonging to the authenticated user.
// @Tags accounts
// @Accept json
// @Produce json
// @Success 200 {object} common.Response{data=[]dto.AccountRead} "List of user accounts"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /accounts [get]
// @Security Bearer
func ListUserAccounts(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
accounts, err := accountSvc.ListUserAccounts(c.Context(), userID)
if err != nil {
log.Error("failed to list user accounts", "error", err, "user_id", userID)
return common.ProblemDetailsJSON(c, "Failed to list accounts", err)
}
if accounts == nil {
accounts = []*dto.AccountRead{} // Return empty array instead of null
}
log.Info("successfully listed user accounts", "count", len(accounts), "user_id", userID)
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Accounts retrieved successfully",
accounts,
)
}
}
// CreateAccount returns a Fiber handler for creating a new account for the current user.
// It extracts the user ID from the request context,
// initializes the account service using the provided
// UnitOfWork factory, and attempts to create a new account.
// On success, it returns the created account as JSON.
// On failure, it logs the error and returns an appropriate error response.
// @Summary Create a new account
// @Description Creates a new account for the authenticated user.
//
// You can specify the currency for the account.
// Returns the created account details.
//
// @Tags accounts
// @Accept json
// @Produce json
// @Success 201 {object} common.Response "Account created successfully"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account [post]
// @Security Bearer
func CreateAccount(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
log.Info("creating new account")
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
input, err := common.BindAndValidate[CreateAccountRequest](c)
if input == nil {
return err // error response already written
}
a, err := accountSvc.CreateAccount(
c.Context(),
dto.AccountCreate{
UserID: userID,
Currency: input.Currency,
},
)
if err != nil {
log.Error("failed to create account", "error", err)
if strings.Contains(err.Error(), "user already has an account with currency") {
return common.ProblemDetailsJSON(
c,
"Account creation failed",
err,
"You already have an account with this currency.",
fiber.StatusConflict, // 409 Conflict
)
}
return common.ProblemDetailsJSON(c, "Failed to create account", err)
}
log.Info("account created", "account_id", a.ID)
return common.SuccessResponseJSON(
c,
fiber.StatusCreated,
"Account created",
a,
)
}
}
// Deposit returns a Fiber handler for depositing an amount into a user's account.
// It expects a UnitOfWork factory function as a dependency for transactional operations.
// The handler parses the current user ID from the request context,
// validates the account ID from the URL,
// and parses the deposit amount from the request body.
// If successful, it performs the deposit operation using
// the AccountService and returns the transaction as JSON.
// On error, it logs the issue and returns an appropriate JSON error response.
// @Summary Deposit funds into an account
// @Description Adds funds to the specified account. Specify the amount, currency,
// and optional money source. Returns the transaction details.
// @Tags accounts
// @Accept json
// @Produce json
// @Param id path string true "Account ID"
// @Param request body DepositRequest true "Deposit details"
// @Success 200 {object} common.Response "Deposit successful"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account/{id}/deposit [post]
// @Security Bearer
func Deposit(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
log.Info("deposit handler called", "account_id", c.Params("id"))
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
accountID, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Error("invalid account ID for deposit", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid account ID",
err,
"Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
input, err := common.BindAndValidate[DepositRequest](c)
if input == nil {
return err // error response already written
}
currencyCode := money.USD
if input.Currency != "" {
currencyCode = money.Code(input.Currency)
}
depositCmd := commands.Deposit{
UserID: userID,
AccountID: accountID,
Amount: input.Amount,
Currency: string(currencyCode),
// Add MoneySource, TargetCurrency, etc. if needed
}
err = accountSvc.Deposit(c.Context(), depositCmd)
if err != nil {
log.Error(
"failed to process deposit",
"error",
err,
"user_id",
userID,
"account_id",
accountID,
)
return common.ProblemDetailsJSON(c, "Failed to process deposit", err)
}
log.Info("successfully processed deposit", "account_id", accountID, "user_id", userID)
return common.SuccessResponseJSON(
c,
fiber.StatusAccepted,
"Deposit request is being processed. "+
"Your deposit is being started and will be completed soon.",
fiber.Map{},
)
}
}
// Withdraw returns a Fiber handler for processing account withdrawal requests.
// It expects a UnitOfWork factory function as a dependency for transactional operations.
//
// The handler performs the following steps:
// 1. Retrieves the current user ID from the request context.
// 2. Parses the account ID from the route parameters.
// 3. Parses the withdrawal amount from the request body.
// 4. Calls the AccountService.Withdraw method to process the withdrawal.
// 5. Returns the transaction details as a JSON response on success.
//
// Error responses are returned in JSON format with appropriate status codes
// if any step fails (e.g., invalid user ID, invalid account ID,
//
// parsing errors, or withdrawal errors).
//
// @Summary Withdraw funds from an account
// @Description Withdraws a specified amount from the user's account.
// Specify the amount and currency. Returns the transaction details.
//
// @Tags accounts
// @Accept json
// @Produce json
// @Param id path string true "Account ID"
// @Param request body WithdrawRequest true "Withdrawal details"
// @Success 200 {object} common.Response "Withdrawal successful"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account/{id}/withdraw [post]
// @Security Bearer
func Withdraw(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
accountID, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Error("invalid account ID for withdrawal", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid account ID",
err,
"Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
input, err := common.BindAndValidate[WithdrawRequest](c)
if input == nil {
return err // error response already written
}
// Validate that at least one field in ExternalTarget is present
if input.ExternalTarget.BankAccountNumber == "" &&
input.ExternalTarget.RoutingNumber == "" &&
input.ExternalTarget.ExternalWalletAddress == "" {
return common.ProblemDetailsJSON(
c,
"Invalid external target",
nil,
"At least one external target field must be provided",
fiber.StatusBadRequest,
)
}
// Validate and parse currency code
currencyCode := money.Code(input.Currency)
if currencyCode == "" {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
nil,
"Please provide a valid 3-letter ISO 4217 currency code",
fiber.StatusBadRequest,
)
}
withdrawCmd := commands.Withdraw{
UserID: userID,
AccountID: accountID,
Amount: input.Amount,
Currency: string(currencyCode),
}
if input.ExternalTarget != nil {
withdrawCmd.ExternalTarget = &commands.ExternalTarget{
BankAccountNumber: input.ExternalTarget.BankAccountNumber,
RoutingNumber: input.ExternalTarget.RoutingNumber,
ExternalWalletAddress: input.ExternalTarget.ExternalWalletAddress,
}
}
if err = accountSvc.Withdraw(c.Context(), withdrawCmd); err != nil {
log.Error(
"failed to process withdrawal",
"error",
err,
"user_id",
userID,
"account_id",
accountID,
)
// Handle Stripe Connect onboarding error specifically
if errors.Is(err, domain.ErrStripeOnboardingIncomplete) {
return common.ProblemDetailsJSON(
c,
"Stripe Connect onboarding required",
err,
"Please complete Stripe Connect onboarding before making a withdrawal",
fiber.StatusForbidden,
)
}
// Handle insufficient funds error
if strings.Contains(err.Error(), "insufficient funds") {
return common.ProblemDetailsJSON(
c,
"Insufficient funds",
err,
"Your account does not have sufficient funds for this withdrawal",
fiber.StatusBadRequest,
)
}
return common.ProblemDetailsJSON(c, "Failed to process withdrawal", err)
}
return common.SuccessResponseJSON(
c,
fiber.StatusAccepted,
"Withdrawal request is being processed. "+
"Your withdrawal is being started and will be completed soon.",
fiber.Map{},
)
}
}
// Transfer returns a Fiber handler for transferring funds between accounts.
// @Summary Transfer funds between accounts
// @Description Transfers a specified amount from one account to another.
// Specify the source and destination account IDs, amount, and currency.
// Returns the transaction details.
// @Tags accounts
// @Accept json
// @Produce json
// @Param id path string true "Source Account ID"
// @Param request body TransferRequest true "Transfer details"
// @Success 200 {object} common.Response "Transfer successful"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 422 {object} common.ProblemDetails "Unprocessable entity"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account/{id}/transfer [post]
// @Security Bearer
func Transfer(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
log.Info("transfer handler called", "account_id", c.Params("id"))
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
sourceAccountID, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Error("invalid source account ID for transfer", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid account ID",
err,
"Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
input, err := common.BindAndValidate[TransferRequest](c)
if input == nil {
return err // error response already written
}
destAccountID, err := uuid.Parse(input.DestinationAccountID)
if err != nil {
log.Error("invalid destination account ID for transfer", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid destination account ID",
err,
"Destination Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
currencyCode := money.USD
if input.Currency != "" {
currencyCode = money.Code(input.Currency)
}
// Construct transfer command
cmd := commands.Transfer{
UserID: userID,
AccountID: sourceAccountID,
ToAccountID: destAccountID,
Amount: input.Amount,
Currency: currencyCode.String(),
}
err = accountSvc.Transfer(c.Context(), cmd)
if err != nil {
log.Error(
"failed to transfer funds",
"error",
err,
"user_id",
userID,
"account_id",
sourceAccountID,
)
return common.ProblemDetailsJSON(c, "Failed to transfer", err)
}
log.Info("successfully transferred funds",
"amount", input.Amount,
"currency", input.Currency,
"from_account_id", sourceAccountID,
"to_account_id", destAccountID,
"user_id", userID,
)
return common.SuccessResponseJSON(
c,
fiber.StatusAccepted,
"Transfer request is being processed. "+
"Your transfer is being started and will be completed soon.",
fiber.Map{},
)
}
}
// GetTransactions returns a Fiber handler that retrieves the list of transactions
//
// for a specific account.
//
// It expects a UnitOfWork factory function as a dependency for service instantiation.
// The handler extracts the current user ID from the request context and
// parses the account ID from the URL parameters.
// On success, it returns the transactions as a JSON response. On error,
// it logs the error and returns an appropriate JSON error response.
// @Summary Get account transactions
// @Description Retrieves a list of transactions for the specified account.
// Returns an array of transaction details.
// @Tags accounts
// @Accept json
// @Produce json
// @Param id path string true "Account ID"
// @Success 200 {object} common.Response "Transactions fetched"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account/{id}/transactions [get]
// @Security Bearer
func GetTransactions(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
id, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Error("invalid account ID for transactions", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid account ID",
err,
"Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
tx, err := accountSvc.GetTransactions(c.Context(), userID, id)
if err != nil {
log.Error(
"failed to list transactions for account ID %s",
"error",
err,
"account_id",
id,
)
return common.ProblemDetailsJSON(c, "Failed to list transactions", err)
}
dtos := make([]*TransactionDTO, 0, len(tx))
for _, t := range tx {
dtos = append(dtos, &TransactionDTO{
ID: t.ID.String(),
UserID: t.UserID.String(),
AccountID: t.AccountID.String(),
Amount: t.Amount,
Currency: string(t.Currency),
Balance: t.Balance,
CreatedAt: t.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
})
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Transactions fetched",
dtos,
)
}
}
// GetBalance returns a Fiber handler for retrieving the balance of a specific account.
// It expects a UnitOfWork factory function as a dependency for service instantiation.
// The handler extracts the current user ID from the request context and
// parses the account ID from the URL parameters.
// On success, it returns the account balance as a JSON response.
// On error, it logs the error and returns an appropriate JSON error response.
// @Summary Get account balance
// @Description Retrieves the current balance for the specified account.
// Returns the balance amount and currency.
// @Tags accounts
// @Accept json
// @Produce json
// @Param id path string true "Account ID"
// @Success 200 {object} common.Response "Balance fetched"
// @Failure 400 {object} common.ProblemDetails "Invalid request"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 429 {object} common.ProblemDetails "Too many requests"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /account/{id}/balance [get]
// @Security Bearer
func GetBalance(
accountSvc *accountsvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Error("failed to get user ID from token", "error", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
id, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Error("invalid account ID for balance", "error", err)
return common.ProblemDetailsJSON(
c,
"Invalid account ID",
err,
"Account ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
balance, err := accountSvc.GetBalance(c.Context(), userID, id)
if err != nil {
log.Errorf("Failed to fetch balance for account ID %s: %v", id, err)
return common.ProblemDetailsJSON(
c,
"Failed to fetch balance",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Balance fetched",
fiber.Map{"balance": balance},
)
}
}
package account
import (
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/provider/exchange"
)
//revive:disable
// CreateAccountRequest represents the request body for creating a new account.
type CreateAccountRequest struct {
Currency string `json:"currency" validate:"omitempty,len=3,uppercase,alpha"`
}
// DepositRequest represents the request body for depositing funds into an account.
type DepositRequest struct {
Amount float64 `json:"amount" xml:"amount" form:"amount" validate:"required,gt=0"`
Currency string `json:"currency" validate:"omitempty,len=3,uppercase"`
MoneySource string `json:"money_source" validate:"required,min=2,max=64"`
}
// ExternalTarget represents the destination for an external withdrawal, such as a bank account or wallet.
type ExternalTarget struct {
BankAccountNumber string `json:"bank_account_number,omitempty" validate:"omitempty,min=6,max=34"`
RoutingNumber string `json:"routing_number,omitempty" validate:"omitempty,min=6,max=12"`
ExternalWalletAddress string `json:"external_wallet_address,omitempty" validate:"omitempty,min=6,max=128"`
}
// WithdrawRequest represents the request body for withdrawing funds from an account.
type WithdrawRequest struct {
Amount float64 `json:"amount" xml:"amount" form:"amount" validate:"required,gt=0"`
Currency string `json:"currency" validate:"omitempty,len=3,uppercase"`
ExternalTarget *ExternalTarget `json:"external_target" validate:"required"`
}
// TransferRequest represents the request body for transferring funds between accounts.
type TransferRequest struct {
Amount float64 `json:"amount" validate:"required,gt=0"`
Currency string `json:"currency" validate:"omitempty,len=3,uppercase,alpha"`
DestinationAccountID string `json:"destination_account_id" validate:"required,uuid4"`
}
// TransactionDTO is the API response representation of a transaction.
type TransactionDTO struct {
ID string `json:"id"`
UserID string `json:"user_id"`
AccountID string `json:"account_id"`
Amount float64 `json:"amount"`
Balance float64 `json:"balance"`
CreatedAt string `json:"created_at"`
Currency string `json:"currency"`
MoneySource string `json:"money_source"`
}
// ConversionInfoDTO holds conversion details for API responses.
type ConversionInfoDTO struct {
OriginalAmount float64 `json:"original_amount"`
OriginalCurrency string `json:"original_currency"`
ConvertedAmount float64 `json:"converted_amount"`
ConvertedCurrency string `json:"converted_currency"`
ConversionRate float64 `json:"conversion_rate"`
}
// TransferResponseDTO is the API response for a transfer operation, containing both transactions and a single conversion_info field (like deposit/withdraw).
type TransferResponseDTO struct {
Outgoing *TransactionDTO `json:"outgoing_transaction"`
Incoming *TransactionDTO `json:"incoming_transaction"`
ConversionInfo *ConversionInfoDTO `json:"conversion_info"`
}
// ToTransactionDTO maps a dto.TransactionRead to a TransactionDTO.
func ToTransactionDTO(tx *dto.TransactionRead) *TransactionDTO {
if tx == nil {
return nil
}
dto := &TransactionDTO{
ID: tx.ID.String(),
UserID: tx.UserID.String(),
AccountID: tx.AccountID.String(),
Amount: tx.Amount,
Currency: tx.Currency,
Balance: tx.Balance,
CreatedAt: tx.CreatedAt.Format("2006-01-02T15:04:05Z07:00"),
}
return dto
}
// ToConversionInfoDTO maps provider.ExchangeRate to ConversionInfoDTO.
func ToConversionInfoDTO(convInfo *exchange.RateInfo) *ConversionInfoDTO {
if convInfo == nil {
return nil
}
return &ConversionInfoDTO{
OriginalAmount: 0, // Not directly available from RateInfo
OriginalCurrency: convInfo.FromCurrency,
ConvertedAmount: 0, // Not directly available from RateInfo
ConvertedCurrency: convInfo.ToCurrency,
ConversionRate: convInfo.Rate,
}
}
// ToTransferResponseDTO maps domain transactions and conversion info to a TransferResponseDTO with a single conversion_info field.
func ToTransferResponseDTO(txOut, txIn *dto.TransactionRead, convInfo *exchange.RateInfo) *TransferResponseDTO {
return &TransferResponseDTO{
Outgoing: ToTransactionDTO(txOut),
Incoming: ToTransactionDTO(txIn),
ConversionInfo: ToConversionInfoDTO(convInfo),
}
}
//revive:enable
package account
import (
"context"
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
"github.com/amirasaad/fintech/pkg/service/stripeconnect"
"github.com/amirasaad/fintech/webapi/account/dto"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// getUserIDFromContext extracts the user ID from the JWT token in the context
func (h *StripeConnectHandlers) getUserIDFromContext(c *fiber.Ctx) (uuid.UUID, error) {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return uuid.Nil, common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := h.authSvc.GetCurrentUserId(token)
if err != nil {
log.Errorf("Failed to parse user ID from token: %v", err)
return uuid.Nil, common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
return userID, nil
}
type StripeConnectHandlers struct {
stripeConnectSvc stripeconnect.Service
authSvc *authsvc.Service
}
func NewStripeConnectHandlers(
stripeConnectSvc stripeconnect.Service,
authSvc *authsvc.Service,
) *StripeConnectHandlers {
return &StripeConnectHandlers{
stripeConnectSvc: stripeConnectSvc,
authSvc: authSvc,
}
}
// InitiateOnboarding initiates the Stripe Connect onboarding flow
// @Summary Initiate Stripe Connect onboarding
// @Description Generates a Stripe Connect onboarding URL for the authenticated user
// @Tags account
// @Produce json
// @Security BearerAuth
// @Success 200 {object} dto.InitiateOnboardingResponse
// @Failure 401 {object} dto.ErrorResponse
// @Failure 500 {object} dto.ErrorResponse
// @Router /stripe/account/onboard [post]
func (h *StripeConnectHandlers) InitiateOnboarding(c *fiber.Ctx) error {
// Get user ID from JWT token
userID, err := h.getUserIDFromContext(c)
if err != nil {
return common.ProblemDetailsJSON(c, err.Error(), err)
}
onboardingURL, err := h.stripeConnectSvc.GenerateOnboardingURL(
context.Background(),
userID,
)
if err != nil {
return common.ProblemDetailsJSON(c, "Failed to generate onboarding URL", err)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Onboarding URL generated successfully",
onboardingURL,
)
}
// GetOnboardingStatus checks if the authenticated user has completed
// the Stripe Connect onboarding process
// @Summary Check Stripe Connect onboarding status
// @Description Returns the onboarding completion status
// for the authenticated user's Stripe Connect account
// @Tags account
// @Accept json
// @Produce json
// @Security BearerAuth
// @Success 200 {object} dto.OnboardingStatusResponse
// @Failure 401 {object} dto.ErrorResponse
// @Failure 403 {object} dto.ErrorResponse
// @Failure 500 {object} dto.ErrorResponse
// @Router /stripe/account/onboard/status [get]
func (h *StripeConnectHandlers) GetOnboardingStatus(c *fiber.Ctx) error {
// Get user ID from JWT token
userID, err := h.getUserIDFromContext(c)
if err != nil {
return common.ProblemDetailsJSON(c, err.Error(), err)
}
isComplete, err := h.stripeConnectSvc.IsOnboardingComplete(context.Background(), userID)
if err != nil {
return common.ProblemDetailsJSON(c, "Failed to get onboarding status", err)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Onboarding status retrieved successfully",
dto.OnboardingStatusResponse{
IsComplete: isComplete,
},
)
}
// MapRoutes maps the Stripe Connect routes to the router with the API version prefix
func (h *StripeConnectHandlers) MapRoutes(router fiber.Router, jwtMiddleware fiber.Handler) {
// Stripe Connect onboarding routes
onboardGroup := router.Group("/account/onboard")
onboardGroup.Use(jwtMiddleware) // Add JWT protection to all routes in this group
{
onboardGroup.Post("/", h.InitiateOnboarding)
onboardGroup.Get("/status", h.GetOnboardingStatus)
}
}
package auth
import (
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
)
// Routes registers HTTP routes for authentication operations.
// Sets up endpoints for user login and authentication.
func Routes(app *fiber.App, authSvc *authsvc.Service) {
app.Post("/auth/login", Login(authSvc))
}
// Login handles user authentication and returns a JWT token.
// @Summary User login
// @Description Authenticate user with identity (username or email) and password
// @Tags auth
// @Accept json
// @Produce json
// @Param request body LoginInput true "Login credentials"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 429 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /auth/login [post]
func Login(authSvc *authsvc.Service) fiber.Handler {
return func(c *fiber.Ctx) error {
input, err := common.BindAndValidate[LoginInput](c)
if input == nil {
return err // Error already written by BindAndValidate
}
user, err := authSvc.Login(c.Context(), input.Identity, input.Password)
if err != nil {
// Check if it's an unauthorized error
if err.Error() == "user unauthorized" {
return common.ProblemDetailsJSON(
c,
"Invalid identity or password",
nil,
"Identity or password is incorrect",
fiber.StatusUnauthorized,
)
}
return common.ProblemDetailsJSON(
c,
"Internal Server Error",
err,
)
}
if user == nil {
return common.ProblemDetailsJSON(
c,
"Invalid identity or password",
nil,
"Identity or password is incorrect",
fiber.StatusUnauthorized,
)
}
token, err := authSvc.GenerateToken(c.Context(), user)
if err != nil {
return common.ProblemDetailsJSON(
c,
"Internal Server Error",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Success login",
fiber.Map{"token": token},
)
}
}
package checkout
import (
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/middleware"
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
"github.com/amirasaad/fintech/pkg/service/checkout"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/golang-jwt/jwt/v5"
)
// Routes registers HTTP routes for checkout-related operations.
func Routes(
app *fiber.App,
checkoutSvc *checkout.Service,
authSvc *authsvc.Service,
cfg *config.App,
) {
app.Get(
"/checkout/sessions/pending",
middleware.JwtProtected(cfg.Auth.Jwt),
GetPendingSessions(checkoutSvc, authSvc),
)
}
// GetPendingSessions returns a Fiber handler for retrieving pending checkout sessions.
// for the current user.
// @Summary Get pending checkout sessions
// @Description Retrieves a list of pending checkout sessions for the authenticated user.
// @Tags checkout
// @Accept json
// @Produce json
// @Success 200 {object} common.Response "Pending sessions fetched"
// @Failure 401 {object} common.ProblemDetails "Unauthorized"
// @Failure 500 {object} common.ProblemDetails "Internal server error"
// @Router /checkout/sessions/pending [get]
// @Security Bearer
func GetPendingSessions(checkoutSvc *checkout.Service, authSvc *authsvc.Service) fiber.Handler {
return func(c *fiber.Ctx) error {
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(c, "Unauthorized", nil, "missing user context")
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Errorf("Failed to parse user ID from token: %v", err)
return common.ProblemDetailsJSON(c, "Invalid user ID", err)
}
sessions, err := checkoutSvc.GetSessionsByUserID(c.Context(), userID)
if err != nil {
log.Errorf("Failed to get pending sessions: %v", err)
return common.ProblemDetailsJSON(c, "Failed to get pending sessions", err)
}
dtos := make([]*SessionDTO, 0, len(sessions))
for _, s := range sessions {
if s.Status == "created" {
dtos = append(dtos, &SessionDTO{
ID: s.ID,
TransactionID: s.TransactionID.String(),
UserID: s.UserID.String(),
AccountID: s.AccountID.String(),
Amount: s.Amount,
Currency: s.Currency,
Status: s.Status,
CheckoutURL: s.CheckoutURL,
CreatedAt: s.CreatedAt,
ExpiresAt: s.ExpiresAt,
})
}
}
return common.SuccessResponseJSON(c, fiber.StatusOK, "Pending sessions fetched", dtos)
}
}
package common
import (
"errors"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/domain/account"
"github.com/amirasaad/fintech/pkg/domain/user"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/provider/exchange"
"github.com/go-playground/validator/v10"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
)
// Response defines the standard API response structure for success cases.
type Response struct {
Status int `json:"status"` // HTTP status code
Message string `json:"message"` // Human-readable explanation
Data any `json:"data,omitempty"` // Response data
}
// ProblemDetails follows RFC 9457 Problem Details for HTTP APIs.
type ProblemDetails struct {
Type string `json:"type,omitempty"` // A URI reference that identifies the problem type
Title string `json:"title"` // Short, human-readable summary
Status int `json:"status"` // HTTP status code
Detail string `json:"detail,omitempty"` // Human-readable explanation
Instance string `json:"instance,omitempty"` // URI reference
Errors any `json:"errors,omitempty"` // Optional: additional error details
}
// ProblemDetailsJSON writes a problem+json error response with a status
// code inferred from the error (if present).
// The title is set to the error message (if error),
// and detail can be a string, error, or structured object.
// Optionally, a status code can be provided as the last argument (int)
// to override the fallback status.
func ProblemDetailsJSON(
c *fiber.Ctx,
title string,
err error,
detailOrStatus ...any,
) error {
status := fiber.StatusBadRequest
pdDetail := ""
var pdErrors any
var customStatus *int
if err != nil {
status = errorToStatusCode(err)
// Use generic message for duplicate key errors
if errors.Is(err, domain.ErrAlreadyExists) {
pdDetail = "Unprocessable entity"
} else {
pdDetail = err.Error()
}
}
// Check for custom detail or status code in variadic args
for _, arg := range detailOrStatus {
switch v := arg.(type) {
case int:
customStatus = &v
case string:
pdDetail = v
case error:
pdDetail = v.Error()
default:
pdErrors = v
}
}
// Use custom status if provided
if customStatus != nil {
status = *customStatus
}
pd := ProblemDetails{
Type: "about:blank",
Status: status,
Title: title,
Detail: pdDetail,
Errors: pdErrors,
Instance: c.Path(),
}
c.Set(fiber.HeaderContentType, "application/problem+json")
if err := c.Status(status).JSON(pd); err != nil {
log.Errorf("ProblemDetailsJSON failed: %v", err)
}
return nil
}
// BindAndValidate parses the request body and validates it using go-playground/validator.
// Returns a pointer to the struct (populated), or writes an error response and returns nil.
func BindAndValidate[T any](c *fiber.Ctx) (*T, error) {
var input T
if err := c.BodyParser(&input); err != nil {
return nil, ProblemDetailsJSON(
c,
"Invalid request body",
err,
"Request body could not be parsed or has invalid types",
fiber.StatusBadRequest,
) //nolint:errcheck
}
validate := validator.New()
if err := validate.Struct(input); err != nil {
if ve, ok := err.(validator.ValidationErrors); ok {
details := make(map[string]string)
for _, fe := range ve {
field := fe.Field()
msg := fe.Tag()
details[field] = msg
}
//revive:disable
return nil, ProblemDetailsJSON( //nolint:errcheck
c,
"Validation failed",
nil,
"Request validation failed",
details,
fiber.StatusBadRequest,
)
//revive:enable
}
//revive:disable
ProblemDetailsJSON( //nolint:errcheck
c,
"Validation failed",
err,
"Request validation failed",
fiber.StatusBadRequest)
//revive:enable
return nil, err
}
return &input, nil
}
// SuccessResponseJSON writes a JSON response with the given status, message, and data
// using the standard Response struct.
// Use for successful API responses (e.g., 200, 201, 202).
func SuccessResponseJSON(c *fiber.Ctx, status int, message string, data any) error {
return c.Status(status).JSON(Response{
Status: status,
Message: message,
Data: data,
})
}
// errorToStatusCode maps domain errors to appropriate HTTP status codes.
func errorToStatusCode(err error) int {
switch {
// Domain errors
case errors.Is(err, domain.ErrAlreadyExists):
return fiber.StatusUnprocessableEntity
// Account errors
case errors.Is(err, account.ErrAccountNotFound):
return fiber.StatusNotFound
case errors.Is(err, account.ErrDepositAmountExceedsMaxSafeInt):
return fiber.StatusBadRequest
case errors.Is(err, account.ErrTransactionAmountMustBePositive):
return fiber.StatusBadRequest
case errors.Is(err, account.ErrInsufficientFunds):
return fiber.StatusUnprocessableEntity
// Common errors
case errors.Is(err, money.ErrInvalidCurrency):
return fiber.StatusBadRequest
case errors.Is(err, money.ErrAmountExceedsMaxSafeInt):
return fiber.StatusBadRequest
case errors.Is(err, exchange.ErrUnsupportedPair):
return fiber.StatusUnprocessableEntity
// Money/currency conversion errors
case errors.Is(err, exchange.ErrProviderUnavailable):
return fiber.StatusServiceUnavailable
// User errors
case errors.Is(err, user.ErrUserNotFound):
return fiber.StatusNotFound
case errors.Is(err, user.ErrUserUnauthorized):
return fiber.StatusUnauthorized
default:
return fiber.StatusInternalServerError
}
}
package currency
import (
"fmt"
"strings"
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/middleware"
"github.com/amirasaad/fintech/pkg/money"
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
currencysvc "github.com/amirasaad/fintech/pkg/service/currency"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
)
// Routes sets up the currency routes
func Routes(
r fiber.Router,
currencySvc *currencysvc.Service,
authSvc *authsvc.Service,
cfg *config.App,
) {
currencyGroup := r.Group("/api/currencies")
// Public endpoints
currencyGroup.Get(
"/",
ListCurrencies(currencySvc),
)
currencyGroup.Get(
"/supported",
ListSupportedCurrencies(currencySvc),
)
currencyGroup.Get(
"/:code",
GetCurrency(currencySvc),
)
currencyGroup.Get(
"/:code/supported",
CheckCurrencySupported(currencySvc),
)
currencyGroup.Get(
"/search",
SearchCurrencies(currencySvc),
)
currencyGroup.Get(
"/region/:region",
SearchCurrenciesByRegion(currencySvc),
)
currencyGroup.Get(
"/statistics",
GetCurrencyStatistics(currencySvc),
)
currencyGroup.Get(
"/default",
GetDefaultCurrency(currencySvc),
)
// Admin endpoints (require authentication)
adminGroup := currencyGroup.Group("/admin")
adminGroup.Post(
"/",
middleware.JwtProtected(cfg.Auth.Jwt),
RegisterCurrency(currencySvc),
)
adminGroup.Delete(
"/:code",
middleware.JwtProtected(cfg.Auth.Jwt),
UnregisterCurrency(currencySvc),
)
adminGroup.Put(
"/:code/activate",
middleware.JwtProtected(cfg.Auth.Jwt),
ActivateCurrency(currencySvc),
)
adminGroup.Put(
"/:code/deactivate",
middleware.JwtProtected(cfg.Auth.Jwt),
DeactivateCurrency(currencySvc),
)
}
// ListCurrencies returns a Fiber handler for listing all available currencies.
// @Summary List all currencies
// @Description Get a list of all available currencies in the system
// @Tags currencies
// @Accept json
// @Produce json
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 429 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /currency [get]
// @Security Bearer
func ListCurrencies(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
currencies, err := currencySvc.ListAll(c.Context())
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to list currencies",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currencies fetched successfully",
currencies,
)
}
}
// ListSupportedCurrencies returns all supported currency codes
// @Summary List supported currencies
// @Description Get all supported currency codes
// @Tags currencies
// @Accept json
// @Produce json
// @Success 200 {array} string
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/supported [get]
func ListSupportedCurrencies(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
currencies, err := currencySvc.ListSupported(c.Context())
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to list supported currencies",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Supported currencies fetched successfully",
currencies,
)
}
}
// GetCurrency returns currency information by code
// @Summary Get currency by code
// @Description Get currency information by ISO 4217 code
// @Tags currencies
// @Accept json
// @Produce json
// @Param code path string true "Currency code (e.g., USD, EUR)"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 404 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/{code} [get]
func GetCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
code := c.Params("code")
if code == "" {
return common.ProblemDetailsJSON(
c,
"Currency code is required",
nil,
"Missing currency code",
fiber.StatusBadRequest,
)
}
// Validate currency code format
if err := currencySvc.ValidateCode(c.Context(), code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
"Currency code must be a valid ISO 4217 code",
fiber.StatusBadRequest,
)
}
currency, err := currencySvc.Get(c.Context(), code)
if err != nil {
return common.ProblemDetailsJSON(
c,
"Currency not found",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency fetched successfully",
currency,
)
}
}
// CheckCurrencySupported checks if a currency is supported
// @Summary Check if currency is supported
// @Description Check if a currency code is supported
// @Tags currencies
// @Accept json
// @Produce json
// @Param code path string true "Currency code (e.g., USD, EUR)"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Router /api/currencies/{code}/supported [get]
func CheckCurrencySupported(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
code := c.Params("code")
if code == "" {
return common.ProblemDetailsJSON(
c,
"Currency code is required",
nil,
"Missing currency code",
fiber.StatusBadRequest,
)
}
// Validate currency code format
if err := currencySvc.ValidateCode(c.Context(), code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
"Currency code must be a valid ISO 4217 code",
fiber.StatusBadRequest,
)
}
supported := currencySvc.IsSupported(c.Context(), code)
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency support checked successfully",
fiber.Map{"code": code, "supported": supported},
)
}
}
// SearchCurrencies searches for currencies by name
// @Summary Search currencies
// @Description Search for currencies by name
// @Tags currencies
// @Accept json
// @Produce json
// @Param q query string true "Search query"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/search [get]
func SearchCurrencies(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
query := c.Query("q")
if query == "" {
return common.ProblemDetailsJSON(
c,
"Search query is required",
nil,
"Missing search query",
fiber.StatusBadRequest,
)
}
currencies, err := currencySvc.Search(c.Context(), query)
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to search currencies",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currencies searched successfully",
currencies,
)
}
}
// SearchCurrenciesByRegion searches for currencies by region
// @Summary Search currencies by region
// @Description Search for currencies by region
// @Tags currencies
// @Accept json
// @Produce json
// @Param region path string true "Region name"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/region/{region} [get]
func SearchCurrenciesByRegion(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
region := c.Params("region")
if region == "" {
return common.ProblemDetailsJSON(
c,
"Region is required",
nil,
"Missing region",
fiber.StatusBadRequest,
)
}
currencies, err := currencySvc.SearchByRegion(c.Context(), region)
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to search currencies by region",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currencies by region fetched successfully",
currencies,
)
}
}
// GetCurrencyStatistics returns currency statistics
// @Summary Get currency statistics
// @Description Get currency registry statistics
// @Tags currencies
// @Accept json
// @Produce json
// @Success 200 {object} common.Response
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/statistics [get]
func GetCurrencyStatistics(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
stats, err := currencySvc.GetStatistics(c.Context())
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to get currency statistics",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency statistics fetched successfully",
stats,
)
}
}
// GetDefaultCurrency returns the default currency information
// @Summary Get default currency
// @Description Get the default currency information
// @Tags currencies
// @Accept json
// @Produce json
// @Success 200 {object} common.Response
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/default [get]
func GetDefaultCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
defaultCurrency, err := currencySvc.GetDefault(c.Context())
if err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to get default currency",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Default currency fetched successfully",
defaultCurrency,
)
}
}
// RegisterCurrency registers a new currency (admin only)
// @Summary Register currency
// @Description Register a new currency (admin only)
// @Tags currencies
// @Accept json
// @Produce json
// @Param currency body RegisterRequest true "Currency information"
// @Success 201 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 409 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/admin [post]
func RegisterCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
input, err := common.BindAndValidate[RegisterRequest](c)
if err != nil {
return nil // Error already written by BindAndValidate
}
// Validate currency code format
if err = currencySvc.ValidateCode(c.Context(), input.Code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
fiber.StatusBadRequest,
)
}
// Check if currency already exists
if _, err := currencySvc.Get(c.Context(), input.Code); err == nil {
return common.ProblemDetailsJSON(
c,
"Currency already exists",
fmt.Errorf("currency with code %s already exists", input.Code),
fiber.StatusConflict,
)
}
// Register the currency
currEntity := currencysvc.Entity{
Code: money.Code(input.Code),
Name: input.Name,
Symbol: input.Symbol,
Decimals: input.Decimals,
Country: input.Country,
Region: input.Region,
Active: input.Active,
}
if err = currencySvc.Register(c.Context(), currEntity); err != nil {
return common.ProblemDetailsJSON(
c,
"Failed to register currency",
err,
fiber.StatusInternalServerError,
)
}
// Convert to response DTO
resp := ToResponse(&currEntity)
return common.SuccessResponseJSON(
c,
fiber.StatusCreated,
"Currency registered successfully",
resp,
)
}
}
// UnregisterCurrency removes a currency from the registry (admin only)
// @Summary Unregister currency
// @Description Remove a currency from the registry (admin only)
// @Tags currencies
// @Accept json
// @Produce json
// @Param code path string true "Currency code"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 404 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/admin/{code} [delete]
func UnregisterCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
code := c.Params("code")
if code == "" {
return common.ProblemDetailsJSON(
c,
"Currency code is required",
nil,
)
}
// Validate currency code format
if err := currencySvc.ValidateCode(c.Context(), code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
)
}
if err := currencySvc.Unregister(c.Context(), code); err != nil {
if strings.Contains(err.Error(), "not found") {
return common.ProblemDetailsJSON(
c,
"Failed to unregister currency: currency not found",
err,
)
}
return common.ProblemDetailsJSON(
c,
"Failed to unregister currency",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency unregistered successfully",
fiber.Map{"code": code},
)
}
}
// ActivateCurrency activates a currency (admin only)
// @Summary Activate currency
// @Description Activate a currency (admin only)
// @Tags currencies
// @Accept json
// @Produce json
// @Param code path string true "Currency code"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 404 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/admin/{code}/activate [put]
func ActivateCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
code := c.Params("code")
if code == "" {
return common.ProblemDetailsJSON(
c,
"Currency code is required",
nil,
)
}
// Validate currency code format
if err := currencySvc.ValidateCode(c.Context(), code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
)
}
if err := currencySvc.Activate(c.Context(), code); err != nil {
if strings.Contains(err.Error(), "not found") {
return common.ProblemDetailsJSON(
c,
"Failed to activate currency: currency not found",
err,
)
}
return common.ProblemDetailsJSON(
c,
"Failed to activate currency",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency activated successfully",
fiber.Map{"code": code},
)
}
}
// DeactivateCurrency deactivates a currency (admin only)
// @Summary Deactivate currency
// @Description Deactivate a currency (admin only)
// @Tags currencies
// @Accept json
// @Produce json
// @Param code path string true "Currency code"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 404 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /api/currencies/admin/{code}/deactivate [put]
func DeactivateCurrency(
currencySvc *currencysvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
code := c.Params("code")
if code == "" {
return common.ProblemDetailsJSON(
c,
"Currency code is required",
nil,
)
}
// Validate currency code format
if err := currencySvc.ValidateCode(c.Context(), code); err != nil {
return common.ProblemDetailsJSON(
c,
"Invalid currency code",
err,
)
}
if err := currencySvc.Deactivate(c.Context(), code); err != nil {
if strings.Contains(err.Error(), "not found") {
return common.ProblemDetailsJSON(
c,
"Failed to deactivate currency: currency not found",
err,
)
}
return common.ProblemDetailsJSON(
c,
"Failed to deactivate currency",
err,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"Currency deactivated successfully",
fiber.Map{"code": code},
)
}
}
package currency
import (
"time"
"github.com/amirasaad/fintech/pkg/money"
"github.com/amirasaad/fintech/pkg/registry"
currencysvc "github.com/amirasaad/fintech/pkg/service/currency"
)
// RegisterRequest represents the request body for registering a currency.
type RegisterRequest struct {
Code string `json:"code" validate:"required,len=3,uppercase"`
Name string `json:"name" validate:"required"`
Symbol string `json:"symbol" validate:"required"`
Decimals int `json:"decimals" validate:"required,min=0,max=8"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
Active bool `json:"active"`
Metadata map[string]string `json:"metadata,omitempty"`
}
// CurrencyResponse represents the response structure for currency data
type CurrencyResponse struct {
Code string `json:"code"`
Name string `json:"name"`
Symbol string `json:"symbol"`
Decimals int `json:"decimals"`
Country string `json:"country,omitempty"`
Region string `json:"region,omitempty"`
Active bool `json:"active"`
Metadata map[string]string `json:"metadata,omitempty"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt *time.Time `json:"updated_at,omitempty"`
}
// ToResponse converts a currency entity to a response DTO
func ToResponse(entity *currencysvc.Entity) *CurrencyResponse {
if entity == nil {
return nil
}
updatedAt := entity.UpdatedAt()
return &CurrencyResponse{
Code: entity.Code.String(),
Name: entity.Name,
Symbol: entity.Symbol,
Decimals: entity.Decimals,
Country: entity.Country,
Region: entity.Region,
Active: entity.Active,
Metadata: entity.Metadata(),
CreatedAt: entity.CreatedAt(),
UpdatedAt: &updatedAt,
}
}
// ToServiceEntity converts a RegisterRequest to a service layer entity
func (r *RegisterRequest) ToServiceEntity() *currencysvc.Entity {
return ¤cysvc.Entity{
Entity: registry.NewBaseEntity(r.Code, r.Name),
Code: money.Code(r.Code),
Name: r.Name,
Symbol: r.Symbol,
Decimals: r.Decimals,
Country: r.Country,
Region: r.Region,
Active: r.Active,
}
}
package payment
import (
"fmt"
"github.com/amirasaad/fintech/pkg/provider/payment"
"github.com/gofiber/fiber/v2"
)
// StripeWebhookHandler handles incoming Stripe webhook events
func StripeWebhookHandler(
paymentProvider payment.Payment,
) fiber.Handler {
return func(c *fiber.Ctx) error {
// Get the signature from the request headers
signature := c.Get("Stripe-Signature")
if signature == "" {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Missing Stripe-Signature header",
})
}
// Get the raw request body
payload := c.Body()
if len(payload) == 0 {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": "Empty request body",
})
}
// Process the webhook event
_, err := paymentProvider.HandleWebhook(c.Context(), payload, signature)
if err != nil {
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
"error": fmt.Sprintf("Error processing webhook: %v", err),
})
}
// Return a 200 response to acknowledge receipt of the event
return c.SendStatus(fiber.StatusOK)
}
}
// StripeWebhookRoutes sets up the Stripe webhook routes
func StripeWebhookRoutes(
app *fiber.App,
paymentProvider payment.Payment,
) {
// Webhook endpoint for Stripe events
app.Post("/api/v1/webhooks/stripe", StripeWebhookHandler(paymentProvider))
}
package testutils
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/amirasaad/fintech/pkg/app"
"github.com/amirasaad/fintech/pkg/config"
pkgeventbus "github.com/amirasaad/fintech/pkg/eventbus"
"github.com/amirasaad/fintech/pkg/registry"
"github.com/amirasaad/fintech/infra/eventbus"
"github.com/amirasaad/fintech/infra/provider/exchangerateapi"
"github.com/amirasaad/fintech/infra/provider/mockpayment"
infrarepo "github.com/amirasaad/fintech/infra/repository"
infrarepoUser "github.com/amirasaad/fintech/infra/repository/user"
fixturescurrency "github.com/amirasaad/fintech/internal/fixtures/currency"
"github.com/amirasaad/fintech/pkg/domain"
"github.com/amirasaad/fintech/pkg/domain/user"
"github.com/amirasaad/fintech/webapi"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
"github.com/golang-migrate/migrate/v4"
migratepostgres "github.com/golang-migrate/migrate/v4/database/postgres"
_ "github.com/golang-migrate/migrate/v4/source/file" // required for file-based migrations
"github.com/google/uuid"
"github.com/stretchr/testify/suite"
"github.com/testcontainers/testcontainers-go"
tcpostgres "github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
// E2ETestSuite provides a test suite with a real Postgres database using Testcontainers
type E2ETestSuite struct {
suite.Suite
pgContainer *tcpostgres.PostgresContainer
db *gorm.DB
app *fiber.App
cfg *config.App
}
// BeforeEachTest runs before each test in the E2ETestSuite. It enables parallel test execution.
func (s *E2ETestSuite) BeforeEachTest() {
s.T().Parallel()
}
// SetupSuite initializes the test suite with a real Postgres database
func (s *E2ETestSuite) SetupSuite() {
ctx := context.Background()
// Start Postgres container
pg, err := tcpostgres.Run(
ctx,
"postgres:15-alpine",
tcpostgres.WithDatabase("testdb"),
tcpostgres.WithUsername("test"),
tcpostgres.WithPassword("test"),
testcontainers.WithWaitStrategy(
wait.ForLog("database system is ready to accept connections").
WithOccurrence(2).WithStartupTimeout(30*time.Second),
),
)
s.Require().NoError(err)
s.pgContainer = pg
// Get connection string and connect to database
dsn, err := pg.ConnectionString(ctx, "sslmode=disable")
s.Require().NoError(err)
s.db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{})
s.Require().NoError(err)
// Run migrations
sqlDB, err := s.db.DB()
s.Require().NoError(err)
driver, err := migratepostgres.WithInstance(sqlDB, &migratepostgres.Config{})
s.Require().NoError(err)
_, filename, _, _ := runtime.Caller(0)
migrationsPath := filepath.Join(filepath.Dir(filename), "../../internal/migrations")
m, err := migrate.NewWithDatabaseInstance("file://"+migrationsPath, "postgres", driver)
s.Require().NoError(err)
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
s.Require().NoError(err)
}
// Load config
envTest, err := config.FindEnvTest(".env.test")
s.Require().NoError(err)
s.cfg, err = config.Load(envTest)
s.Require().NoError(err)
s.cfg.DB.Url = dsn
// Setup services and app
s.setupApp()
// log.SetOutput(io.Discard)
}
// TearDownSuite cleans up the test suite resources
func (s *E2ETestSuite) TearDownSuite() {
ctx := context.Background()
if s.pgContainer != nil {
_ = s.pgContainer.Terminate(ctx)
}
}
// setupApp creates all services and the test app,
// using Redis as the event bus via testcontainers-go.
func (s *E2ETestSuite) setupApp() {
s.T().Helper()
// Create deps with debug logging
uow := infrarepo.NewUoW(s.db)
// Enable debug logging
logger := slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{
Level: slog.LevelDebug,
}))
// Setup currency service
ctx := context.Background()
currencyRegistry, err := registry.NewBuilder().
WithName("test_currency").
WithRedis("").
WithCache(100, time.Minute).
BuildRegistry()
if err != nil {
panic(fmt.Errorf("failed to create test currency registry provider: %w", err))
}
// Load currency fixtures
_, filename, _, _ := runtime.Caller(0)
fixturePath := filepath.Join(
filepath.Dir(filename),
"../../internal/fixtures/currency/meta.csv",
)
metas, err := fixturescurrency.LoadCurrencyMetaCSV(fixturePath)
s.Require().NoError(err)
for _, meta := range metas {
s.Require().NoError(currencyRegistry.Register(ctx, meta))
}
driver := ""
if s.cfg.EventBus != nil {
driver = strings.TrimSpace(strings.ToLower(s.cfg.EventBus.Driver))
}
var eventBus pkgeventbus.Bus
switch driver {
case "", "memory":
eventBus = eventbus.NewWithMemoryAsync(logger)
case "redis":
redisContainer, containerErr := testcontainers.GenericContainer(
ctx,
testcontainers.GenericContainerRequest{
ContainerRequest: testcontainers.ContainerRequest{
Image: "redis:7-alpine",
ExposedPorts: []string{"6379/tcp"},
WaitingFor: wait.ForListeningPort(
"6379/tcp",
).WithStartupTimeout(10 * time.Second),
},
Started: true,
},
)
s.Require().NoError(containerErr)
endpoint, endpointErr := redisContainer.Endpoint(ctx, "")
s.Require().NoError(endpointErr)
eventBus, err = eventbus.NewWithRedis(
"redis://"+endpoint,
logger,
&eventbus.RedisEventBusConfig{
DLQRetryInterval: 5 * time.Minute,
DLQBatchSize: 10,
},
)
s.Require().NoError(err)
s.T().Cleanup(func() {
_ = redisContainer.Terminate(ctx)
})
case "kafka":
kafkaBus, kafkaBusErr := eventbus.NewWithKafka(
s.cfg.EventBus.KafkaBrokers,
logger,
&eventbus.KafkaEventBusConfig{
GroupID: s.cfg.EventBus.KafkaGroupID,
TopicPrefix: s.cfg.EventBus.KafkaTopic,
DLQRetryInterval: 5 * time.Minute,
DLQBatchSize: 10,
},
)
s.Require().NoError(kafkaBusErr)
eventBus = kafkaBus
default:
s.T().Fatalf("unsupported event bus driver: %s", driver)
}
// Create registry providers for each service with in-memory storage
mainReg, err := registry.NewBuilder().
WithName("test").
WithRedis(""). // Empty URL for in-memory
WithCache(100, time.Minute).
BuildRegistry()
if err != nil {
panic(fmt.Errorf("failed to create test main registry provider: %w", err))
}
mainRegistry, ok := mainReg.(*registry.Enhanced)
if !ok {
panic("main registry is not of type *registry.Enhanced")
}
// Create currency registry
currencyReg, err := registry.NewBuilder().
WithName("test_currency").
WithRedis("").
WithCache(100, time.Minute).
BuildRegistry()
if err != nil {
panic(fmt.Errorf("failed to create test currency registry provider: %w", err))
}
currencyRegistry, ok = currencyReg.(*registry.Enhanced)
if !ok {
panic("currency registry is not of type *registry.Enhanced")
}
// Create checkout registry
checkoutReg, err := registry.NewBuilder().
WithName("test_checkout").
WithRedis("").
WithCache(100, time.Minute).
BuildRegistry()
if err != nil {
panic(fmt.Errorf("failed to create test checkout registry provider: %w", err))
}
checkoutRegistry, ok := checkoutReg.(*registry.Enhanced)
if !ok {
panic("checkout registry is not of type *registry.Enhanced")
}
// Create exchange rate registry
exchangeRateReg, err := registry.NewBuilder().
WithName("test_exchange_rate").
WithRedis("").
WithCache(100, time.Minute).
BuildRegistry()
if err != nil {
panic(fmt.Errorf("failed to create test exchange rate registry provider: %w", err))
}
exchangeRateRegistry, ok := exchangeRateReg.(*registry.Enhanced)
if !ok {
panic("exchange rate registry is not of type *registry.Enhanced")
}
exchangeRateProvider := exchangerateapi.NewFakeExchangeRate()
mockPaymentProvider := mockpayment.NewMockPaymentProvider()
deps := &app.Deps{
RegistryProvider: mainRegistry,
CurrencyRegistry: currencyRegistry,
CheckoutRegistry: checkoutRegistry,
ExchangeRateRegistry: exchangeRateRegistry,
ExchangeRateProvider: exchangeRateProvider,
PaymentProvider: mockPaymentProvider,
Uow: uow,
EventBus: eventBus,
Logger: logger,
}
// Create test app
s.app = webapi.SetupApp(app.New(
deps,
s.cfg,
))
}
// MakeRequest is a helper for making HTTP requests in tests
func (s *E2ETestSuite) MakeRequest(
method, path, body, token string,
) *http.Response {
var req *http.Request
if body != "" {
req = httptest.NewRequest(method, path, bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
} else {
req = httptest.NewRequest(method, path, nil)
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
resp, err := s.app.Test(req, 1000000)
if err != nil {
s.T().Fatal(err)
}
return resp
}
// CreateTestUser creates a unique test user via the POST /user/ endpoint
func (s *E2ETestSuite) CreateTestUser() *domain.User {
randomID := uuid.New().String()[:8]
username := fmt.Sprintf("testuser_%s", randomID)
email := fmt.Sprintf("test_%s@example.com", randomID)
// Create user via HTTP POST request
createUserBody := fmt.Sprintf(
`{"username":"%s","email":"%s","password":"password123"}`,
username,
email,
)
resp := s.MakeRequest("POST", "/user", createUserBody, "")
if resp.StatusCode != 201 {
// Read the response body for more details
body, _ := io.ReadAll(resp.Body)
s.T().Logf("User creation failed with status %d.", resp.StatusCode)
s.T().Logf("Response body: %s", string(body))
s.T().Fatalf("Expected 201 Created for user creation, got %d", resp.StatusCode)
}
// Parse response to get the created user
var response common.Response
err := json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
s.T().Fatal(err)
}
// Extract user data from response
if userData, ok := response.Data.(map[string]any); ok {
userIDStr, ok := userData["id"].(string)
if !ok {
s.T().Fatalf("User ID should be present in response")
}
userID, parseErr := uuid.Parse(userIDStr)
if parseErr != nil {
s.T().Fatalf("User ID should be a valid UUID")
}
return &domain.User{
ID: userID,
Username: username,
Email: email,
Password: "password123",
}
}
// Fallback: create user directly
testUser, err := user.New(username, email, "password123")
if err != nil {
s.T().Fatalf("Failed to create user: %v", err)
}
return testUser
}
// LoginUser makes an actual HTTP request to login and returns the JWT token
func (s *E2ETestSuite) LoginUser(testUser *domain.User) string {
loginBody := fmt.Sprintf(`{"identity":"%s","password":"%s"}`, testUser.Email, testUser.Password)
resp := s.MakeRequest("POST", "/auth/login", loginBody, "")
var response common.Response
err := json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
s.T().Fatal(err)
}
// Extract token from response
var token string
if dataMap, ok := response.Data.(map[string]any); ok {
if tokenInterface, exists := dataMap["token"]; exists {
token = tokenInterface.(string)
}
} else if dataMap, ok := response.Data.(map[string]string); ok {
token = dataMap["token"]
} else if tokenString, ok := response.Data.(string); ok {
token = tokenString
}
s.T().Logf("Extracted token: %s", token)
if token == "" {
s.T().Fatalf("No token found in response")
}
return token
}
// MarkUserOnboardingComplete marks a user as having completed Stripe Connect onboarding
// This is a helper method for E2E tests that need to test withdrawal flows
func (s *E2ETestSuite) MarkUserOnboardingComplete(userID uuid.UUID) {
s.T().Helper()
err := s.db.Model(&infrarepoUser.User{}).
Where("id = ?", userID).
Updates(map[string]interface{}{
"stripe_connect_onboarding_completed": true,
"stripe_connect_account_status": "active",
}).Error
s.Require().NoError(err, "Failed to mark user onboarding as complete")
}
package user
import (
"github.com/amirasaad/fintech/pkg/config"
"github.com/amirasaad/fintech/pkg/dto"
"github.com/amirasaad/fintech/pkg/middleware"
authsvc "github.com/amirasaad/fintech/pkg/service/auth"
usersvc "github.com/amirasaad/fintech/pkg/service/user"
"github.com/amirasaad/fintech/webapi/common"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/log"
"github.com/golang-jwt/jwt/v5"
"github.com/google/uuid"
)
// Routes registers HTTP routes for user-related operations.
// Sets up endpoints for user creation, retrieval, update, and deletion.
func Routes(
app *fiber.App,
userSvc *usersvc.Service,
authSvc *authsvc.Service,
cfg *config.App,
) {
app.Get("/user/:id", middleware.JwtProtected(cfg.Auth.Jwt), GetUser(userSvc))
app.Post("/user", CreateUser(userSvc))
app.Put("/user/:id",
middleware.JwtProtected(cfg.Auth.Jwt),
UpdateUser(userSvc, authSvc))
app.Delete("/user/:id",
middleware.JwtProtected(cfg.Auth.Jwt),
DeleteUser(userSvc, authSvc))
}
// GetUser returns a Fiber handler for retrieving a user by ID.
// @Summary Get user by ID
// @Description Retrieve a user by their ID
// @Tags users
// @Accept json
// @Produce json
// @Param id path string true "User ID"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 404 {object} common.ProblemDetails
// @Router /user/{id} [get]
// @Security Bearer
func GetUser(userSvc *usersvc.Service) fiber.Handler {
return func(c *fiber.Ctx) error {
id, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Errorf("Invalid user ID: %v", err)
return common.ProblemDetailsJSON(
c,
"Invalid user ID",
err,
"User ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
user, err := userSvc.GetUser(
c.Context(),
id.String(),
)
if err != nil || user == nil {
// Generic error for not found to prevent user enumeration
return common.ProblemDetailsJSON(
c,
"Invalid credentials",
nil,
fiber.StatusUnauthorized,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"User found",
user,
)
}
}
// CreateUser creates a new user account.
// @Summary Create a new user
// @Description Create a new user account with username, email, and password
// @Tags users
// @Accept json
// @Produce json
// @Param request body NewUser true "User creation data"
// @Success 201 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 429 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /user [post]
func CreateUser(userSvc *usersvc.Service) fiber.Handler {
return func(c *fiber.Ctx) error {
input, err := common.BindAndValidate[NewUser](c)
if input == nil {
return err // error response already written
}
if len(input.Password) > 72 {
return common.ProblemDetailsJSON(c,
"Invalid request body",
nil,
"Password too long",
fiber.StatusBadRequest)
}
user, err := userSvc.CreateUser(
c.Context(),
input.Username,
input.Email,
input.Password)
if err != nil {
return common.ProblemDetailsJSON(
c,
"Couldn't create user",
err)
}
return common.SuccessResponseJSON(
c,
fiber.StatusCreated,
"Created user",
user)
}
}
// UpdateUser updates user information.
// @Summary Update user
// @Description Update user information by ID
// @Tags users
// @Accept json
// @Produce json
// @Param id path string true "User ID"
// @Param request body UpdateUserInput true "User update data"
// @Success 200 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 429 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /user/{id} [put]
// @Security Bearer
func UpdateUser(
userSvc *usersvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
input, err := common.BindAndValidate[UpdateUserInput](c)
if input == nil {
return err // error response already written
}
id, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Errorf("Invalid user ID: %v", err)
return common.ProblemDetailsJSON(
c,
"Invalid user ID",
err,
"User ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(
c,
"Unauthorized", nil,
fiber.StatusUnauthorized)
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Errorf("Failed to parse user ID from token: %v", err)
return common.ProblemDetailsJSON(c, "Unauthorized", nil,
fiber.StatusUnauthorized)
}
if id != userID {
return common.ProblemDetailsJSON(c, "Forbidden", nil,
fiber.StatusUnauthorized)
}
err = userSvc.UpdateUser(c.Context(), id.String(), &dto.UserUpdate{
Names: &input.Names,
})
if err != nil {
// Generic error for update failure
return common.ProblemDetailsJSON(
c,
"Unauthorized",
nil,
"missing user context", fiber.StatusUnauthorized)
}
// Get the updated user to return in response
updatedUser, err := userSvc.GetUser(c.Context(), id.String())
if err != nil || updatedUser == nil {
return common.ProblemDetailsJSON(
c,
"Unauthorized",
nil,
fiber.StatusUnauthorized)
}
return common.SuccessResponseJSON(
c,
fiber.StatusOK,
"User updated successfully",
updatedUser)
}
}
// DeleteUser deletes a user account.
// @Summary Delete user
// @Description Delete a user account by ID with password confirmation
// @Tags users
// @Accept json
// @Produce json
// @Param id path string true "User ID"
// @Param request body PasswordInput true "Password confirmation"
// @Success 204 {object} common.Response
// @Failure 400 {object} common.ProblemDetails
// @Failure 401 {object} common.ProblemDetails
// @Failure 429 {object} common.ProblemDetails
// @Failure 500 {object} common.ProblemDetails
// @Router /user/{id} [delete]
// @Security Bearer
func DeleteUser(
userSvc *usersvc.Service,
authSvc *authsvc.Service,
) fiber.Handler {
return func(c *fiber.Ctx) error {
input, err := common.BindAndValidate[PasswordInput](c)
if input == nil {
return err // error response already written
}
id, err := uuid.Parse(c.Params("id"))
if err != nil {
log.Errorf("Invalid user ID: %v", err)
return common.ProblemDetailsJSON(
c,
"Invalid user ID",
err,
"User ID must be a valid UUID",
fiber.StatusBadRequest,
)
}
token, ok := c.Locals("user").(*jwt.Token)
if !ok {
return common.ProblemDetailsJSON(
c,
"Unauthorized",
nil,
"missing user context",
fiber.StatusUnauthorized,
)
}
userID, err := authSvc.GetCurrentUserId(token)
if err != nil {
log.Errorf("Failed to parse user ID from token: %v", err)
return common.ProblemDetailsJSON(
c,
"Unauthorized",
nil,
"missing user context",
fiber.StatusUnauthorized,
)
}
if id != userID {
return common.ProblemDetailsJSON(
c,
"Forbidden",
nil,
"You are not allowed to delete this user",
fiber.StatusUnauthorized,
)
}
// Retrieve user to get email for password validation
user, err := userSvc.GetUser(c.Context(), id.String())
if err != nil {
log.Errorf("Error getting user for password validation: %v", err)
return common.ProblemDetailsJSON(
c,
"Invalid credentials",
err,
"Internal server error during user validation",
fiber.StatusInternalServerError,
)
}
if user == nil {
return common.ProblemDetailsJSON(
c,
"Invalid credentials",
nil,
"User not found",
fiber.StatusUnauthorized,
)
}
if isValid, validErr := userSvc.ValidUser(
c.Context(),
user.Email,
input.Password,
); validErr != nil || !isValid {
// If this is a DB/internal error, return 500
if !isValid {
return common.ProblemDetailsJSON(
c,
"Invalid credentials",
nil,
"Invalid username or password",
fiber.StatusUnauthorized,
)
}
// If err is not nil, it's an internal server error during validation
return common.ProblemDetailsJSON(
c,
"Failed to validate user",
err,
fiber.StatusInternalServerError,
)
}
err = userSvc.DeleteUser(c.Context(), id.String())
if err != nil {
return common.ProblemDetailsJSON(
c,
"Couldn't delete user",
err,
"Internal server error during user deletion",
fiber.StatusInternalServerError,
)
}
return common.SuccessResponseJSON(
c,
fiber.StatusNoContent,
"User successfully deleted",
nil,
)
}
}
// Package webapi provides HTTP handlers and API endpoints for the fintech application.
// It is organized into sub-packages for different domains:
// - account: Account and transaction endpoints
// - auth: Authentication endpoints
// - user: User management endpoints
// - currency: Currency and exchange rate endpoints
package webapi
import (
"errors"
"strings"
"github.com/amirasaad/fintech/pkg/app"
accountweb "github.com/amirasaad/fintech/webapi/account"
authweb "github.com/amirasaad/fintech/webapi/auth"
checkoutweb "github.com/amirasaad/fintech/webapi/checkout"
"github.com/amirasaad/fintech/webapi/common"
currencyweb "github.com/amirasaad/fintech/webapi/currency"
"github.com/amirasaad/fintech/webapi/payment"
userweb "github.com/amirasaad/fintech/webapi/user"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/fiber/v2/middleware/limiter"
"github.com/gofiber/fiber/v2/middleware/logger"
"github.com/gofiber/fiber/v2/middleware/recover"
"github.com/gofiber/swagger"
)
// SetupApp Initialize Fiber with custom configuration
func SetupApp(app *app.App) *fiber.App {
accountSvc := app.AccountService
userSvc := app.UserService
authSvc := app.AuthService
currencySvc := app.CurrencyService
checkoutSvc := app.CheckoutService
fiberApp := fiber.New(fiber.Config{
ErrorHandler: func(c *fiber.Ctx, err error) error {
return common.ProblemDetailsJSON(c, "Internal Server Error", err)
},
})
fiberApp.Get("/swagger/*", swagger.New(swagger.Config{
TryItOutEnabled: true,
WithCredentials: true,
PersistAuthorization: true,
OAuth2RedirectUrl: "/auth/login",
}))
// Configure rate limiting middleware
// Uses X-Forwarded-For header when behind a proxy
// Falls back to X-Real-IP or direct IP if needed
fiberApp.Use(limiter.New(limiter.Config{
Max: app.Config.RateLimit.MaxRequests,
Expiration: app.Config.RateLimit.Window,
KeyGenerator: func(c *fiber.Ctx) string {
// Use X-Forwarded-For header if available (for load balancers/proxies)
// Fall back to X-Real-IP, then to direct IP
if forwardedFor := c.Get("X-Forwarded-For"); forwardedFor != "" {
// Take the first IP in the chain
if commaIndex := strings.Index(forwardedFor, ","); commaIndex != -1 {
return strings.TrimSpace(forwardedFor[:commaIndex])
}
return strings.TrimSpace(forwardedFor)
}
if realIP := c.Get("X-Real-IP"); realIP != "" {
return realIP
}
return c.IP()
},
LimitReached: func(c *fiber.Ctx) error {
return common.ProblemDetailsJSON(
c,
"Too Many Requests",
errors.New("rate limit exceeded"),
fiber.StatusTooManyRequests,
)
},
}))
fiberApp.Use(recover.New())
fiberApp.Use(logger.New())
// Health check endpoint
fiberApp.Get(
"/",
func(c *fiber.Ctx) error {
return c.SendString("FinTech API is running! 🚀")
},
)
// Debug endpoint to list all routes
fiberApp.Get("/debug/routes", func(c *fiber.Ctx) error {
routes := fiberApp.GetRoutes()
var routeList []map[string]interface{}
for _, route := range routes {
if route.Path != "" {
routeList = append(routeList, map[string]interface{}{
"method": route.Method,
"path": route.Path,
})
}
}
return c.JSON(routeList)
})
// Payment event processor for Stripe webhooks
fiberApp.Post(
"/api/v1/webhooks/stripe",
payment.StripeWebhookHandler(app.Deps.PaymentProvider),
)
// Initialize account routes which include Stripe Connect routes
accountweb.Routes(fiberApp, accountSvc, authSvc, app.StripeConnectService, app.Config)
userweb.Routes(fiberApp, userSvc, authSvc, app.Config)
authweb.Routes(fiberApp, authSvc)
currencyweb.Routes(fiberApp, currencySvc, authSvc, app.Config)
checkoutweb.Routes(fiberApp, checkoutSvc, authSvc, app.Config)
return fiberApp
}