mirror of
https://github.com/thomiceli/opengist.git
synced 2025-06-12 13:37:13 +02:00
Add Postgres and MySQL databases support (#335)
This commit is contained in:
@ -8,7 +8,6 @@ import (
|
||||
"github.com/urfave/cli/v2"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
var CmdHook = cli.Command{
|
||||
@ -50,7 +49,8 @@ func initialize(ctx *cli.Context) {
|
||||
}
|
||||
config.InitLog()
|
||||
|
||||
if err := db.Setup(filepath.Join(config.GetHomeDir(), config.C.DBFilename), false); err != nil {
|
||||
db.DeprecationDBFilename()
|
||||
if err := db.Setup(config.C.DBUri, false); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize database in hooks")
|
||||
}
|
||||
}
|
||||
|
@ -108,8 +108,9 @@ func Initialize(ctx *cli.Context) {
|
||||
if err := os.MkdirAll(filepath.Join(homePath, "custom"), 0755); err != nil {
|
||||
log.Fatal().Err(err).Send()
|
||||
}
|
||||
log.Info().Msg("Database file: " + filepath.Join(homePath, config.C.DBFilename))
|
||||
if err := db.Setup(filepath.Join(homePath, config.C.DBFilename), false); err != nil {
|
||||
|
||||
db.DeprecationDBFilename()
|
||||
if err := db.Setup(config.C.DBUri, false); err != nil {
|
||||
log.Fatal().Err(err).Msg("Failed to initialize database")
|
||||
}
|
||||
|
||||
|
@ -29,7 +29,10 @@ type config struct {
|
||||
LogOutput string `yaml:"log-output" env:"OG_LOG_OUTPUT"`
|
||||
ExternalUrl string `yaml:"external-url" env:"OG_EXTERNAL_URL"`
|
||||
OpengistHome string `yaml:"opengist-home" env:"OG_OPENGIST_HOME"`
|
||||
DBFilename string `yaml:"db-filename" env:"OG_DB_FILENAME"`
|
||||
|
||||
DBUri string `yaml:"db-uri" env:"OG_DB_URI"`
|
||||
DBFilename string `yaml:"db-filename" env:"OG_DB_FILENAME"` // deprecated
|
||||
|
||||
IndexEnabled bool `yaml:"index.enabled" env:"OG_INDEX_ENABLED"`
|
||||
IndexDirname string `yaml:"index.dirname" env:"OG_INDEX_DIRNAME"`
|
||||
|
||||
@ -80,7 +83,7 @@ func configWithDefaults() (*config, error) {
|
||||
c.LogLevel = "warn"
|
||||
c.LogOutput = "stdout,file"
|
||||
c.OpengistHome = ""
|
||||
c.DBFilename = "opengist.db"
|
||||
c.DBUri = "opengist.db"
|
||||
c.IndexEnabled = true
|
||||
c.IndexDirname = "opengist.index"
|
||||
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
)
|
||||
|
||||
type AdminSetting struct {
|
||||
Key string `gorm:"uniqueIndex"`
|
||||
Key string `gorm:"index:,unique"`
|
||||
Value string
|
||||
}
|
||||
|
||||
@ -49,7 +49,7 @@ func UpdateSetting(key string, value string) error {
|
||||
}
|
||||
|
||||
func setSetting(key string, value string) error {
|
||||
return db.Create(&AdminSetting{Key: key, Value: value}).Error
|
||||
return db.FirstOrCreate(&AdminSetting{Key: key, Value: value}, &AdminSetting{Key: key}).Error
|
||||
}
|
||||
|
||||
func initAdminSettings(settings map[string]string) error {
|
||||
@ -64,9 +64,9 @@ func initAdminSettings(settings map[string]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type DBAuthInfo struct{}
|
||||
type AuthInfo struct{}
|
||||
|
||||
func (auth DBAuthInfo) RequireLogin() (bool, error) {
|
||||
func (auth AuthInfo) RequireLogin() (bool, error) {
|
||||
s, err := GetSetting(SettingRequireLogin)
|
||||
if err != nil {
|
||||
return true, err
|
||||
@ -74,7 +74,7 @@ func (auth DBAuthInfo) RequireLogin() (bool, error) {
|
||||
return s == "1", nil
|
||||
}
|
||||
|
||||
func (auth DBAuthInfo) AllowGistsWithoutLogin() (bool, error) {
|
||||
func (auth AuthInfo) AllowGistsWithoutLogin() (bool, error) {
|
||||
s, err := GetSetting(SettingAllowGistsWithoutLogin)
|
||||
if err != nil {
|
||||
return false, err
|
||||
|
@ -2,38 +2,133 @@ package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/glebarez/sqlite"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm/logger"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
msqlite "github.com/glebarez/go-sqlite"
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/thomiceli/opengist/internal/config"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var db *gorm.DB
|
||||
|
||||
func Setup(dbPath string, sharedCache bool) error {
|
||||
var err error
|
||||
journalMode := strings.ToUpper(config.C.SqliteJournalMode)
|
||||
const (
|
||||
SQLite databaseType = iota
|
||||
PostgreSQL
|
||||
MySQL
|
||||
)
|
||||
|
||||
if !slices.Contains([]string{"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, journalMode) {
|
||||
log.Warn().Msg("Invalid SQLite journal mode: " + journalMode)
|
||||
type databaseType int
|
||||
|
||||
func (d databaseType) String() string {
|
||||
return [...]string{"SQLite", "PostgreSQL", "MySQL"}[d]
|
||||
}
|
||||
|
||||
type databaseInfo struct {
|
||||
Type databaseType
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
}
|
||||
|
||||
var DatabaseInfo *databaseInfo
|
||||
|
||||
func parseDBURI(uri string) (*databaseInfo, error) {
|
||||
info := &databaseInfo{}
|
||||
|
||||
if !strings.Contains(uri, "://") {
|
||||
info.Type = SQLite
|
||||
if uri == "file::memory:" {
|
||||
info.Database = "file::memory:"
|
||||
return info, nil
|
||||
}
|
||||
info.Database = filepath.Join(config.GetHomeDir(), uri)
|
||||
return info, nil
|
||||
}
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URI: %v", err)
|
||||
}
|
||||
|
||||
sharedCacheStr := ""
|
||||
if sharedCache {
|
||||
sharedCacheStr = "&cache=shared"
|
||||
switch u.Scheme {
|
||||
case "postgres", "postgresql":
|
||||
info.Type = PostgreSQL
|
||||
case "mysql", "mariadb":
|
||||
info.Type = MySQL
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database: %v", err)
|
||||
}
|
||||
|
||||
if db, err = gorm.Open(sqlite.Open(dbPath+"?_fk=true&_journal_mode="+journalMode+sharedCacheStr), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
}); err != nil {
|
||||
if u.Host != "" {
|
||||
host, port, _ := strings.Cut(u.Host, ":")
|
||||
info.Host = host
|
||||
info.Port = port
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
info.User = u.User.Username()
|
||||
info.Password, _ = u.User.Password()
|
||||
}
|
||||
|
||||
switch info.Type {
|
||||
case PostgreSQL, MySQL:
|
||||
info.Database = strings.TrimPrefix(u.Path, "/")
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database: %v", err)
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func Setup(dbUri string, sharedCache bool) error {
|
||||
dbInfo, err := parseDBURI(dbUri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msgf("Setting up a %s database connection", dbInfo.Type)
|
||||
var setupFunc func(databaseInfo, bool) error
|
||||
switch dbInfo.Type {
|
||||
case SQLite:
|
||||
setupFunc = setupSQLite
|
||||
case PostgreSQL:
|
||||
setupFunc = setupPostgres
|
||||
case MySQL:
|
||||
setupFunc = setupMySQL
|
||||
default:
|
||||
return fmt.Errorf("unknown database type: %v", dbInfo.Type)
|
||||
}
|
||||
|
||||
maxAttempts := 60
|
||||
retryInterval := 1 * time.Second
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
err = setupFunc(*dbInfo, sharedCache)
|
||||
if err == nil {
|
||||
log.Info().Msg("Database connection established")
|
||||
break
|
||||
}
|
||||
|
||||
if attempt < maxAttempts {
|
||||
log.Warn().Err(err).Msgf("Failed to connect to database (attempt %d), retrying in %v...", attempt, retryInterval)
|
||||
time.Sleep(retryInterval)
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
DatabaseInfo = dbInfo
|
||||
|
||||
if err = db.SetupJoinTable(&Gist{}, "Likes", &Like{}); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -46,7 +141,7 @@ func Setup(dbPath string, sharedCache bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = ApplyMigrations(db); err != nil {
|
||||
if err = applyMigrations(db, dbInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -75,11 +170,7 @@ func CountAll(table interface{}) (int64, error) {
|
||||
}
|
||||
|
||||
func IsUniqueConstraintViolation(err error) bool {
|
||||
var sqliteErr *msqlite.Error
|
||||
if errors.As(err, &sqliteErr) && sqliteErr.Code() == 2067 {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
return errors.Is(err, gorm.ErrDuplicatedKey)
|
||||
}
|
||||
|
||||
func Ping() error {
|
||||
@ -90,3 +181,65 @@ func Ping() error {
|
||||
|
||||
return sql.Ping()
|
||||
}
|
||||
|
||||
func setupSQLite(dbInfo databaseInfo, sharedCache bool) error {
|
||||
var err error
|
||||
journalMode := strings.ToUpper(config.C.SqliteJournalMode)
|
||||
|
||||
if !slices.Contains([]string{"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, journalMode) {
|
||||
log.Warn().Msg("Invalid SQLite journal mode: " + journalMode)
|
||||
}
|
||||
|
||||
sharedCacheStr := ""
|
||||
if sharedCache {
|
||||
sharedCacheStr = "&cache=shared"
|
||||
}
|
||||
|
||||
db, err = gorm.Open(sqlite.Open(dbInfo.Database+"?_fk=true&_journal_mode="+journalMode+sharedCacheStr), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
TranslateError: true,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func setupPostgres(dbInfo databaseInfo, sharedCache bool) error {
|
||||
var err error
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", dbInfo.Host, dbInfo.Port, dbInfo.User, dbInfo.Password, dbInfo.Database)
|
||||
|
||||
db, err = gorm.Open(postgres.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
TranslateError: true,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func setupMySQL(dbInfo databaseInfo, sharedCache bool) error {
|
||||
var err error
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbInfo.User, dbInfo.Password, dbInfo.Host, dbInfo.Port, dbInfo.Database)
|
||||
|
||||
db, err = gorm.Open(mysql.New(mysql.Config{
|
||||
DSN: dsn,
|
||||
DontSupportRenameIndex: true,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
TranslateError: true,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func DeprecationDBFilename() {
|
||||
if config.C.DBFilename != "" {
|
||||
log.Warn().Msg("The 'db-filename'/'OG_DB_FILENAME' configuration option is deprecated and will be removed in a future version. Please use 'db-uri'/'OG_DB_URI' instead.")
|
||||
}
|
||||
|
||||
if config.C.DBUri == "" {
|
||||
config.C.DBUri = config.C.DBFilename
|
||||
}
|
||||
}
|
||||
|
||||
func TruncateDatabase() error {
|
||||
return db.Migrator().DropTable("likes", &User{}, "gists", &SSHKey{}, &AdminSetting{}, &Invitation{})
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
@ -15,10 +16,21 @@ type Invitation struct {
|
||||
|
||||
func GetAllInvitations() ([]*Invitation, error) {
|
||||
var invitations []*Invitation
|
||||
err := db.
|
||||
Order("(((expires_at >= strftime('%s', 'now')) AND ((nb_max <= 0) OR (nb_used < nb_max)))) desc").
|
||||
Order("id asc").
|
||||
Find(&invitations).Error
|
||||
dialect := db.Dialector.Name()
|
||||
query := db.Model(&Invitation{})
|
||||
|
||||
switch dialect {
|
||||
case "sqlite":
|
||||
query = query.Order("(((expires_at >= strftime('%s', 'now')) AND ((nb_max <= 0) OR (nb_used < nb_max)))) DESC")
|
||||
case "postgres":
|
||||
query = query.Order("(((expires_at >= EXTRACT(EPOCH FROM CURRENT_TIMESTAMP)) AND ((nb_max <= 0) OR (nb_used < nb_max)))) DESC")
|
||||
case "mysql":
|
||||
query = query.Order("(((expires_at >= UNIX_TIMESTAMP()) AND ((nb_max <= 0) OR (nb_used < nb_max)))) DESC")
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported database dialect: %s", dialect)
|
||||
}
|
||||
|
||||
err := query.Order("id ASC").Find(&invitations).Error
|
||||
|
||||
return invitations, err
|
||||
}
|
||||
|
@ -11,7 +11,19 @@ type MigrationVersion struct {
|
||||
Version uint
|
||||
}
|
||||
|
||||
func ApplyMigrations(db *gorm.DB) error {
|
||||
func applyMigrations(db *gorm.DB, dbInfo *databaseInfo) error {
|
||||
switch dbInfo.Type {
|
||||
case SQLite:
|
||||
return applySqliteMigrations(db)
|
||||
case PostgreSQL, MySQL:
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unknown database type: %s", dbInfo.Type)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func applySqliteMigrations(db *gorm.DB) error {
|
||||
// Create migration table if it doesn't exist
|
||||
if err := db.AutoMigrate(&MigrationVersion{}); err != nil {
|
||||
log.Fatal().Err(err).Msg("Error creating migration version table")
|
||||
|
@ -6,7 +6,7 @@ import (
|
||||
|
||||
type User struct {
|
||||
ID uint `gorm:"primaryKey"`
|
||||
Username string `gorm:"uniqueIndex"`
|
||||
Username string `gorm:"uniqueIndex,size:191"`
|
||||
Password string
|
||||
IsAdmin bool
|
||||
CreatedAt int64
|
||||
|
@ -39,7 +39,7 @@ func runGitCommand(ch ssh.Channel, gitCmd string, key string, ip string) error {
|
||||
return errors.New("gist not found")
|
||||
}
|
||||
|
||||
allowUnauthenticated, err := auth.ShouldAllowUnauthenticatedGistAccess(db.DBAuthInfo{}, true)
|
||||
allowUnauthenticated, err := auth.ShouldAllowUnauthenticatedGistAccess(db.AuthInfo{}, true)
|
||||
if err != nil {
|
||||
return errors.New("internal server error")
|
||||
}
|
||||
|
@ -161,6 +161,9 @@ func adminConfig(ctx echo.Context) error {
|
||||
setData(ctx, "htmlTitle", trH(ctx, "admin.configuration")+" - "+trH(ctx, "admin.admin_panel"))
|
||||
setData(ctx, "adminHeaderPage", "config")
|
||||
|
||||
setData(ctx, "dbtype", db.DatabaseInfo.Type.String())
|
||||
setData(ctx, "dbname", db.DatabaseInfo.Database)
|
||||
|
||||
return html(ctx, "admin_config.html")
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,8 @@ import (
|
||||
"github.com/thomiceli/opengist/internal/web"
|
||||
)
|
||||
|
||||
var databaseType string
|
||||
|
||||
type testServer struct {
|
||||
server *web.Server
|
||||
sessionCookie string
|
||||
@ -132,6 +134,17 @@ func structToURLValues(s interface{}) url.Values {
|
||||
}
|
||||
|
||||
func setup(t *testing.T) {
|
||||
var databaseDsn string
|
||||
databaseType = os.Getenv("OPENGIST_TEST_DB")
|
||||
switch databaseType {
|
||||
case "sqlite":
|
||||
databaseDsn = "file::memory:"
|
||||
case "postgres":
|
||||
databaseDsn = "postgres://postgres:opengist@localhost:5432/opengist_test"
|
||||
case "mysql":
|
||||
databaseDsn = "mysql://root:opengist@localhost:3306/opengist_test"
|
||||
}
|
||||
|
||||
_ = os.Setenv("OPENGIST_SKIP_GIT_HOOKS", "1")
|
||||
|
||||
err := config.InitConfig("", io.Discard)
|
||||
@ -155,9 +168,13 @@ func setup(t *testing.T) {
|
||||
err = os.MkdirAll(filepath.Join(homePath, "tmp", "repos"), 0755)
|
||||
require.NoError(t, err, "Could not create tmp repos directory")
|
||||
|
||||
err = db.Setup("file::memory:", true)
|
||||
err = db.Setup(databaseDsn, true)
|
||||
require.NoError(t, err, "Could not initialize database")
|
||||
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("Could not initialize database")
|
||||
}
|
||||
|
||||
err = memdb.Setup()
|
||||
require.NoError(t, err, "Could not initialize in memory database")
|
||||
|
||||
@ -168,10 +185,10 @@ func setup(t *testing.T) {
|
||||
func teardown(t *testing.T, s *testServer) {
|
||||
s.stop()
|
||||
|
||||
err := db.Close()
|
||||
require.NoError(t, err, "Could not close database")
|
||||
//err := db.Close()
|
||||
//require.NoError(t, err, "Could not close database")
|
||||
|
||||
err = os.RemoveAll(path.Join(config.GetHomeDir(), "tests"))
|
||||
err := os.RemoveAll(path.Join(config.GetHomeDir(), "tests"))
|
||||
require.NoError(t, err, "Could not remove repos directory")
|
||||
|
||||
err = os.RemoveAll(path.Join(config.GetHomeDir(), "tmp", "repos"))
|
||||
@ -180,6 +197,9 @@ func teardown(t *testing.T, s *testServer) {
|
||||
err = os.RemoveAll(path.Join(config.GetHomeDir(), "tmp", "sessions"))
|
||||
require.NoError(t, err, "Could not remove repos directory")
|
||||
|
||||
err = db.TruncateDatabase()
|
||||
require.NoError(t, err, "Could not truncate database")
|
||||
|
||||
// err = os.RemoveAll(path.Join(config.C.OpengistHome, "testsindex"))
|
||||
// require.NoError(t, err, "Could not remove repos directory")
|
||||
|
||||
|
Reference in New Issue
Block a user