db-wait/pkg/dbwait/dbwait.go

76 lines
1.5 KiB
Go
Raw Normal View History

2021-09-15 16:15:43 +00:00
package dbwait
import (
"context"
"database/sql"
"fmt"
"net/url"
"os"
"time"
)
func Wait(databaseURL *url.URL, period time.Duration, timeout time.Duration) error {
var err error
var sqlDB *sql.DB
switch databaseURL.Scheme {
case "oracle":
sqlDB, err = sql.Open("oracle", databaseURL.String())
if err != nil {
return err
}
case "postgres":
sqlDB, err = sql.Open("postgres", databaseURL.String())
if err != nil {
return err
}
}
defer sqlDB.Close()
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
ticker := time.NewTicker(time.Nanosecond)
<-ticker.C
LOOP:
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-ticker.C:
queryCtx, queryCancel := context.WithTimeout(ctx, period)
defer queryCancel()
switch databaseURL.Scheme {
case "oracle":
2022-09-13 15:31:17 +00:00
row := sqlDB.QueryRowContext(queryCtx, "SELECT 1 FROM dual")
2021-09-22 09:08:38 +00:00
2022-09-13 15:31:17 +00:00
var n int
err := row.Scan(&n)
2021-09-15 16:15:43 +00:00
if err != nil {
fmt.Fprintf(os.Stderr, "%s: %s\n", time.Now().String(), err.Error())
ticker.Reset(period)
continue LOOP
}
2021-09-22 09:08:38 +00:00
2022-09-13 15:31:17 +00:00
if n != 1 {
fmt.Fprintf(os.Stderr, "%s: Returned value not 1\n", time.Now().String())
2021-09-22 09:08:38 +00:00
ticker.Reset(period)
continue LOOP
}
2021-09-15 16:15:43 +00:00
return nil
case "postgres":
2021-09-22 09:08:38 +00:00
row := sqlDB.QueryRowContext(queryCtx, "SELECT 1 AS ROW")
if row.Err() != nil {
2021-09-15 16:15:43 +00:00
fmt.Fprintf(os.Stderr, "%s: %s\n", time.Now().String(), err.Error())
ticker.Reset(period)
continue LOOP
}
return nil
}
}
}
}