mirror of
https://github.com/thomiceli/opengist.git
synced 2025-07-09 01:18:04 +02:00
Refactor server code (#407)
This commit is contained in:
@ -19,7 +19,13 @@ const (
|
||||
|
||||
func GetSetting(key string) (string, error) {
|
||||
var setting AdminSetting
|
||||
err := db.Where("`key` = ?", key).First(&setting).Error
|
||||
var err error
|
||||
switch db.Dialector.Name() {
|
||||
case "mysql", "sqlite":
|
||||
err = db.Where("`key` = ?", key).First(&setting).Error
|
||||
case "postgres":
|
||||
err = db.Where("key = ?", key).First(&setting).Error
|
||||
}
|
||||
return setting.Value, err
|
||||
}
|
||||
|
||||
|
@ -46,6 +46,12 @@ var DatabaseInfo *databaseInfo
|
||||
func parseDBURI(uri string) (*databaseInfo, error) {
|
||||
info := &databaseInfo{}
|
||||
|
||||
if uri == ":memory:" {
|
||||
info.Type = SQLite
|
||||
info.Database = uri
|
||||
return info, nil
|
||||
}
|
||||
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URI: %v", err)
|
||||
@ -91,14 +97,14 @@ func parseDBURI(uri string) (*databaseInfo, error) {
|
||||
return info, nil
|
||||
}
|
||||
|
||||
func Setup(dbUri string, sharedCache bool) error {
|
||||
func Setup(dbUri string) error {
|
||||
dbInfo, err := parseDBURI(dbUri)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info().Msgf("Setting up a %s database connection", dbInfo.Type)
|
||||
var setupFunc func(databaseInfo, bool) error
|
||||
var setupFunc func(databaseInfo) error
|
||||
switch dbInfo.Type {
|
||||
case SQLite:
|
||||
setupFunc = setupSQLite
|
||||
@ -114,7 +120,7 @@ func Setup(dbUri string, sharedCache bool) error {
|
||||
retryInterval := 1 * time.Second
|
||||
|
||||
for attempt := 1; attempt <= maxAttempts; attempt++ {
|
||||
err = setupFunc(*dbInfo, sharedCache)
|
||||
err = setupFunc(*dbInfo)
|
||||
if err == nil {
|
||||
log.Info().Msg("Database connection established")
|
||||
break
|
||||
@ -142,7 +148,7 @@ func Setup(dbUri string, sharedCache bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = applyMigrations(db, dbInfo); err != nil {
|
||||
if err = applyMigrations(dbInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -183,37 +189,38 @@ func Ping() error {
|
||||
return sql.Ping()
|
||||
}
|
||||
|
||||
func setupSQLite(dbInfo databaseInfo, sharedCache bool) error {
|
||||
func setupSQLite(dbInfo databaseInfo) error {
|
||||
var err error
|
||||
var dsn string
|
||||
journalMode := strings.ToUpper(config.C.SqliteJournalMode)
|
||||
|
||||
if !slices.Contains([]string{"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, journalMode) {
|
||||
log.Warn().Msg("Invalid SQLite journal mode: " + journalMode)
|
||||
}
|
||||
|
||||
u, err := url.Parse(dbInfo.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if dbInfo.Database == ":memory:" {
|
||||
dsn = ":memory:?_fk=true&cache=shared"
|
||||
} else {
|
||||
u, err := url.Parse(dbInfo.Database)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
u.Scheme = "file"
|
||||
q := u.Query()
|
||||
q.Set("_fk", "true")
|
||||
q.Set("_journal_mode", journalMode)
|
||||
if sharedCache {
|
||||
q.Set("cache", "shared")
|
||||
u.Scheme = "file"
|
||||
q := u.Query()
|
||||
q.Set("_fk", "true")
|
||||
q.Set("_journal_mode", journalMode)
|
||||
u.RawQuery = q.Encode()
|
||||
dsn = u.String()
|
||||
}
|
||||
u.RawQuery = q.Encode()
|
||||
dsn := u.String()
|
||||
db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
TranslateError: true,
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func setupPostgres(dbInfo databaseInfo, sharedCache bool) error {
|
||||
func setupPostgres(dbInfo databaseInfo) error {
|
||||
var err error
|
||||
dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", dbInfo.Host, dbInfo.Port, dbInfo.User, dbInfo.Password, dbInfo.Database)
|
||||
|
||||
@ -225,7 +232,7 @@ func setupPostgres(dbInfo databaseInfo, sharedCache bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func setupMySQL(dbInfo databaseInfo, sharedCache bool) error {
|
||||
func setupMySQL(dbInfo databaseInfo) error {
|
||||
var err error
|
||||
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbInfo.User, dbInfo.Password, dbInfo.Host, dbInfo.Port, dbInfo.Database)
|
||||
|
||||
|
@ -3,7 +3,6 @@ package db
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type MigrationVersion struct {
|
||||
@ -11,10 +10,10 @@ type MigrationVersion struct {
|
||||
Version uint
|
||||
}
|
||||
|
||||
func applyMigrations(db *gorm.DB, dbInfo *databaseInfo) error {
|
||||
func applyMigrations(dbInfo *databaseInfo) error {
|
||||
switch dbInfo.Type {
|
||||
case SQLite:
|
||||
return applySqliteMigrations(db)
|
||||
return applySqliteMigrations()
|
||||
case PostgreSQL, MySQL:
|
||||
return nil
|
||||
default:
|
||||
@ -23,7 +22,7 @@ func applyMigrations(db *gorm.DB, dbInfo *databaseInfo) error {
|
||||
|
||||
}
|
||||
|
||||
func applySqliteMigrations(db *gorm.DB) error {
|
||||
func applySqliteMigrations() error {
|
||||
// Create migration table if it doesn't exist
|
||||
if err := db.AutoMigrate(&MigrationVersion{}); err != nil {
|
||||
log.Fatal().Err(err).Msg("Error creating migration version table")
|
||||
@ -37,7 +36,7 @@ func applySqliteMigrations(db *gorm.DB) error {
|
||||
// Define migrations
|
||||
migrations := []struct {
|
||||
Version uint
|
||||
Func func(*gorm.DB) error
|
||||
Func func() error
|
||||
}{
|
||||
{1, v1_modifyConstraintToSSHKeys},
|
||||
{2, v2_lowercaseEmails},
|
||||
@ -53,7 +52,7 @@ func applySqliteMigrations(db *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := m.Func(db); err != nil {
|
||||
if err := m.Func(); err != nil {
|
||||
log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version))
|
||||
tx.Rollback()
|
||||
return err
|
||||
@ -73,7 +72,7 @@ func applySqliteMigrations(db *gorm.DB) error {
|
||||
}
|
||||
|
||||
// Modify the constraint on the ssh_keys table to use ON DELETE CASCADE
|
||||
func v1_modifyConstraintToSSHKeys(db *gorm.DB) error {
|
||||
func v1_modifyConstraintToSSHKeys() error {
|
||||
createSQL := `
|
||||
CREATE TABLE ssh_keys_temp (
|
||||
id integer primary key,
|
||||
@ -108,7 +107,7 @@ func v1_modifyConstraintToSSHKeys(db *gorm.DB) error {
|
||||
return db.Exec(renameSQL).Error
|
||||
}
|
||||
|
||||
func v2_lowercaseEmails(db *gorm.DB) error {
|
||||
func v2_lowercaseEmails() error {
|
||||
// Copy the lowercase emails into the new column
|
||||
copySQL := `UPDATE users SET email = lower(email);`
|
||||
return db.Exec(copySQL).Error
|
||||
|
@ -6,9 +6,10 @@ import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/thomiceli/opengist/internal/auth"
|
||||
"github.com/thomiceli/opengist/internal/auth/password"
|
||||
ogtotp "github.com/thomiceli/opengist/internal/auth/totp"
|
||||
"github.com/thomiceli/opengist/internal/config"
|
||||
"github.com/thomiceli/opengist/internal/utils"
|
||||
"slices"
|
||||
)
|
||||
|
||||
@ -30,7 +31,7 @@ func GetTOTPByUserID(userID uint) (*TOTP, error) {
|
||||
|
||||
func (totp *TOTP) StoreSecret(secret string) error {
|
||||
secretBytes := []byte(secret)
|
||||
encrypted, err := utils.AESEncrypt(config.SecretKey, secretBytes)
|
||||
encrypted, err := auth.AESEncrypt(config.SecretKey, secretBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -45,7 +46,7 @@ func (totp *TOTP) ValidateCode(code string) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
secretBytes, err := utils.AESDecrypt(config.SecretKey, ciphertext)
|
||||
secretBytes, err := auth.AESDecrypt(config.SecretKey, ciphertext)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -60,7 +61,7 @@ func (totp *TOTP) ValidateRecoveryCode(code string) (bool, error) {
|
||||
}
|
||||
|
||||
for i, hashedCode := range hashedCodes {
|
||||
ok, err := utils.Argon2id.Verify(code, hashedCode)
|
||||
ok, err := password.VerifyPassword(code, hashedCode)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -106,7 +107,7 @@ func generateRandomCodes() ([]string, []string, error) {
|
||||
hexCode := hex.EncodeToString(bytes)
|
||||
code := fmt.Sprintf("%s-%s", hexCode[:length/2], hexCode[length/2:])
|
||||
plainCodes[i] = code
|
||||
hashed, err := utils.Argon2id.Hash(code)
|
||||
hashed, err := password.HashPassword(code)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/thomiceli/opengist/internal/git"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
@ -25,31 +26,36 @@ type User struct {
|
||||
}
|
||||
|
||||
func (user *User) BeforeDelete(tx *gorm.DB) error {
|
||||
// Decrement likes counter for all gists liked by this user
|
||||
// The likes will be automatically deleted by the foreign key constraint
|
||||
err := tx.Model(&Gist{}).
|
||||
Omit("updated_at").
|
||||
Where("id IN (?)", tx.
|
||||
Select("gist_id").
|
||||
Table("likes").
|
||||
Where("user_id = ?", user.ID),
|
||||
).
|
||||
UpdateColumn("nb_likes", gorm.Expr("nb_likes - 1")).
|
||||
Error
|
||||
// Decrement likes counter using derived table
|
||||
err := tx.Exec(`
|
||||
UPDATE gists
|
||||
SET nb_likes = nb_likes - 1
|
||||
WHERE id IN (
|
||||
SELECT gist_id
|
||||
FROM (
|
||||
SELECT gist_id
|
||||
FROM likes
|
||||
WHERE user_id = ?
|
||||
) AS derived_likes
|
||||
)
|
||||
`, user.ID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decrement forks counter for all gists forked by this user
|
||||
err = tx.Model(&Gist{}).
|
||||
Omit("updated_at").
|
||||
Where("id IN (?)", tx.
|
||||
Select("forked_id").
|
||||
Table("gists").
|
||||
Where("user_id = ?", user.ID),
|
||||
).
|
||||
UpdateColumn("nb_forks", gorm.Expr("nb_forks - 1")).
|
||||
Error
|
||||
// Decrement forks counter using derived table
|
||||
err = tx.Exec(`
|
||||
UPDATE gists
|
||||
SET nb_forks = nb_forks - 1
|
||||
WHERE id IN (
|
||||
SELECT forked_id
|
||||
FROM (
|
||||
SELECT forked_id
|
||||
FROM gists
|
||||
WHERE user_id = ? AND forked_id IS NOT NULL
|
||||
) AS derived_forks
|
||||
)
|
||||
`, user.ID).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -64,8 +70,17 @@ func (user *User) BeforeDelete(tx *gorm.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete all gists created by this user
|
||||
return tx.Where("user_id = ?", user.ID).Delete(&Gist{}).Error
|
||||
err = tx.Where("user_id = ?", user.ID).Delete(&Gist{}).Error
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete user directory
|
||||
if err = git.DeleteUserDirectory(user.Username); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func UserExists(username string) (bool, error) {
|
||||
|
@ -73,8 +73,7 @@ func GetUserByCredentialID(credID binaryData) (*User, error) {
|
||||
if err = db.Preload("User").Where("credential_id = decode(?, 'hex')", hexCredID).First(&credential).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "mysql":
|
||||
case "sqlite":
|
||||
case "mysql", "sqlite":
|
||||
hexCredID := hex.EncodeToString(credID)
|
||||
if err = db.Preload("User").Where("credential_id = unhex(?)", hexCredID).First(&credential).Error; err != nil {
|
||||
return nil, err
|
||||
@ -100,8 +99,7 @@ func GetCredentialByID(id binaryData) (*WebAuthnCredential, error) {
|
||||
if err = db.Where("credential_id = decode(?, 'hex')", hexCredID).First(&cred).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "mysql":
|
||||
case "sqlite":
|
||||
case "mysql", "sqlite":
|
||||
hexCredID := hex.EncodeToString(id)
|
||||
if err = db.Where("credential_id = unhex(?)", hexCredID).First(&cred).Error; err != nil {
|
||||
return nil, err
|
||||
|
Reference in New Issue
Block a user