diff --git a/internal/actions/actions.go b/internal/actions/actions.go index 5c8b460..7a65ab6 100644 --- a/internal/actions/actions.go +++ b/internal/actions/actions.go @@ -141,17 +141,8 @@ func syncGistPreviews() { func resetHooks() { log.Info().Msg("Resetting Git server hooks for all repositories...") - entries, err := filepath.Glob(filepath.Join(config.GetHomeDir(), "repos", "*", "*")) - if err != nil { - log.Error().Err(err).Msg("Cannot read repos directories") - return - } - - for _, e := range entries { - path := strings.Split(e, string(os.PathSeparator)) - if err := git.CreateDotGitFiles(path[len(path)-2], path[len(path)-1]); err != nil { - log.Error().Err(err).Msgf("Cannot reset hooks for repository %s/%s", path[len(path)-2], path[len(path)-1]) - } + if err := git.ResetHooks(); err != nil { + log.Error().Err(err).Msg("Error resetting hooks for repositories") } } diff --git a/internal/utils/aes.go b/internal/auth/aes.go similarity index 98% rename from internal/utils/aes.go rename to internal/auth/aes.go index c401d35..e64c116 100644 --- a/internal/utils/aes.go +++ b/internal/auth/aes.go @@ -1,4 +1,4 @@ -package utils +package auth import ( "crypto/aes" diff --git a/internal/utils/argon2id.go b/internal/auth/argon2id.go similarity index 88% rename from internal/utils/argon2id.go rename to internal/auth/argon2id.go index 5cf15c2..765fceb 100644 --- a/internal/utils/argon2id.go +++ b/internal/auth/argon2id.go @@ -1,4 +1,4 @@ -package utils +package auth import ( "crypto/rand" @@ -10,7 +10,7 @@ import ( "strings" ) -type Argon2ID struct { +type argon2ID struct { format string version int time uint32 @@ -20,7 +20,7 @@ type Argon2ID struct { threads uint8 } -var Argon2id = Argon2ID{ +var Argon2id = argon2ID{ format: "$argon2id$v=%d$m=%d,t=%d,p=%d$%s$%s", version: argon2.Version, time: 1, @@ -30,7 +30,7 @@ var Argon2id = Argon2ID{ threads: 4, } -func (a Argon2ID) Hash(plain string) (string, error) { +func (a argon2ID) Hash(plain string) (string, error) { salt := make([]byte, a.saltLen) if _, err := rand.Read(salt); err != nil { return "", err @@ -44,7 +44,7 @@ func (a Argon2ID) Hash(plain string) (string, error) { ), nil } -func (a Argon2ID) Verify(plain, hash string) (bool, error) { +func (a argon2ID) Verify(plain, hash string) (bool, error) { if hash == "" { return false, nil } diff --git a/internal/auth/oauth/gitea.go b/internal/auth/oauth/gitea.go new file mode 100644 index 0000000..88db9bd --- /dev/null +++ b/internal/auth/oauth/gitea.go @@ -0,0 +1,117 @@ +package oauth + +import ( + gocontext "context" + gojson "encoding/json" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/gitea" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "io" + "net/http" +) + +type GiteaProvider struct { + Provider + URL string +} + +func (p *GiteaProvider) RegisterProvider() error { + goth.UseProviders( + gitea.NewCustomisedURL( + config.C.GiteaClientKey, + config.C.GiteaSecret, + urlJoin(p.URL, "/oauth/gitea/callback"), + urlJoin(config.C.GiteaUrl, "/login/oauth/authorize"), + urlJoin(config.C.GiteaUrl, "/login/oauth/access_token"), + urlJoin(config.C.GiteaUrl, "/api/v1/user"), + ), + ) + + return nil +} + +func (p *GiteaProvider) BeginAuthHandler(ctx *context.Context) { + ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, GiteaProviderString) + ctx.SetRequest(ctx.Request().WithContext(ctxValue)) + + gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) +} + +func (p *GiteaProvider) UserHasProvider(user *db.User) bool { + return user.GiteaID != "" +} + +func NewGiteaProvider(url string) *GiteaProvider { + return &GiteaProvider{ + URL: url, + } +} + +type GiteaCallbackProvider struct { + CallbackProvider + User *goth.User +} + +func (p *GiteaCallbackProvider) GetProvider() string { + return GiteaProviderString +} + +func (p *GiteaCallbackProvider) GetProviderUser() *goth.User { + return p.User +} + +func (p *GiteaCallbackProvider) GetProviderUserID(user *db.User) bool { + return user.GiteaID != "" +} + +func (p *GiteaCallbackProvider) GetProviderUserSSHKeys() ([]string, error) { + resp, err := http.Get(urlJoin(config.C.GiteaUrl, p.User.NickName+".keys")) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return readKeys(resp) +} + +func (p *GiteaCallbackProvider) UpdateUserDB(user *db.User) { + user.GiteaID = p.User.UserID + + resp, err := http.Get(urlJoin(config.C.GiteaUrl, "/api/v1/users/", p.User.UserID)) + if err != nil { + log.Error().Err(err).Msg("Cannot get user from Gitea") + return + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Cannot read Gitea response body") + return + } + + var result map[string]interface{} + err = gojson.Unmarshal(body, &result) + if err != nil { + log.Error().Err(err).Msg("Cannot unmarshal Gitea response body") + return + } + + field, ok := result["avatar_url"] + if !ok { + log.Error().Msg("Field 'avatar_url' not found in Gitea JSON response") + return + } + + user.AvatarURL = field.(string) +} + +func NewGiteaCallbackProvider(user *goth.User) CallbackProvider { + return &GiteaCallbackProvider{ + User: user, + } +} diff --git a/internal/auth/oauth/github.go b/internal/auth/oauth/github.go new file mode 100644 index 0000000..02cd018 --- /dev/null +++ b/internal/auth/oauth/github.go @@ -0,0 +1,84 @@ +package oauth + +import ( + gocontext "context" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/github" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "net/http" +) + +type GitHubProvider struct { + Provider + URL string +} + +func (p *GitHubProvider) RegisterProvider() error { + goth.UseProviders( + github.New( + config.C.GithubClientKey, + config.C.GithubSecret, + urlJoin(p.URL, "/oauth/github/callback"), + ), + ) + + return nil +} + +func (p *GitHubProvider) BeginAuthHandler(ctx *context.Context) { + ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, GitHubProviderString) + ctx.SetRequest(ctx.Request().WithContext(ctxValue)) + + gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) +} + +func (p *GitHubProvider) UserHasProvider(user *db.User) bool { + return user.GithubID != "" +} + +func NewGitHubProvider(url string) *GitHubProvider { + return &GitHubProvider{ + URL: url, + } +} + +type GitHubCallbackProvider struct { + CallbackProvider + User *goth.User +} + +func (p *GitHubCallbackProvider) GetProvider() string { + return GitHubProviderString +} + +func (p *GitHubCallbackProvider) GetProviderUser() *goth.User { + return p.User +} + +func (p *GitHubCallbackProvider) GetProviderUserID(user *db.User) bool { + return user.GithubID != "" +} + +func (p *GitHubCallbackProvider) GetProviderUserSSHKeys() ([]string, error) { + resp, err := http.Get("https://github.com/" + p.User.NickName + ".keys") + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return readKeys(resp) +} + +func (p *GitHubCallbackProvider) UpdateUserDB(user *db.User) { + user.GithubID = p.User.UserID + user.AvatarURL = "https://avatars.githubusercontent.com/u/" + p.User.UserID + "?v=4" +} + +func NewGitHubCallbackProvider(user *goth.User) CallbackProvider { + return &GitHubCallbackProvider{ + User: user, + } +} diff --git a/internal/auth/oauth/gitlab.go b/internal/auth/oauth/gitlab.go new file mode 100644 index 0000000..6c00d14 --- /dev/null +++ b/internal/auth/oauth/gitlab.go @@ -0,0 +1,87 @@ +package oauth + +import ( + gocontext "context" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/gitlab" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "net/http" +) + +type GitLabProvider struct { + Provider + URL string +} + +func (p *GitLabProvider) RegisterProvider() error { + goth.UseProviders( + gitlab.NewCustomisedURL( + config.C.GitlabClientKey, + config.C.GitlabSecret, + urlJoin(p.URL, "/oauth/gitlab/callback"), + urlJoin(config.C.GitlabUrl, "/oauth/authorize"), + urlJoin(config.C.GitlabUrl, "/oauth/token"), + urlJoin(config.C.GitlabUrl, "/api/v4/user"), + ), + ) + + return nil +} + +func (p *GitLabProvider) BeginAuthHandler(ctx *context.Context) { + ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, GitLabProviderString) + ctx.SetRequest(ctx.Request().WithContext(ctxValue)) + + gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) +} + +func (p *GitLabProvider) UserHasProvider(user *db.User) bool { + return user.GitlabID != "" +} + +func NewGitLabProvider(url string) *GitLabProvider { + return &GitLabProvider{ + URL: url, + } +} + +type GitLabCallbackProvider struct { + CallbackProvider + User *goth.User +} + +func (p *GitLabCallbackProvider) GetProvider() string { + return GitLabProviderString +} + +func (p *GitLabCallbackProvider) GetProviderUser() *goth.User { + return p.User +} + +func (p *GitLabCallbackProvider) GetProviderUserID(user *db.User) bool { + return user.GitlabID != "" +} + +func (p *GitLabCallbackProvider) GetProviderUserSSHKeys() ([]string, error) { + resp, err := http.Get(urlJoin(config.C.GitlabUrl, p.User.NickName+".keys")) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + return readKeys(resp) +} + +func (p *GitLabCallbackProvider) UpdateUserDB(user *db.User) { + user.GitlabID = p.User.UserID + user.AvatarURL = urlJoin(config.C.GitlabUrl, "/uploads/-/system/user/avatar/", p.User.UserID, "/avatar.png") + "?width=400" +} + +func NewGitLabCallbackProvider(user *goth.User) CallbackProvider { + return &GitLabCallbackProvider{ + User: user, + } +} diff --git a/internal/auth/oauth/openid.go b/internal/auth/oauth/openid.go new file mode 100644 index 0000000..8731ce3 --- /dev/null +++ b/internal/auth/oauth/openid.go @@ -0,0 +1,85 @@ +package oauth + +import ( + gocontext "context" + "errors" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/openidConnect" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" +) + +type OIDCProvider struct { + Provider + URL string +} + +func (p *OIDCProvider) RegisterProvider() error { + oidcProvider, err := openidConnect.New( + config.C.OIDCClientKey, + config.C.OIDCSecret, + urlJoin(p.URL, "/oauth/openid-connect/callback"), + config.C.OIDCDiscoveryUrl, + "openid", + "email", + "profile", + ) + + if err != nil { + return errors.New("Cannot create OIDC provider: " + err.Error()) + } + + goth.UseProviders(oidcProvider) + return nil +} + +func (p *OIDCProvider) BeginAuthHandler(ctx *context.Context) { + ctxValue := gocontext.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, OpenIDConnectString) + ctx.SetRequest(ctx.Request().WithContext(ctxValue)) + + gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) +} + +func (p *OIDCProvider) UserHasProvider(user *db.User) bool { + return user.OIDCID != "" +} + +func NewOIDCProvider(url string) *OIDCProvider { + return &OIDCProvider{ + URL: url, + } +} + +type OIDCCallbackProvider struct { + CallbackProvider + User *goth.User +} + +func (p *OIDCCallbackProvider) GetProvider() string { + return OpenIDConnectString +} + +func (p *OIDCCallbackProvider) GetProviderUser() *goth.User { + return p.User +} + +func (p *OIDCCallbackProvider) GetProviderUserID(user *db.User) bool { + return user.OIDCID != "" +} + +func (p *OIDCCallbackProvider) GetProviderUserSSHKeys() ([]string, error) { + return nil, nil +} + +func (p *OIDCCallbackProvider) UpdateUserDB(user *db.User) { + user.OIDCID = p.User.UserID + user.AvatarURL = p.User.AvatarURL +} + +func NewOIDCCallbackProvider(user *goth.User) CallbackProvider { + return &OIDCCallbackProvider{ + User: user, + } +} diff --git a/internal/auth/oauth/provider.go b/internal/auth/oauth/provider.go new file mode 100644 index 0000000..951e9f4 --- /dev/null +++ b/internal/auth/oauth/provider.go @@ -0,0 +1,93 @@ +package oauth + +import ( + "fmt" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "io" + "net/http" + "net/url" + "strings" +) + +const ( + GitHubProviderString = "github" + GitLabProviderString = "gitlab" + GiteaProviderString = "gitea" + OpenIDConnectString = "openid-connect" +) + +type Provider interface { + RegisterProvider() error + BeginAuthHandler(ctx *context.Context) + UserHasProvider(user *db.User) bool +} + +type CallbackProvider interface { + GetProvider() string + GetProviderUser() *goth.User + GetProviderUserID(user *db.User) bool + GetProviderUserSSHKeys() ([]string, error) + UpdateUserDB(user *db.User) +} + +func DefineProvider(provider string, url string) (Provider, error) { + switch provider { + case GitHubProviderString: + return NewGitHubProvider(url), nil + case GitLabProviderString: + return NewGitLabProvider(url), nil + case GiteaProviderString: + return NewGiteaProvider(url), nil + case OpenIDConnectString: + return NewOIDCProvider(url), nil + } + + return nil, fmt.Errorf("unsupported provider %s", provider) +} + +func CompleteUserAuth(ctx *context.Context) (CallbackProvider, error) { + user, err := gothic.CompleteUserAuth(ctx.Response(), ctx.Request()) + if err != nil { + return nil, err + } + + switch user.Provider { + case GitHubProviderString: + return NewGitHubCallbackProvider(&user), nil + case GitLabProviderString: + return NewGitLabCallbackProvider(&user), nil + case GiteaProviderString: + return NewGiteaCallbackProvider(&user), nil + case OpenIDConnectString: + return NewOIDCCallbackProvider(&user), nil + } + + return nil, fmt.Errorf("unsupported provider %s", user.Provider) +} + +func urlJoin(base string, elem ...string) string { + joined, err := url.JoinPath(base, elem...) + if err != nil { + log.Error().Err(err).Msg("Cannot join url") + } + + return joined +} + +func readKeys(response *http.Response) ([]string, error) { + body, err := io.ReadAll(response.Body) + if err != nil { + return nil, fmt.Errorf("could not get user keys %v", err) + } + + keys := strings.Split(string(body), "\n") + if len(keys[len(keys)-1]) == 0 { + keys = keys[:len(keys)-1] + } + + return keys, nil +} diff --git a/internal/auth/password/password.go b/internal/auth/password/password.go new file mode 100644 index 0000000..9e96130 --- /dev/null +++ b/internal/auth/password/password.go @@ -0,0 +1,11 @@ +package password + +import "github.com/thomiceli/opengist/internal/auth" + +func HashPassword(code string) (string, error) { + return auth.Argon2id.Hash(code) +} + +func VerifyPassword(code, hashedCode string) (bool, error) { + return auth.Argon2id.Verify(code, hashedCode) +} diff --git a/internal/cli/admin.go b/internal/cli/admin.go index 5abac41..a1d5bcd 100644 --- a/internal/cli/admin.go +++ b/internal/cli/admin.go @@ -2,8 +2,8 @@ package cli import ( "fmt" + "github.com/thomiceli/opengist/internal/auth/password" "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/utils" "github.com/urfave/cli/v2" ) @@ -33,7 +33,7 @@ var CmdAdminResetPassword = cli.Command{ fmt.Printf("Cannot get user %s: %s\n", username, err) return err } - password, err := utils.Argon2id.Hash(plainPassword) + password, err := password.HashPassword(plainPassword) if err != nil { fmt.Printf("Cannot hash password for user %s: %s\n", username, err) return err diff --git a/internal/cli/hook.go b/internal/cli/hook.go index 5af3b46..1d12cac 100644 --- a/internal/cli/hook.go +++ b/internal/cli/hook.go @@ -50,7 +50,7 @@ func initialize(ctx *cli.Context) { config.InitLog() db.DeprecationDBFilename() - if err := db.Setup(config.C.DBUri, false); err != nil { + if err := db.Setup(config.C.DBUri); err != nil { log.Fatal().Err(err).Msg("Failed to initialize database in hooks") } } diff --git a/internal/cli/main.go b/internal/cli/main.go index 4c980fd..f4f6b85 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -10,7 +10,7 @@ import ( "github.com/thomiceli/opengist/internal/index" "github.com/thomiceli/opengist/internal/memdb" "github.com/thomiceli/opengist/internal/ssh" - "github.com/thomiceli/opengist/internal/web" + "github.com/thomiceli/opengist/internal/web/server" "github.com/urfave/cli/v2" "os" "os/signal" @@ -37,7 +37,7 @@ var CmdStart = cli.Command{ Initialize(ctx) - go web.NewServer(os.Getenv("OG_DEV") == "1", path.Join(config.GetHomeDir(), "sessions"), false).Start() + go server.NewServer(os.Getenv("OG_DEV") == "1", path.Join(config.GetHomeDir(), "sessions"), false).Start() go ssh.Start() <-stopCtx.Done() @@ -117,7 +117,7 @@ func Initialize(ctx *cli.Context) { } db.DeprecationDBFilename() - if err := db.Setup(config.C.DBUri, false); err != nil { + if err := db.Setup(config.C.DBUri); err != nil { log.Fatal().Err(err).Msg("Failed to initialize database") } diff --git a/internal/config/config.go b/internal/config/config.go index 6093149..c2e868a 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -2,6 +2,7 @@ package config import ( "fmt" + "github.com/thomiceli/opengist/internal/session" "io" "net/url" "os" @@ -14,7 +15,6 @@ import ( "github.com/rs/zerolog" "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/utils" "gopkg.in/yaml.v3" ) @@ -165,9 +165,9 @@ func InitLog() { } var logWriters []io.Writer - logOutputTypes := utils.RemoveDuplicates[string]( - strings.Split(strings.ToLower(C.LogOutput), ","), - ) + logOutputTypes := strings.Split(strings.ToLower(C.LogOutput), ",") + slices.Sort(logOutputTypes) + logOutputTypes = slices.Compact(logOutputTypes) consoleWriter := zerolog.NewConsoleWriter( func(w *zerolog.ConsoleWriter) { @@ -245,7 +245,7 @@ func GetHomeDir() string { func SetupSecretKey() { if C.SecretKey == "" { path := filepath.Join(GetHomeDir(), "opengist-secret.key") - SecretKey, _ = utils.GenerateSecretKey(path) + SecretKey, _ = session.GenerateSecretKey(path) } else { SecretKey = []byte(C.SecretKey) } diff --git a/internal/db/admin_setting.go b/internal/db/admin_setting.go index d8586ee..06edd6e 100644 --- a/internal/db/admin_setting.go +++ b/internal/db/admin_setting.go @@ -19,7 +19,13 @@ const ( func GetSetting(key string) (string, error) { var setting AdminSetting - err := db.Where("`key` = ?", key).First(&setting).Error + var err error + switch db.Dialector.Name() { + case "mysql", "sqlite": + err = db.Where("`key` = ?", key).First(&setting).Error + case "postgres": + err = db.Where("key = ?", key).First(&setting).Error + } return setting.Value, err } diff --git a/internal/db/db.go b/internal/db/db.go index 3b5c027..6c0cafa 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -46,6 +46,12 @@ var DatabaseInfo *databaseInfo func parseDBURI(uri string) (*databaseInfo, error) { info := &databaseInfo{} + if uri == ":memory:" { + info.Type = SQLite + info.Database = uri + return info, nil + } + u, err := url.Parse(uri) if err != nil { return nil, fmt.Errorf("invalid URI: %v", err) @@ -91,14 +97,14 @@ func parseDBURI(uri string) (*databaseInfo, error) { return info, nil } -func Setup(dbUri string, sharedCache bool) error { +func Setup(dbUri string) error { dbInfo, err := parseDBURI(dbUri) if err != nil { return err } log.Info().Msgf("Setting up a %s database connection", dbInfo.Type) - var setupFunc func(databaseInfo, bool) error + var setupFunc func(databaseInfo) error switch dbInfo.Type { case SQLite: setupFunc = setupSQLite @@ -114,7 +120,7 @@ func Setup(dbUri string, sharedCache bool) error { retryInterval := 1 * time.Second for attempt := 1; attempt <= maxAttempts; attempt++ { - err = setupFunc(*dbInfo, sharedCache) + err = setupFunc(*dbInfo) if err == nil { log.Info().Msg("Database connection established") break @@ -142,7 +148,7 @@ func Setup(dbUri string, sharedCache bool) error { return err } - if err = applyMigrations(db, dbInfo); err != nil { + if err = applyMigrations(dbInfo); err != nil { return err } @@ -183,37 +189,38 @@ func Ping() error { return sql.Ping() } -func setupSQLite(dbInfo databaseInfo, sharedCache bool) error { +func setupSQLite(dbInfo databaseInfo) error { var err error + var dsn string journalMode := strings.ToUpper(config.C.SqliteJournalMode) if !slices.Contains([]string{"DELETE", "TRUNCATE", "PERSIST", "MEMORY", "WAL", "OFF"}, journalMode) { log.Warn().Msg("Invalid SQLite journal mode: " + journalMode) } - u, err := url.Parse(dbInfo.Database) - if err != nil { - return err - } + if dbInfo.Database == ":memory:" { + dsn = ":memory:?_fk=true&cache=shared" + } else { + u, err := url.Parse(dbInfo.Database) + if err != nil { + return err + } - u.Scheme = "file" - q := u.Query() - q.Set("_fk", "true") - q.Set("_journal_mode", journalMode) - if sharedCache { - q.Set("cache", "shared") + u.Scheme = "file" + q := u.Query() + q.Set("_fk", "true") + q.Set("_journal_mode", journalMode) + u.RawQuery = q.Encode() + dsn = u.String() } - u.RawQuery = q.Encode() - dsn := u.String() db, err = gorm.Open(sqlite.Open(dsn), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), TranslateError: true, }) - return err } -func setupPostgres(dbInfo databaseInfo, sharedCache bool) error { +func setupPostgres(dbInfo databaseInfo) error { var err error dsn := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", dbInfo.Host, dbInfo.Port, dbInfo.User, dbInfo.Password, dbInfo.Database) @@ -225,7 +232,7 @@ func setupPostgres(dbInfo databaseInfo, sharedCache bool) error { return err } -func setupMySQL(dbInfo databaseInfo, sharedCache bool) error { +func setupMySQL(dbInfo databaseInfo) error { var err error dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local", dbInfo.User, dbInfo.Password, dbInfo.Host, dbInfo.Port, dbInfo.Database) diff --git a/internal/db/migration.go b/internal/db/migration.go index 9510a87..a549d41 100644 --- a/internal/db/migration.go +++ b/internal/db/migration.go @@ -3,7 +3,6 @@ package db import ( "fmt" "github.com/rs/zerolog/log" - "gorm.io/gorm" ) type MigrationVersion struct { @@ -11,10 +10,10 @@ type MigrationVersion struct { Version uint } -func applyMigrations(db *gorm.DB, dbInfo *databaseInfo) error { +func applyMigrations(dbInfo *databaseInfo) error { switch dbInfo.Type { case SQLite: - return applySqliteMigrations(db) + return applySqliteMigrations() case PostgreSQL, MySQL: return nil default: @@ -23,7 +22,7 @@ func applyMigrations(db *gorm.DB, dbInfo *databaseInfo) error { } -func applySqliteMigrations(db *gorm.DB) error { +func applySqliteMigrations() error { // Create migration table if it doesn't exist if err := db.AutoMigrate(&MigrationVersion{}); err != nil { log.Fatal().Err(err).Msg("Error creating migration version table") @@ -37,7 +36,7 @@ func applySqliteMigrations(db *gorm.DB) error { // Define migrations migrations := []struct { Version uint - Func func(*gorm.DB) error + Func func() error }{ {1, v1_modifyConstraintToSSHKeys}, {2, v2_lowercaseEmails}, @@ -53,7 +52,7 @@ func applySqliteMigrations(db *gorm.DB) error { return err } - if err := m.Func(db); err != nil { + if err := m.Func(); err != nil { log.Fatal().Err(err).Msg(fmt.Sprintf("Error applying migration %d:", m.Version)) tx.Rollback() return err @@ -73,7 +72,7 @@ func applySqliteMigrations(db *gorm.DB) error { } // Modify the constraint on the ssh_keys table to use ON DELETE CASCADE -func v1_modifyConstraintToSSHKeys(db *gorm.DB) error { +func v1_modifyConstraintToSSHKeys() error { createSQL := ` CREATE TABLE ssh_keys_temp ( id integer primary key, @@ -108,7 +107,7 @@ func v1_modifyConstraintToSSHKeys(db *gorm.DB) error { return db.Exec(renameSQL).Error } -func v2_lowercaseEmails(db *gorm.DB) error { +func v2_lowercaseEmails() error { // Copy the lowercase emails into the new column copySQL := `UPDATE users SET email = lower(email);` return db.Exec(copySQL).Error diff --git a/internal/db/totp.go b/internal/db/totp.go index 7ae7dce..cc8421d 100644 --- a/internal/db/totp.go +++ b/internal/db/totp.go @@ -6,9 +6,10 @@ import ( "encoding/hex" "encoding/json" "fmt" + "github.com/thomiceli/opengist/internal/auth" + "github.com/thomiceli/opengist/internal/auth/password" ogtotp "github.com/thomiceli/opengist/internal/auth/totp" "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/utils" "slices" ) @@ -30,7 +31,7 @@ func GetTOTPByUserID(userID uint) (*TOTP, error) { func (totp *TOTP) StoreSecret(secret string) error { secretBytes := []byte(secret) - encrypted, err := utils.AESEncrypt(config.SecretKey, secretBytes) + encrypted, err := auth.AESEncrypt(config.SecretKey, secretBytes) if err != nil { return err } @@ -45,7 +46,7 @@ func (totp *TOTP) ValidateCode(code string) (bool, error) { return false, err } - secretBytes, err := utils.AESDecrypt(config.SecretKey, ciphertext) + secretBytes, err := auth.AESDecrypt(config.SecretKey, ciphertext) if err != nil { return false, err } @@ -60,7 +61,7 @@ func (totp *TOTP) ValidateRecoveryCode(code string) (bool, error) { } for i, hashedCode := range hashedCodes { - ok, err := utils.Argon2id.Verify(code, hashedCode) + ok, err := password.VerifyPassword(code, hashedCode) if err != nil { return false, err } @@ -106,7 +107,7 @@ func generateRandomCodes() ([]string, []string, error) { hexCode := hex.EncodeToString(bytes) code := fmt.Sprintf("%s-%s", hexCode[:length/2], hexCode[length/2:]) plainCodes[i] = code - hashed, err := utils.Argon2id.Hash(code) + hashed, err := password.HashPassword(code) if err != nil { return nil, nil, err } diff --git a/internal/db/user.go b/internal/db/user.go index d7ad00e..bff23c8 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -1,6 +1,7 @@ package db import ( + "github.com/thomiceli/opengist/internal/git" "gorm.io/gorm" ) @@ -25,31 +26,36 @@ type User struct { } func (user *User) BeforeDelete(tx *gorm.DB) error { - // Decrement likes counter for all gists liked by this user - // The likes will be automatically deleted by the foreign key constraint - err := tx.Model(&Gist{}). - Omit("updated_at"). - Where("id IN (?)", tx. - Select("gist_id"). - Table("likes"). - Where("user_id = ?", user.ID), - ). - UpdateColumn("nb_likes", gorm.Expr("nb_likes - 1")). - Error + // Decrement likes counter using derived table + err := tx.Exec(` + UPDATE gists + SET nb_likes = nb_likes - 1 + WHERE id IN ( + SELECT gist_id + FROM ( + SELECT gist_id + FROM likes + WHERE user_id = ? + ) AS derived_likes + ) + `, user.ID).Error if err != nil { return err } - // Decrement forks counter for all gists forked by this user - err = tx.Model(&Gist{}). - Omit("updated_at"). - Where("id IN (?)", tx. - Select("forked_id"). - Table("gists"). - Where("user_id = ?", user.ID), - ). - UpdateColumn("nb_forks", gorm.Expr("nb_forks - 1")). - Error + // Decrement forks counter using derived table + err = tx.Exec(` + UPDATE gists + SET nb_forks = nb_forks - 1 + WHERE id IN ( + SELECT forked_id + FROM ( + SELECT forked_id + FROM gists + WHERE user_id = ? AND forked_id IS NOT NULL + ) AS derived_forks + ) + `, user.ID).Error if err != nil { return err } @@ -64,8 +70,17 @@ func (user *User) BeforeDelete(tx *gorm.DB) error { return err } - // Delete all gists created by this user - return tx.Where("user_id = ?", user.ID).Delete(&Gist{}).Error + err = tx.Where("user_id = ?", user.ID).Delete(&Gist{}).Error + if err != nil { + return err + } + + // Delete user directory + if err = git.DeleteUserDirectory(user.Username); err != nil { + return err + } + + return nil } func UserExists(username string) (bool, error) { diff --git a/internal/db/webauth_credential.go b/internal/db/webauth_credential.go index 742c0b2..3d77415 100644 --- a/internal/db/webauth_credential.go +++ b/internal/db/webauth_credential.go @@ -73,8 +73,7 @@ func GetUserByCredentialID(credID binaryData) (*User, error) { if err = db.Preload("User").Where("credential_id = decode(?, 'hex')", hexCredID).First(&credential).Error; err != nil { return nil, err } - case "mysql": - case "sqlite": + case "mysql", "sqlite": hexCredID := hex.EncodeToString(credID) if err = db.Preload("User").Where("credential_id = unhex(?)", hexCredID).First(&credential).Error; err != nil { return nil, err @@ -100,8 +99,7 @@ func GetCredentialByID(id binaryData) (*WebAuthnCredential, error) { if err = db.Where("credential_id = decode(?, 'hex')", hexCredID).First(&cred).Error; err != nil { return nil, err } - case "mysql": - case "sqlite": + case "mysql", "sqlite": hexCredID := hex.EncodeToString(id) if err = db.Where("credential_id = unhex(?)", hexCredID).First(&cred).Error; err != nil { return nil, err diff --git a/internal/git/commands.go b/internal/git/commands.go index e1b8173..23a332c 100644 --- a/internal/git/commands.go +++ b/internal/git/commands.go @@ -485,6 +485,22 @@ func GcRepos() error { return err } +func ResetHooks() error { + entries, err := filepath.Glob(filepath.Join(config.GetHomeDir(), ReposDirectory, "*", "*")) + if err != nil { + return err + } + + for _, e := range entries { + repoPath := strings.Split(e, string(os.PathSeparator)) + if err := CreateDotGitFiles(repoPath[len(repoPath)-2], repoPath[len(repoPath)-1]); err != nil { + log.Error().Err(err).Msgf("Cannot reset hooks for repository %s/%s", repoPath[len(repoPath)-2], repoPath[len(repoPath)-1]) + } + } + + return nil +} + func HasNoCommits(user string, gist string) (bool, error) { repositoryPath := RepositoryPath(user, gist) @@ -540,6 +556,10 @@ func CreateDotGitFiles(user string, gist string) error { return nil } +func DeleteUserDirectory(user string) error { + return os.RemoveAll(filepath.Join(config.GetHomeDir(), ReposDirectory, user)) +} + func createDotGitHookFile(repositoryPath string, hook string, content string) error { preReceiveDst, err := os.OpenFile(filepath.Join(repositoryPath, "hooks", hook), os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0744) if err != nil { diff --git a/internal/hooks/post_receive.go b/internal/hooks/post_receive.go index f9ee908..8ab0433 100644 --- a/internal/hooks/post_receive.go +++ b/internal/hooks/post_receive.go @@ -5,7 +5,7 @@ import ( "fmt" "github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/git" - "github.com/thomiceli/opengist/internal/utils" + validatorpkg "github.com/thomiceli/opengist/internal/validator" "io" "os" "os/exec" @@ -18,7 +18,7 @@ func PostReceive(in io.Reader, out, er io.Writer) error { newGist := false opts := pushOptions() gistUrl := os.Getenv("OPENGIST_REPOSITORY_URL_INTERNAL") - validator := utils.NewValidator() + validator := validatorpkg.NewValidator() scanner := bufio.NewScanner(in) for scanner.Scan() { diff --git a/internal/utils/session.go b/internal/session/session.go similarity index 97% rename from internal/utils/session.go rename to internal/session/session.go index fec5fac..39f3aa4 100644 --- a/internal/utils/session.go +++ b/internal/session/session.go @@ -1,4 +1,4 @@ -package utils +package session import ( "github.com/gorilla/securecookie" diff --git a/internal/utils/slice.go b/internal/utils/slice.go deleted file mode 100644 index 37ba25f..0000000 --- a/internal/utils/slice.go +++ /dev/null @@ -1,13 +0,0 @@ -package utils - -func RemoveDuplicates[T string | int](sliceList []T) []T { - allKeys := make(map[T]bool) - list := []T{} - for _, item := range sliceList { - if _, value := allKeys[item]; !value { - allKeys[item] = true - list = append(list, item) - } - } - return list -} diff --git a/internal/utils/validator.go b/internal/validator/validator.go similarity index 97% rename from internal/utils/validator.go rename to internal/validator/validator.go index 8d13569..b4d174b 100644 --- a/internal/utils/validator.go +++ b/internal/validator/validator.go @@ -1,4 +1,4 @@ -package utils +package validator import ( "github.com/go-playground/validator/v10" @@ -40,8 +40,7 @@ func ValidationMessages(err *error, locale *i18n.Locale) string { messages[i] = locale.String("validation.should-not-include-sub-directory", e.Field()) case "alphanum": messages[i] = locale.String("validation.should-only-contain-alphanumeric-characters", e.Field()) - case "alphanumdash": - case "alphanumdashorempty": + case "alphanumdash", "alphanumdashorempty": messages[i] = locale.String("validation.should-only-contain-alphanumeric-characters-and-dashes", e.Field()) case "min": messages[i] = locale.String("validation.not-enough", e.Field()) diff --git a/internal/web/admin.go b/internal/web/admin.go deleted file mode 100644 index c49d778..0000000 --- a/internal/web/admin.go +++ /dev/null @@ -1,236 +0,0 @@ -package web - -import ( - "github.com/labstack/echo/v4" - "github.com/thomiceli/opengist/internal/actions" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/git" - "runtime" - "strconv" - "time" -) - -func adminIndex(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "admin.admin_panel")) - setData(ctx, "adminHeaderPage", "index") - - setData(ctx, "opengistVersion", config.OpengistVersion) - setData(ctx, "goVersion", runtime.Version()) - gitVersion, err := git.GetGitVersion() - if err != nil { - return errorRes(500, "Cannot get git version", err) - } - setData(ctx, "gitVersion", gitVersion) - - countUsers, err := db.CountAll(&db.User{}) - if err != nil { - return errorRes(500, "Cannot count users", err) - } - setData(ctx, "countUsers", countUsers) - - countGists, err := db.CountAll(&db.Gist{}) - if err != nil { - return errorRes(500, "Cannot count gists", err) - } - setData(ctx, "countGists", countGists) - - countKeys, err := db.CountAll(&db.SSHKey{}) - if err != nil { - return errorRes(500, "Cannot count SSH keys", err) - } - setData(ctx, "countKeys", countKeys) - - setData(ctx, "syncReposFromFS", actions.IsRunning(actions.SyncReposFromFS)) - setData(ctx, "syncReposFromDB", actions.IsRunning(actions.SyncReposFromDB)) - setData(ctx, "gitGcRepos", actions.IsRunning(actions.GitGcRepos)) - setData(ctx, "syncGistPreviews", actions.IsRunning(actions.SyncGistPreviews)) - setData(ctx, "resetHooks", actions.IsRunning(actions.ResetHooks)) - setData(ctx, "indexGists", actions.IsRunning(actions.IndexGists)) - return html(ctx, "admin_index.html") -} - -func adminUsers(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "admin.users")+" - "+trH(ctx, "admin.admin_panel")) - setData(ctx, "adminHeaderPage", "users") - pageInt := getPage(ctx) - - var data []*db.User - var err error - if data, err = db.GetAllUsers(pageInt - 1); err != nil { - return errorRes(500, "Cannot get users", err) - } - - if err = paginate(ctx, data, pageInt, 10, "data", "admin-panel/users", 1); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - return html(ctx, "admin_users.html") -} - -func adminGists(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "admin.gists")+" - "+trH(ctx, "admin.admin_panel")) - setData(ctx, "adminHeaderPage", "gists") - pageInt := getPage(ctx) - - var data []*db.Gist - var err error - if data, err = db.GetAllGists(pageInt - 1); err != nil { - return errorRes(500, "Cannot get gists", err) - } - - if err = paginate(ctx, data, pageInt, 10, "data", "admin-panel/gists", 1); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - return html(ctx, "admin_gists.html") -} - -func adminUserDelete(ctx echo.Context) error { - userId, _ := strconv.ParseUint(ctx.Param("user"), 10, 64) - user, err := db.GetUserById(uint(userId)) - if err != nil { - return errorRes(500, "Cannot retrieve user", err) - } - - if err := user.Delete(); err != nil { - return errorRes(500, "Cannot delete this user", err) - } - - addFlash(ctx, tr(ctx, "flash.admin.user-deleted"), "success") - return redirect(ctx, "/admin-panel/users") -} - -func adminGistDelete(ctx echo.Context) error { - gist, err := db.GetGistByID(ctx.Param("gist")) - if err != nil { - return errorRes(500, "Cannot retrieve gist", err) - } - - if err = gist.DeleteRepository(); err != nil { - return errorRes(500, "Cannot delete the repository", err) - } - - if err = gist.Delete(); err != nil { - return errorRes(500, "Cannot delete this gist", err) - } - - gist.RemoveFromIndex() - - addFlash(ctx, tr(ctx, "flash.admin.gist-deleted"), "success") - return redirect(ctx, "/admin-panel/gists") -} - -func adminSyncReposFromFS(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.sync-fs"), "success") - go actions.Run(actions.SyncReposFromFS) - return redirect(ctx, "/admin-panel") -} - -func adminSyncReposFromDB(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.sync-db"), "success") - go actions.Run(actions.SyncReposFromDB) - return redirect(ctx, "/admin-panel") -} - -func adminGcRepos(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.git-gc"), "success") - go actions.Run(actions.GitGcRepos) - return redirect(ctx, "/admin-panel") -} - -func adminSyncGistPreviews(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.sync-previews"), "success") - go actions.Run(actions.SyncGistPreviews) - return redirect(ctx, "/admin-panel") -} - -func adminResetHooks(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.reset-hooks"), "success") - go actions.Run(actions.ResetHooks) - return redirect(ctx, "/admin-panel") -} - -func adminIndexGists(ctx echo.Context) error { - addFlash(ctx, tr(ctx, "flash.admin.index-gists"), "success") - go actions.Run(actions.IndexGists) - return redirect(ctx, "/admin-panel") -} - -func adminConfig(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "admin.configuration")+" - "+trH(ctx, "admin.admin_panel")) - setData(ctx, "adminHeaderPage", "config") - - setData(ctx, "dbtype", db.DatabaseInfo.Type.String()) - setData(ctx, "dbname", db.DatabaseInfo.Database) - - return html(ctx, "admin_config.html") -} - -func adminSetConfig(ctx echo.Context) error { - key := ctx.FormValue("key") - value := ctx.FormValue("value") - - if err := db.UpdateSetting(key, value); err != nil { - return errorRes(500, "Cannot set setting", err) - } - - return ctx.JSON(200, map[string]interface{}{ - "success": true, - }) -} - -func adminInvitations(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "admin.invitations")+" - "+trH(ctx, "admin.admin_panel")) - setData(ctx, "adminHeaderPage", "invitations") - - var invitations []*db.Invitation - var err error - if invitations, err = db.GetAllInvitations(); err != nil { - return errorRes(500, "Cannot get invites", err) - } - - setData(ctx, "invitations", invitations) - return html(ctx, "admin_invitations.html") -} - -func adminInvitationsCreate(ctx echo.Context) error { - code := ctx.FormValue("code") - nbMax, err := strconv.ParseUint(ctx.FormValue("nbMax"), 10, 64) - if err != nil { - nbMax = 10 - } - - expiresAtUnix, err := strconv.ParseInt(ctx.FormValue("expiredAtUnix"), 10, 64) - if err != nil { - expiresAtUnix = time.Now().Unix() + 604800 // 1 week - } - - invitation := &db.Invitation{ - Code: code, - ExpiresAt: expiresAtUnix, - NbMax: uint(nbMax), - } - - if err := invitation.Create(); err != nil { - return errorRes(500, "Cannot create invitation", err) - } - - addFlash(ctx, tr(ctx, "flash.admin.invitation-created"), "success") - return redirect(ctx, "/admin-panel/invitations") -} - -func adminInvitationsDelete(ctx echo.Context) error { - id, _ := strconv.ParseUint(ctx.Param("id"), 10, 64) - invitation, err := db.GetInvitationByID(uint(id)) - if err != nil { - return errorRes(500, "Cannot retrieve invitation", err) - } - - if err := invitation.Delete(); err != nil { - return errorRes(500, "Cannot delete this invitation", err) - } - - addFlash(ctx, tr(ctx, "flash.admin.invitation-deleted"), "success") - return redirect(ctx, "/admin-panel/invitations") -} diff --git a/internal/web/auth.go b/internal/web/auth.go deleted file mode 100644 index 39d3d28..0000000 --- a/internal/web/auth.go +++ /dev/null @@ -1,815 +0,0 @@ -package web - -import ( - "bytes" - "context" - "crypto/md5" - gojson "encoding/json" - "errors" - "fmt" - "github.com/labstack/echo/v4" - "github.com/markbates/goth" - "github.com/markbates/goth/gothic" - "github.com/markbates/goth/providers/gitea" - "github.com/markbates/goth/providers/github" - "github.com/markbates/goth/providers/gitlab" - "github.com/markbates/goth/providers/openidConnect" - "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/auth/totp" - "github.com/thomiceli/opengist/internal/auth/webauthn" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/i18n" - "github.com/thomiceli/opengist/internal/utils" - "golang.org/x/text/cases" - "golang.org/x/text/language" - "gorm.io/gorm" - "io" - "net/http" - "net/url" - "strings" -) - -const ( - GitHubProvider = "github" - GitLabProvider = "gitlab" - GiteaProvider = "gitea" - OpenIDConnect = "openid-connect" -) - -func register(ctx echo.Context) error { - disableSignup := getData(ctx, "DisableSignup") - disableForm := getData(ctx, "DisableLoginForm") - - code := ctx.QueryParam("code") - if code != "" { - if invitation, err := db.GetInvitationByCode(code); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Cannot check for invitation code", err) - } else if invitation != nil && invitation.IsUsable() { - disableSignup = false - } - } - - setData(ctx, "title", trH(ctx, "auth.new-account")) - setData(ctx, "htmlTitle", trH(ctx, "auth.new-account")) - setData(ctx, "disableForm", disableForm) - setData(ctx, "disableSignup", disableSignup) - setData(ctx, "isLoginPage", false) - return html(ctx, "auth_form.html") -} - -func processRegister(ctx echo.Context) error { - disableSignup := getData(ctx, "DisableSignup") - - code := ctx.QueryParam("code") - invitation, err := db.GetInvitationByCode(code) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Cannot check for invitation code", err) - } else if invitation.ID != 0 && invitation.IsUsable() { - disableSignup = false - } - - if disableSignup == true { - return errorRes(403, tr(ctx, "error.signup-disabled"), nil) - } - - if getData(ctx, "DisableLoginForm") == true { - return errorRes(403, tr(ctx, "error.signup-disabled-form"), nil) - } - - setData(ctx, "title", trH(ctx, "auth.new-account")) - setData(ctx, "htmlTitle", trH(ctx, "auth.new-account")) - - sess := getSession(ctx) - - dto := new(db.UserDTO) - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, utils.ValidationMessages(&err, getData(ctx, "locale").(*i18n.Locale)), "error") - return html(ctx, "auth_form.html") - } - - if exists, err := db.UserExists(dto.Username); err != nil || exists { - addFlash(ctx, tr(ctx, "flash.auth.username-exists"), "error") - return html(ctx, "auth_form.html") - } - - user := dto.ToUser() - - password, err := utils.Argon2id.Hash(user.Password) - if err != nil { - return errorRes(500, "Cannot hash password", err) - } - user.Password = password - - if err = user.Create(); err != nil { - return errorRes(500, "Cannot create user", err) - } - - if user.ID == 1 { - if err = user.SetAdmin(); err != nil { - return errorRes(500, "Cannot set user admin", err) - } - } - - if invitation.ID != 0 { - if err := invitation.Use(); err != nil { - return errorRes(500, "Cannot use invitation", err) - } - } - - sess.Values["user"] = user.ID - saveSession(sess, ctx) - - return redirect(ctx, "/") -} - -func login(ctx echo.Context) error { - setData(ctx, "title", trH(ctx, "auth.login")) - setData(ctx, "htmlTitle", trH(ctx, "auth.login")) - setData(ctx, "disableForm", getData(ctx, "DisableLoginForm")) - setData(ctx, "isLoginPage", true) - return html(ctx, "auth_form.html") -} - -func processLogin(ctx echo.Context) error { - if getData(ctx, "DisableLoginForm") == true { - return errorRes(403, tr(ctx, "error.login-disabled-form"), nil) - } - - var err error - sess := getSession(ctx) - - dto := &db.UserDTO{} - if err = ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - password := dto.Password - - var user *db.User - - if user, err = db.GetUserByUsername(dto.Username); err != nil { - if !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Cannot get user", err) - } - log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - addFlash(ctx, tr(ctx, "flash.auth.invalid-credentials"), "error") - return redirect(ctx, "/login") - } - - if ok, err := utils.Argon2id.Verify(password, user.Password); !ok { - if err != nil { - return errorRes(500, "Cannot check for password", err) - } - log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - addFlash(ctx, tr(ctx, "flash.auth.invalid-credentials"), "error") - return redirect(ctx, "/login") - } - - // handle MFA - var hasWebauthn, hasTotp bool - if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { - return errorRes(500, "Cannot check for user MFA", err) - } - if hasWebauthn || hasTotp { - sess.Values["mfaID"] = user.ID - sess.Options.MaxAge = 5 * 60 // 5 minutes - saveSession(sess, ctx) - return redirect(ctx, "/mfa") - } - - sess.Values["user"] = user.ID - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - saveSession(sess, ctx) - deleteCsrfCookie(ctx) - - return redirect(ctx, "/") -} - -func mfa(ctx echo.Context) error { - var err error - - user := db.User{ID: getSession(ctx).Values["mfaID"].(uint)} - - var hasWebauthn, hasTotp bool - if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { - return errorRes(500, "Cannot check for user MFA", err) - } - - setData(ctx, "hasWebauthn", hasWebauthn) - setData(ctx, "hasTotp", hasTotp) - - return html(ctx, "mfa.html") -} - -func oauthCallback(ctx echo.Context) error { - user, err := gothic.CompleteUserAuth(ctx.Response(), ctx.Request()) - if err != nil { - return errorRes(400, tr(ctx, "error.complete-oauth-login", err.Error()), err) - } - - currUser := getUserLogged(ctx) - if currUser != nil { - // if user is logged in, link account to user and update its avatar URL - updateUserProviderInfo(currUser, user.Provider, user) - - if err = currUser.Update(); err != nil { - return errorRes(500, "Cannot update user "+cases.Title(language.English).String(user.Provider)+" id", err) - } - - addFlash(ctx, tr(ctx, "flash.auth.account-linked-oauth", cases.Title(language.English).String(user.Provider)), "success") - return redirect(ctx, "/settings") - } - - // if user is not in database, create it - userDB, err := db.GetUserByProvider(user.UserID, user.Provider) - if err != nil { - if getData(ctx, "DisableSignup") == true { - return errorRes(403, tr(ctx, "error.signup-disabled"), nil) - } - - if !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Cannot get user", err) - } - - if user.NickName == "" { - user.NickName = strings.Split(user.Email, "@")[0] - } - - userDB = &db.User{ - Username: user.NickName, - Email: user.Email, - MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))), - } - - // set provider id and avatar URL - updateUserProviderInfo(userDB, user.Provider, user) - - if err = userDB.Create(); err != nil { - if db.IsUniqueConstraintViolation(err) { - addFlash(ctx, tr(ctx, "flash.auth.username-exists"), "error") - return redirect(ctx, "/login") - } - - return errorRes(500, "Cannot create user", err) - } - - if userDB.ID == 1 { - if err = userDB.SetAdmin(); err != nil { - return errorRes(500, "Cannot set user admin", err) - } - } - - var resp *http.Response - switch user.Provider { - case GitHubProvider: - resp, err = http.Get("https://github.com/" + user.NickName + ".keys") - case GitLabProvider: - resp, err = http.Get(urlJoin(config.C.GitlabUrl, user.NickName+".keys")) - case GiteaProvider: - resp, err = http.Get(urlJoin(config.C.GiteaUrl, user.NickName+".keys")) - case OpenIDConnect: - err = errors.New("cannot get keys from OIDC provider") - } - - if err == nil { - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - addFlash(ctx, tr(ctx, "flash.auth.user-sshkeys-not-retrievable"), "error") - log.Error().Err(err).Msg("Could not get user keys") - } - - keys := strings.Split(string(body), "\n") - if len(keys[len(keys)-1]) == 0 { - keys = keys[:len(keys)-1] - } - for _, key := range keys { - sshKey := db.SSHKey{ - Title: "Added from " + user.Provider, - Content: key, - User: *userDB, - } - - if err = sshKey.Create(); err != nil { - addFlash(ctx, tr(ctx, "flash.auth.user-sshkeys-not-created"), "error") - log.Error().Err(err).Msg("Could not create ssh key") - } - } - } - } - - sess := getSession(ctx) - sess.Values["user"] = userDB.ID - saveSession(sess, ctx) - deleteCsrfCookie(ctx) - - return redirect(ctx, "/") -} - -func oauth(ctx echo.Context) error { - provider := ctx.Param("provider") - - httpProtocol := "http" - if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { - httpProtocol = "https" - } - - forwarded_hdr := ctx.Request().Header.Get("Forwarded") - if forwarded_hdr != "" { - fields := strings.Split(forwarded_hdr, ";") - fwd := make(map[string]string) - for _, v := range fields { - p := strings.Split(v, "=") - fwd[p[0]] = p[1] - } - val, ok := fwd["proto"] - if ok && val == "https" { - httpProtocol = "https" - } - } - - var opengistUrl string - if config.C.ExternalUrl != "" { - opengistUrl = config.C.ExternalUrl - } else { - opengistUrl = httpProtocol + "://" + ctx.Request().Host - } - - switch provider { - case GitHubProvider: - goth.UseProviders( - github.New( - config.C.GithubClientKey, - config.C.GithubSecret, - urlJoin(opengistUrl, "/oauth/github/callback"), - ), - ) - - case GitLabProvider: - goth.UseProviders( - gitlab.NewCustomisedURL( - config.C.GitlabClientKey, - config.C.GitlabSecret, - urlJoin(opengistUrl, "/oauth/gitlab/callback"), - urlJoin(config.C.GitlabUrl, "/oauth/authorize"), - urlJoin(config.C.GitlabUrl, "/oauth/token"), - urlJoin(config.C.GitlabUrl, "/api/v4/user"), - ), - ) - - case GiteaProvider: - goth.UseProviders( - gitea.NewCustomisedURL( - config.C.GiteaClientKey, - config.C.GiteaSecret, - urlJoin(opengistUrl, "/oauth/gitea/callback"), - urlJoin(config.C.GiteaUrl, "/login/oauth/authorize"), - urlJoin(config.C.GiteaUrl, "/login/oauth/access_token"), - urlJoin(config.C.GiteaUrl, "/api/v1/user"), - ), - ) - case OpenIDConnect: - oidcProvider, err := openidConnect.New( - config.C.OIDCClientKey, - config.C.OIDCSecret, - urlJoin(opengistUrl, "/oauth/openid-connect/callback"), - config.C.OIDCDiscoveryUrl, - "openid", - "email", - "profile", - ) - - if err != nil { - return errorRes(500, "Cannot create OIDC provider", err) - } - - goth.UseProviders(oidcProvider) - } - - ctxValue := context.WithValue(ctx.Request().Context(), gothic.ProviderParamKey, provider) - ctx.SetRequest(ctx.Request().WithContext(ctxValue)) - if provider != GitHubProvider && provider != GitLabProvider && provider != GiteaProvider && provider != OpenIDConnect { - return errorRes(400, tr(ctx, "error.oauth-unsupported"), nil) - } - - gothic.BeginAuthHandler(ctx.Response(), ctx.Request()) - return nil -} - -func oauthUnlink(ctx echo.Context) error { - provider := ctx.Param("provider") - - currUser := getUserLogged(ctx) - // Map each provider to a function that checks the relevant ID in currUser - providerIDCheckMap := map[string]func() bool{ - GitHubProvider: func() bool { return currUser.GithubID != "" }, - GitLabProvider: func() bool { return currUser.GitlabID != "" }, - GiteaProvider: func() bool { return currUser.GiteaID != "" }, - OpenIDConnect: func() bool { return currUser.OIDCID != "" }, - } - - if checkFunc, exists := providerIDCheckMap[provider]; exists && checkFunc() { - if err := currUser.DeleteProviderID(provider); err != nil { - return errorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(provider), err) - } - - addFlash(ctx, tr(ctx, "flash.auth.account-unlinked-oauth", cases.Title(language.English).String(provider)), "success") - return redirect(ctx, "/settings") - } - - return redirect(ctx, "/settings") -} - -func beginWebAuthnBinding(ctx echo.Context) error { - credsCreation, jsonWaSession, err := webauthn.BeginBinding(getUserLogged(ctx)) - if err != nil { - return errorRes(500, "Cannot begin WebAuthn registration", err) - } - - sess := getSession(ctx) - sess.Values["webauthn_registration_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - saveSession(sess, ctx) - - return ctx.JSON(200, credsCreation) -} - -func finishWebAuthnBinding(ctx echo.Context) error { - sess := getSession(ctx) - jsonWaSession, ok := sess.Values["webauthn_registration_session"].([]byte) - if !ok { - return jsonErrorRes(401, "Cannot get WebAuthn registration session", nil) - } - - user := getUserLogged(ctx) - - // extract passkey name from request - body, err := io.ReadAll(ctx.Request().Body) - if err != nil { - return jsonErrorRes(400, "Failed to read request body", err) - } - ctx.Request().Body.Close() - ctx.Request().Body = io.NopCloser(bytes.NewBuffer(body)) - - dto := new(db.CrendentialDTO) - _ = gojson.Unmarshal(body, &dto) - - if err = ctx.Validate(dto); err != nil { - return jsonErrorRes(400, "Invalid request", err) - } - passkeyName := dto.PasskeyName - if passkeyName == "" { - passkeyName = "WebAuthn" - } - - waCredential, err := webauthn.FinishBinding(user, jsonWaSession, ctx.Request()) - if err != nil { - return jsonErrorRes(403, "Failed binding attempt for passkey", err) - } - - if _, err = db.CreateFromCrendential(user.ID, passkeyName, waCredential); err != nil { - return jsonErrorRes(500, "Cannot create WebAuthn credential on database", err) - } - - delete(sess.Values, "webauthn_registration_session") - saveSession(sess, ctx) - - addFlash(ctx, tr(ctx, "flash.auth.passkey-registred", passkeyName), "success") - return json(ctx, []string{"OK"}) -} - -func beginWebAuthnLogin(ctx echo.Context) error { - credsCreation, jsonWaSession, err := webauthn.BeginDiscoverableLogin() - if err != nil { - return jsonErrorRes(401, "Cannot begin WebAuthn login", err) - } - - sess := getSession(ctx) - sess.Values["webauthn_login_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - saveSession(sess, ctx) - - return json(ctx, credsCreation) -} - -func finishWebAuthnLogin(ctx echo.Context) error { - sess := getSession(ctx) - sessionData, ok := sess.Values["webauthn_login_session"].([]byte) - if !ok { - return jsonErrorRes(401, "Cannot get WebAuthn login session", nil) - } - - userID, err := webauthn.FinishDiscoverableLogin(sessionData, ctx.Request()) - if err != nil { - return jsonErrorRes(403, "Failed authentication attempt for passkey", err) - } - - sess.Values["user"] = userID - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - - delete(sess.Values, "webauthn_login_session") - saveSession(sess, ctx) - - return json(ctx, []string{"OK"}) -} - -func beginWebAuthnAssertion(ctx echo.Context) error { - sess := getSession(ctx) - - ogUser, err := db.GetUserById(sess.Values["mfaID"].(uint)) - if err != nil { - return jsonErrorRes(500, "Cannot get user", err) - } - - credsCreation, jsonWaSession, err := webauthn.BeginLogin(ogUser) - if err != nil { - return jsonErrorRes(401, "Cannot begin WebAuthn login", err) - } - - sess.Values["webauthn_assertion_session"] = jsonWaSession - sess.Options.MaxAge = 5 * 60 // 5 minutes - saveSession(sess, ctx) - - return json(ctx, credsCreation) -} - -func finishWebAuthnAssertion(ctx echo.Context) error { - sess := getSession(ctx) - sessionData, ok := sess.Values["webauthn_assertion_session"].([]byte) - if !ok { - return jsonErrorRes(401, "Cannot get WebAuthn assertion session", nil) - } - - userId := sess.Values["mfaID"].(uint) - - ogUser, err := db.GetUserById(userId) - if err != nil { - return jsonErrorRes(500, "Cannot get user", err) - } - - if err = webauthn.FinishLogin(ogUser, sessionData, ctx.Request()); err != nil { - return jsonErrorRes(403, "Failed authentication attempt for passkey", err) - } - - sess.Values["user"] = userId - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - - delete(sess.Values, "webauthn_assertion_session") - delete(sess.Values, "mfaID") - saveSession(sess, ctx) - - return json(ctx, []string{"OK"}) -} - -func beginTotp(ctx echo.Context) error { - user := getUserLogged(ctx) - - if _, hasTotp, err := user.HasMFA(); err != nil { - return errorRes(500, "Cannot check for user MFA", err) - } else if hasTotp { - addFlash(ctx, tr(ctx, "auth.totp.already-enabled"), "error") - return redirect(ctx, "/settings") - } - - ogUrl, err := url.Parse(getData(ctx, "baseHttpUrl").(string)) - if err != nil { - return errorRes(500, "Cannot parse base URL", err) - } - - sess := getSession(ctx) - generatedSecret, _ := sess.Values["generatedSecret"].([]byte) - - totpSecret, qrcode, err, generatedSecret := totp.GenerateQRCode(getUserLogged(ctx).Username, ogUrl.Hostname(), generatedSecret) - if err != nil { - return errorRes(500, "Cannot generate TOTP QR code", err) - } - sess.Values["totpSecret"] = totpSecret - sess.Values["generatedSecret"] = generatedSecret - saveSession(sess, ctx) - - setData(ctx, "totpSecret", totpSecret) - setData(ctx, "totpQrcode", qrcode) - - return html(ctx, "totp.html") - -} - -func finishTotp(ctx echo.Context) error { - user := getUserLogged(ctx) - - if _, hasTotp, err := user.HasMFA(); err != nil { - return errorRes(500, "Cannot check for user MFA", err) - } else if hasTotp { - addFlash(ctx, tr(ctx, "auth.totp.already-enabled"), "error") - return redirect(ctx, "/settings") - } - - dto := &db.TOTPDTO{} - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, "Invalid secret", "error") - return redirect(ctx, "/settings/totp/generate") - } - - sess := getSession(ctx) - secret, ok := sess.Values["totpSecret"].(string) - if !ok { - return errorRes(500, "Cannot get TOTP secret from session", nil) - } - - if !totp.Validate(dto.Code, secret) { - addFlash(ctx, tr(ctx, "auth.totp.invalid-code"), "error") - - return redirect(ctx, "/settings/totp/generate") - } - - userTotp := &db.TOTP{ - UserID: getUserLogged(ctx).ID, - } - if err := userTotp.StoreSecret(secret); err != nil { - return errorRes(500, "Cannot store TOTP secret", err) - } - - if err := userTotp.Create(); err != nil { - return errorRes(500, "Cannot create TOTP", err) - } - - addFlash(ctx, "TOTP successfully enabled", "success") - codes, err := userTotp.GenerateRecoveryCodes() - if err != nil { - return errorRes(500, "Cannot generate recovery codes", err) - } - - delete(sess.Values, "totpSecret") - delete(sess.Values, "generatedSecret") - saveSession(sess, ctx) - - setData(ctx, "recoveryCodes", codes) - return html(ctx, "totp.html") -} - -func assertTotp(ctx echo.Context) error { - var err error - dto := &db.TOTPDTO{} - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, tr(ctx, "auth.totp.invalid-code"), "error") - return redirect(ctx, "/mfa") - } - - sess := getSession(ctx) - userId := sess.Values["mfaID"].(uint) - var userTotp *db.TOTP - if userTotp, err = db.GetTOTPByUserID(userId); err != nil { - return errorRes(500, "Cannot get TOTP by UID", err) - } - - redirectUrl := "/" - - var validCode, validRecoveryCode bool - if validCode, err = userTotp.ValidateCode(dto.Code); err != nil { - return errorRes(500, "Cannot validate TOTP code", err) - } - if !validCode { - validRecoveryCode, err = userTotp.ValidateRecoveryCode(dto.Code) - if err != nil { - return errorRes(500, "Cannot validate TOTP code", err) - } - - if !validRecoveryCode { - addFlash(ctx, tr(ctx, "auth.totp.invalid-code"), "error") - return redirect(ctx, "/mfa") - } - - addFlash(ctx, tr(ctx, "auth.totp.code-used", dto.Code), "warning") - redirectUrl = "/settings" - } - - sess.Values["user"] = userId - sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year - delete(sess.Values, "mfaID") - saveSession(sess, ctx) - - return redirect(ctx, redirectUrl) -} - -func disableTotp(ctx echo.Context) error { - user := getUserLogged(ctx) - userTotp, err := db.GetTOTPByUserID(user.ID) - if err != nil { - return errorRes(500, "Cannot get TOTP by UID", err) - } - - if err = userTotp.Delete(); err != nil { - return errorRes(500, "Cannot delete TOTP", err) - } - - addFlash(ctx, tr(ctx, "auth.totp.disabled"), "success") - return redirect(ctx, "/settings") -} - -func regenerateTotpRecoveryCodes(ctx echo.Context) error { - user := getUserLogged(ctx) - userTotp, err := db.GetTOTPByUserID(user.ID) - if err != nil { - return errorRes(500, "Cannot get TOTP by UID", err) - } - - codes, err := userTotp.GenerateRecoveryCodes() - if err != nil { - return errorRes(500, "Cannot generate recovery codes", err) - } - - setData(ctx, "recoveryCodes", codes) - return html(ctx, "totp.html") -} - -func logout(ctx echo.Context) error { - deleteSession(ctx) - deleteCsrfCookie(ctx) - return redirect(ctx, "/all") -} - -func urlJoin(base string, elem ...string) string { - joined, err := url.JoinPath(base, elem...) - if err != nil { - log.Error().Err(err).Msg("Cannot join url") - } - - return joined -} - -func updateUserProviderInfo(userDB *db.User, provider string, user goth.User) { - userDB.AvatarURL = getAvatarUrlFromProvider(provider, user.UserID) - switch provider { - case GitHubProvider: - userDB.GithubID = user.UserID - case GitLabProvider: - userDB.GitlabID = user.UserID - case GiteaProvider: - userDB.GiteaID = user.UserID - case OpenIDConnect: - userDB.OIDCID = user.UserID - userDB.AvatarURL = user.AvatarURL - } -} - -func getAvatarUrlFromProvider(provider string, identifier string) string { - switch provider { - case GitHubProvider: - return "https://avatars.githubusercontent.com/u/" + identifier + "?v=4" - case GitLabProvider: - return urlJoin(config.C.GitlabUrl, "/uploads/-/system/user/avatar/", identifier, "/avatar.png") + "?width=400" - case GiteaProvider: - resp, err := http.Get(urlJoin(config.C.GiteaUrl, "/api/v1/users/", identifier)) - if err != nil { - log.Error().Err(err).Msg("Cannot get user from Gitea") - return "" - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - log.Error().Err(err).Msg("Cannot read Gitea response body") - return "" - } - - var result map[string]interface{} - err = gojson.Unmarshal(body, &result) - if err != nil { - log.Error().Err(err).Msg("Cannot unmarshal Gitea response body") - return "" - } - - field, ok := result["avatar_url"] - if !ok { - log.Error().Msg("Field 'avatar_url' not found in Gitea JSON response") - return "" - } - return field.(string) - } - return "" -} - -type ContextAuthInfo struct { - context echo.Context -} - -func (auth ContextAuthInfo) RequireLogin() (bool, error) { - return getData(auth.context, "RequireLogin") == true, nil -} - -func (auth ContextAuthInfo) AllowGistsWithoutLogin() (bool, error) { - return getData(auth.context, "AllowGistsWithoutLogin") == true, nil -} diff --git a/internal/web/context/context.go b/internal/web/context/context.go new file mode 100644 index 0000000..e10590e --- /dev/null +++ b/internal/web/context/context.go @@ -0,0 +1,145 @@ +package context + +import ( + "context" + "github.com/gorilla/sessions" + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "html/template" + "net/http" + "sync" +) + +type dataKey string + +const DataKeyStr dataKey = "data" + +type Context struct { + echo.Context + + data echo.Map + lock sync.RWMutex + + store *Store + User *db.User +} + +func NewContext(c echo.Context, sessionPath string) *Context { + return &Context{ + Context: c, + data: make(echo.Map), + store: NewStore(sessionPath), + } +} + +func (ctx *Context) SetData(key string, value any) { + ctx.lock.Lock() + defer ctx.lock.Unlock() + + ctx.data[key] = value +} + +func (ctx *Context) GetData(key string) any { + ctx.lock.RLock() + defer ctx.lock.RUnlock() + + return ctx.data[key] +} + +func (ctx *Context) DataMap() echo.Map { + return ctx.data +} + +func (ctx *Context) ErrorRes(code int, message string, err error) error { + if code >= 500 { + var skipLogger = log.With().CallerWithSkipFrameCount(3).Logger() + skipLogger.Error().Err(err).Msg(message) + } + + ctx.SetRequest(ctx.Request().WithContext(context.WithValue(ctx.Request().Context(), DataKeyStr, ctx.data))) + + return &echo.HTTPError{Code: code, Message: message, Internal: err} +} + +func (ctx *Context) RedirectTo(location string) error { + return ctx.Context.Redirect(302, config.C.ExternalUrl+location) +} + +func (ctx *Context) Html(template string) error { + return ctx.HtmlWithCode(200, template) +} + +func (ctx *Context) HtmlWithCode(code int, template string) error { + ctx.setErrorFlashes() + return ctx.Render(code, template, ctx.DataMap()) +} + +func (ctx *Context) Json(data any) error { + return ctx.JsonWithCode(200, data) +} + +func (ctx *Context) JsonWithCode(code int, data any) error { + return ctx.JSON(code, data) +} + +func (ctx *Context) PlainText(code int, message string) error { + return ctx.String(code, message) +} + +func (ctx *Context) NotFound(message string) error { + return ctx.ErrorRes(404, message, nil) +} + +func (ctx *Context) GetSession() *sessions.Session { + sess, _ := ctx.store.UserStore.Get(ctx.Request(), "session") + return sess +} + +func (ctx *Context) SaveSession(sess *sessions.Session) { + _ = sess.Save(ctx.Request(), ctx.Response()) +} + +func (ctx *Context) DeleteSession() { + sess := ctx.GetSession() + sess.Options.MaxAge = -1 + ctx.SaveSession(sess) +} + +func (ctx *Context) AddFlash(flashMessage string, flashType string) { + sess, _ := ctx.store.flashStore.Get(ctx.Request(), "flash") + sess.AddFlash(flashMessage, flashType) + _ = sess.Save(ctx.Request(), ctx.Response()) +} + +func (ctx *Context) setErrorFlashes() { + sess, _ := ctx.store.flashStore.Get(ctx.Request(), "flash") + + ctx.SetData("flashErrors", sess.Flashes("error")) + ctx.SetData("flashSuccess", sess.Flashes("success")) + ctx.SetData("flashWarnings", sess.Flashes("warning")) + + _ = sess.Save(ctx.Request(), ctx.Response()) +} + +func (ctx *Context) DeleteCsrfCookie() { + ctx.SetCookie(&http.Cookie{Name: "_csrf", Path: "/", MaxAge: -1}) +} + +func (ctx *Context) TrH(key string, args ...any) template.HTML { + l := ctx.GetData("locale").(*i18n.Locale) + return l.Tr(key, args...) +} + +func (ctx *Context) Tr(key string, args ...any) string { + l := ctx.GetData("locale").(*i18n.Locale) + return l.String(key, args...) +} + +var ManifestEntries map[string]Asset + +type Asset struct { + File string `json:"file"` +} diff --git a/internal/web/context/store.go b/internal/web/context/store.go new file mode 100644 index 0000000..5998f22 --- /dev/null +++ b/internal/web/context/store.go @@ -0,0 +1,28 @@ +package context + +import ( + "github.com/gorilla/sessions" + "github.com/markbates/goth/gothic" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/session" + "path/filepath" +) + +type Store struct { + sessionsPath string + + flashStore *sessions.CookieStore + UserStore *sessions.FilesystemStore +} + +func NewStore(sessionsPath string) *Store { + s := &Store{sessionsPath: sessionsPath} + + s.flashStore = sessions.NewCookieStore([]byte("opengist")) + encryptKey, _ := session.GenerateSecretKey(filepath.Join(s.sessionsPath, "session-encrypt.key")) + s.UserStore = sessions.NewFilesystemStore(s.sessionsPath, config.SecretKey, encryptKey) + s.UserStore.MaxLength(10 * 1024) + gothic.Store = s.UserStore + + return s +} diff --git a/internal/web/gist.go b/internal/web/gist.go deleted file mode 100644 index f7d3564..0000000 --- a/internal/web/gist.go +++ /dev/null @@ -1,917 +0,0 @@ -package web - -import ( - "archive/zip" - "bufio" - "bytes" - gojson "encoding/json" - "errors" - "fmt" - "html/template" - "net/url" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" - - "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/git" - "github.com/thomiceli/opengist/internal/i18n" - "github.com/thomiceli/opengist/internal/index" - "github.com/thomiceli/opengist/internal/render" - "github.com/thomiceli/opengist/internal/utils" - - "github.com/google/uuid" - "github.com/labstack/echo/v4" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "gorm.io/gorm" -) - -func gistInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - currUser := getUserLogged(ctx) - - userName := ctx.Param("user") - gistName := ctx.Param("gistname") - - switch filepath.Ext(gistName) { - case ".js": - setData(ctx, "gistpage", "js") - gistName = strings.TrimSuffix(gistName, ".js") - case ".json": - setData(ctx, "gistpage", "json") - gistName = strings.TrimSuffix(gistName, ".json") - case ".git": - setData(ctx, "gistpage", "git") - gistName = strings.TrimSuffix(gistName, ".git") - } - - gist, err := db.GetGist(userName, gistName) - if err != nil { - return notFound("Gist not found") - } - - if gist.Private == db.PrivateVisibility { - if currUser == nil || currUser.ID != gist.UserID { - return notFound("Gist not found") - } - } - - setData(ctx, "gist", gist) - - if config.C.SshGit { - var sshDomain string - - if config.C.SshExternalDomain != "" { - sshDomain = config.C.SshExternalDomain - } else { - sshDomain = strings.Split(ctx.Request().Host, ":")[0] - } - - if config.C.SshPort == "22" { - setData(ctx, "sshCloneUrl", sshDomain+":"+userName+"/"+gistName+".git") - } else { - setData(ctx, "sshCloneUrl", "ssh://"+sshDomain+":"+config.C.SshPort+"/"+userName+"/"+gistName+".git") - } - } - - baseHttpUrl := getData(ctx, "baseHttpUrl").(string) - - if config.C.HttpGit { - setData(ctx, "httpCloneUrl", baseHttpUrl+"/"+userName+"/"+gistName+".git") - } - - setData(ctx, "httpCopyUrl", baseHttpUrl+"/"+userName+"/"+gistName) - setData(ctx, "currentUrl", template.URL(ctx.Request().URL.Path)) - setData(ctx, "embedScript", fmt.Sprintf(``, baseHttpUrl+"/"+userName+"/"+gistName+".js")) - - nbCommits, err := gist.NbCommits() - if err != nil { - return errorRes(500, "Error fetching number of commits", err) - } - setData(ctx, "nbCommits", nbCommits) - - if currUser != nil { - hasLiked, err := currUser.HasLiked(gist) - if err != nil { - return errorRes(500, "Cannot get user like status", err) - } - setData(ctx, "hasLiked", hasLiked) - } - - if gist.Private > 0 { - setData(ctx, "NoIndex", true) - } - - return next(ctx) - } -} - -// gistSoftInit try to load a gist (same as gistInit) but does not return a 404 if the gist is not found -// useful for git clients using HTTP to obfuscate the existence of a private gist -func gistSoftInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - userName := ctx.Param("user") - gistName := ctx.Param("gistname") - - gistName = strings.TrimSuffix(gistName, ".git") - - gist, _ := db.GetGist(userName, gistName) - setData(ctx, "gist", gist) - - return next(ctx) - } -} - -// gistNewPushInit has the same behavior as gistSoftInit but create a new gist empty instead -func gistNewPushSoftInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - setData(c, "gist", new(db.Gist)) - return next(c) - } -} - -func allGists(ctx echo.Context) error { - var err error - var urlPage string - - fromUserStr := ctx.Param("user") - userLogged := getUserLogged(ctx) - pageInt := getPage(ctx) - - sort := "created" - sortText := trH(ctx, "gist.list.sort-by-created") - order := "desc" - orderText := trH(ctx, "gist.list.order-by-desc") - - if ctx.QueryParam("sort") == "updated" { - sort = "updated" - sortText = trH(ctx, "gist.list.sort-by-updated") - } - - if ctx.QueryParam("order") == "asc" { - order = "asc" - orderText = trH(ctx, "gist.list.order-by-asc") - } - - setData(ctx, "sort", sortText) - setData(ctx, "order", orderText) - - var gists []*db.Gist - var currentUserId uint - if userLogged != nil { - currentUserId = userLogged.ID - } else { - currentUserId = 0 - } - - if fromUserStr == "" { - urlctx := ctx.Request().URL.Path - if strings.HasSuffix(urlctx, "search") { - setData(ctx, "htmlTitle", trH(ctx, "gist.list.search-results")) - setData(ctx, "mode", "search") - setData(ctx, "searchQuery", ctx.QueryParam("q")) - setData(ctx, "searchQueryUrl", template.URL("&q="+ctx.QueryParam("q"))) - urlPage = "search" - gists, err = db.GetAllGistsFromSearch(currentUserId, ctx.QueryParam("q"), pageInt-1, sort, order) - } else if strings.HasSuffix(urlctx, "all") { - setData(ctx, "htmlTitle", trH(ctx, "gist.list.all")) - setData(ctx, "mode", "all") - urlPage = "all" - gists, err = db.GetAllGistsForCurrentUser(currentUserId, pageInt-1, sort, order) - } - } else { - liked := false - forked := false - - liked, err = regexp.MatchString(`/[^/]*/liked`, ctx.Request().URL.Path) - if err != nil { - return errorRes(500, "Error matching regexp", err) - } - - forked, err = regexp.MatchString(`/[^/]*/forked`, ctx.Request().URL.Path) - if err != nil { - return errorRes(500, "Error matching regexp", err) - } - - var fromUser *db.User - - fromUser, err = db.GetUserByUsername(fromUserStr) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return notFound("User not found") - } - return errorRes(500, "Error fetching user", err) - } - setData(ctx, "fromUser", fromUser) - - if countFromUser, err := db.CountAllGistsFromUser(fromUser.ID, currentUserId); err != nil { - return errorRes(500, "Error counting gists", err) - } else { - setData(ctx, "countFromUser", countFromUser) - } - - if countLiked, err := db.CountAllGistsLikedByUser(fromUser.ID, currentUserId); err != nil { - return errorRes(500, "Error counting liked gists", err) - } else { - setData(ctx, "countLiked", countLiked) - } - - if countForked, err := db.CountAllGistsForkedByUser(fromUser.ID, currentUserId); err != nil { - return errorRes(500, "Error counting forked gists", err) - } else { - setData(ctx, "countForked", countForked) - } - - if liked { - urlPage = fromUserStr + "/liked" - setData(ctx, "htmlTitle", trH(ctx, "gist.list.all-liked-by", fromUserStr)) - setData(ctx, "mode", "liked") - gists, err = db.GetAllGistsLikedByUser(fromUser.ID, currentUserId, pageInt-1, sort, order) - } else if forked { - urlPage = fromUserStr + "/forked" - setData(ctx, "htmlTitle", trH(ctx, "gist.list.all-forked-by", fromUserStr)) - setData(ctx, "mode", "forked") - gists, err = db.GetAllGistsForkedByUser(fromUser.ID, currentUserId, pageInt-1, sort, order) - } else { - urlPage = fromUserStr - setData(ctx, "htmlTitle", trH(ctx, "gist.list.all-from", fromUserStr)) - setData(ctx, "mode", "fromUser") - gists, err = db.GetAllGistsFromUser(fromUser.ID, currentUserId, pageInt-1, sort, order) - } - } - - renderedGists := make([]*render.RenderedGist, 0, len(gists)) - for _, gist := range gists { - rendered, err := render.HighlightGistPreview(gist) - if err != nil { - log.Error().Err(err).Msg("Error rendering gist preview for " + gist.Identifier() + " - " + gist.PreviewFilename) - } - renderedGists = append(renderedGists, &rendered) - } - - if err != nil { - return errorRes(500, "Error fetching gists", err) - } - - if err = paginate(ctx, renderedGists, pageInt, 10, "gists", fromUserStr, 2, "&sort="+sort+"&order="+order); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - setData(ctx, "urlPage", urlPage) - return html(ctx, "all.html") -} - -func search(ctx echo.Context) error { - var err error - - content, meta := parseSearchQueryStr(ctx.QueryParam("q")) - pageInt := getPage(ctx) - - var currentUserId uint - userLogged := getUserLogged(ctx) - if userLogged != nil { - currentUserId = userLogged.ID - } else { - currentUserId = 0 - } - - var visibleGistsIds []uint - visibleGistsIds, err = db.GetAllGistsVisibleByUser(currentUserId) - if err != nil { - return errorRes(500, "Error fetching gists", err) - } - - gistsIds, nbHits, langs, err := index.SearchGists(content, index.SearchGistMetadata{ - Username: meta["user"], - Title: meta["title"], - Filename: meta["filename"], - Extension: meta["extension"], - Language: meta["language"], - }, visibleGistsIds, pageInt) - if err != nil { - return errorRes(500, "Error searching gists", err) - } - - gists, err := db.GetAllGistsByIds(gistsIds) - if err != nil { - return errorRes(500, "Error fetching gists", err) - } - - renderedGists := make([]*render.RenderedGist, 0, len(gists)) - for _, gist := range gists { - rendered, err := render.HighlightGistPreview(gist) - if err != nil { - log.Error().Err(err).Msg("Error rendering gist preview for " + gist.Identifier() + " - " + gist.PreviewFilename) - } - renderedGists = append(renderedGists, &rendered) - } - - if pageInt > 1 && len(renderedGists) != 0 { - setData(ctx, "prevPage", pageInt-1) - } - if 10*pageInt < int(nbHits) { - setData(ctx, "nextPage", pageInt+1) - } - setData(ctx, "prevLabel", trH(ctx, "pagination.previous")) - setData(ctx, "nextLabel", trH(ctx, "pagination.next")) - setData(ctx, "urlPage", "search") - setData(ctx, "urlParams", template.URL("&q="+ctx.QueryParam("q"))) - setData(ctx, "htmlTitle", trH(ctx, "gist.list.search-results")) - setData(ctx, "nbHits", nbHits) - setData(ctx, "gists", renderedGists) - setData(ctx, "langs", langs) - setData(ctx, "searchQuery", ctx.QueryParam("q")) - return html(ctx, "search.html") -} - -func gistIndex(ctx echo.Context) error { - if getData(ctx, "gistpage") == "js" { - return gistJs(ctx) - } else if getData(ctx, "gistpage") == "json" { - return gistJson(ctx) - } - - gist := getData(ctx, "gist").(*db.Gist) - revision := ctx.Param("revision") - - if revision == "" { - revision = "HEAD" - } - - files, err := gist.Files(revision, true) - if _, ok := err.(*git.RevisionNotFoundError); ok { - return notFound("Revision not found") - } else if err != nil { - return errorRes(500, "Error fetching files", err) - } - - renderedFiles := render.HighlightFiles(files) - - setData(ctx, "page", "code") - setData(ctx, "commit", revision) - setData(ctx, "files", renderedFiles) - setData(ctx, "revision", revision) - setData(ctx, "htmlTitle", gist.Title) - return html(ctx, "gist.html") -} - -func gistJson(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - files, err := gist.Files("HEAD", true) - if err != nil { - return errorRes(500, "Error fetching files", err) - } - - renderedFiles := render.HighlightFiles(files) - setData(ctx, "files", renderedFiles) - - htmlbuf := bytes.Buffer{} - w := bufio.NewWriter(&htmlbuf) - if err = ctx.Echo().Renderer.Render(w, "gist_embed.html", dataMap(ctx), ctx); err != nil { - return err - } - _ = w.Flush() - - jsUrl, err := url.JoinPath(getData(ctx, "baseHttpUrl").(string), gist.User.Username, gist.Identifier()+".js") - if err != nil { - return errorRes(500, "Error joining js url", err) - } - - cssUrl, err := url.JoinPath(getData(ctx, "baseHttpUrl").(string), manifestEntries["embed.css"].File) - if err != nil { - return errorRes(500, "Error joining css url", err) - } - - return ctx.JSON(200, map[string]interface{}{ - "owner": gist.User.Username, - "id": gist.Identifier(), - "uuid": gist.Uuid, - "title": gist.Title, - "description": gist.Description, - "created_at": time.Unix(gist.CreatedAt, 0).Format(time.RFC3339), - "visibility": gist.VisibilityStr(), - "files": renderedFiles, - "embed": map[string]string{ - "html": htmlbuf.String(), - "css": cssUrl, - "js": jsUrl, - "js_dark": jsUrl + "?dark", - }, - }) -} - -func gistJs(ctx echo.Context) error { - if _, exists := ctx.QueryParams()["dark"]; exists { - setData(ctx, "dark", "dark") - } - - gist := getData(ctx, "gist").(*db.Gist) - files, err := gist.Files("HEAD", true) - if err != nil { - return errorRes(500, "Error fetching files", err) - } - - renderedFiles := render.HighlightFiles(files) - setData(ctx, "files", renderedFiles) - - htmlbuf := bytes.Buffer{} - w := bufio.NewWriter(&htmlbuf) - if err = ctx.Echo().Renderer.Render(w, "gist_embed.html", dataMap(ctx), ctx); err != nil { - return err - } - _ = w.Flush() - - cssUrl, err := url.JoinPath(getData(ctx, "baseHttpUrl").(string), manifestEntries["embed.css"].File) - if err != nil { - return errorRes(500, "Error joining css url", err) - } - - js, err := escapeJavaScriptContent(htmlbuf.String(), cssUrl) - if err != nil { - return errorRes(500, "Error escaping JavaScript content", err) - } - ctx.Response().Header().Set("Content-Type", "application/javascript") - return plainText(ctx, 200, js) -} - -func revisions(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - userName := gist.User.Username - gistName := gist.Identifier() - - pageInt := getPage(ctx) - - commits, err := gist.Log((pageInt - 1) * 10) - if err != nil { - return errorRes(500, "Error fetching commits log", err) - } - - if err := paginate(ctx, commits, pageInt, 10, "commits", userName+"/"+gistName+"/revisions", 2); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - emailsSet := map[string]struct{}{} - for _, commit := range commits { - if commit.AuthorEmail == "" { - continue - } - emailsSet[strings.ToLower(commit.AuthorEmail)] = struct{}{} - } - - emailsUsers, err := db.GetUsersFromEmails(emailsSet) - if err != nil { - return errorRes(500, "Error fetching users emails", err) - } - - setData(ctx, "page", "revisions") - setData(ctx, "revision", "HEAD") - setData(ctx, "emails", emailsUsers) - setData(ctx, "htmlTitle", trH(ctx, "gist.revision-of", gist.Title)) - - return html(ctx, "revisions.html") -} - -func create(ctx echo.Context) error { - setData(ctx, "htmlTitle", trH(ctx, "gist.new.create-a-new-gist")) - return html(ctx, "create.html") -} - -func processCreate(ctx echo.Context) error { - isCreate := false - if ctx.Request().URL.Path == "/" { - isCreate = true - } - - err := ctx.Request().ParseForm() - if err != nil { - return errorRes(400, tr(ctx, "error.bad-request"), err) - } - - dto := new(db.GistDTO) - var gist *db.Gist - - if isCreate { - setData(ctx, "htmlTitle", trH(ctx, "gist.new.create-a-new-gist")) - } else { - gist = getData(ctx, "gist").(*db.Gist) - setData(ctx, "htmlTitle", trH(ctx, "gist.edit.edit-gist", gist.Title)) - } - - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - dto.Files = make([]db.FileDTO, 0) - fileCounter := 0 - for i := 0; i < len(ctx.Request().PostForm["content"]); i++ { - name := ctx.Request().PostForm["name"][i] - content := ctx.Request().PostForm["content"][i] - - if name == "" { - fileCounter += 1 - name = "gistfile" + strconv.Itoa(fileCounter) + ".txt" - } - - escapedValue, err := url.QueryUnescape(content) - if err != nil { - return errorRes(400, tr(ctx, "error.invalid-character-unescaped"), err) - } - - dto.Files = append(dto.Files, db.FileDTO{ - Filename: strings.Trim(name, " "), - Content: escapedValue, - }) - } - - err = ctx.Validate(dto) - if err != nil { - addFlash(ctx, utils.ValidationMessages(&err, getData(ctx, "locale").(*i18n.Locale)), "error") - if isCreate { - return html(ctx, "create.html") - } else { - files, err := gist.Files("HEAD", false) - if err != nil { - return errorRes(500, "Error fetching files", err) - } - setData(ctx, "files", files) - return html(ctx, "edit.html") - } - } - - if isCreate { - gist = dto.ToGist() - } else { - gist = dto.ToExistingGist(gist) - } - - user := getUserLogged(ctx) - gist.NbFiles = len(dto.Files) - - if isCreate { - uuidGist, err := uuid.NewRandom() - if err != nil { - return errorRes(500, "Error creating an UUID", err) - } - gist.Uuid = strings.Replace(uuidGist.String(), "-", "", -1) - - gist.UserID = user.ID - gist.User = *user - } - - if gist.Title == "" { - if ctx.Request().PostForm["name"][0] == "" { - gist.Title = "gist:" + gist.Uuid - } else { - gist.Title = ctx.Request().PostForm["name"][0] - } - } - - if len(dto.Files) > 0 { - split := strings.Split(dto.Files[0].Content, "\n") - if len(split) > 10 { - gist.Preview = strings.Join(split[:10], "\n") - } else { - gist.Preview = dto.Files[0].Content - } - - gist.PreviewFilename = dto.Files[0].Filename - } - - if err = gist.InitRepository(); err != nil { - return errorRes(500, "Error creating the repository", err) - } - - if err = gist.AddAndCommitFiles(&dto.Files); err != nil { - return errorRes(500, "Error adding and committing files", err) - } - - if isCreate { - if err = gist.Create(); err != nil { - return errorRes(500, "Error creating the gist", err) - } - } else { - if err = gist.Update(); err != nil { - return errorRes(500, "Error updating the gist", err) - } - } - - gist.AddInIndex() - - return redirect(ctx, "/"+user.Username+"/"+gist.Identifier()) -} - -func editVisibility(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - - dto := new(db.VisibilityDTO) - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - gist.Private = dto.Private - if err := gist.UpdateNoTimestamps(); err != nil { - return errorRes(500, "Error updating this gist", err) - } - - addFlash(ctx, tr(ctx, "flash.gist.visibility-changed"), "success") - return redirect(ctx, "/"+gist.User.Username+"/"+gist.Identifier()) -} - -func deleteGist(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - - if err := gist.Delete(); err != nil { - return errorRes(500, "Error deleting this gist", err) - } - gist.RemoveFromIndex() - - addFlash(ctx, tr(ctx, "flash.gist.deleted"), "success") - return redirect(ctx, "/") -} - -func like(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - currentUser := getUserLogged(ctx) - - hasLiked, err := currentUser.HasLiked(gist) - if err != nil { - return errorRes(500, "Error checking if user has liked a gist", err) - } - - if hasLiked { - err = gist.RemoveUserLike(getUserLogged(ctx)) - } else { - err = gist.AppendUserLike(getUserLogged(ctx)) - } - - if err != nil { - return errorRes(500, "Error liking/dislking this gist", err) - } - - redirectTo := "/" + gist.User.Username + "/" + gist.Identifier() - if r := ctx.QueryParam("redirecturl"); r != "" { - redirectTo = r - } - return redirect(ctx, redirectTo) -} - -func fork(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - currentUser := getUserLogged(ctx) - - alreadyForked, err := gist.GetForkParent(currentUser) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Error checking if gist is already forked", err) - } - - if gist.User.ID == currentUser.ID { - addFlash(ctx, tr(ctx, "flash.gist.fork-own-gist"), "error") - return redirect(ctx, "/"+gist.User.Username+"/"+gist.Identifier()) - } - - if alreadyForked.ID != 0 { - return redirect(ctx, "/"+alreadyForked.User.Username+"/"+alreadyForked.Identifier()) - } - - uuidGist, err := uuid.NewRandom() - if err != nil { - return errorRes(500, "Error creating an UUID", err) - } - - newGist := &db.Gist{ - Uuid: strings.Replace(uuidGist.String(), "-", "", -1), - Title: gist.Title, - Preview: gist.Preview, - PreviewFilename: gist.PreviewFilename, - Description: gist.Description, - Private: gist.Private, - UserID: currentUser.ID, - ForkedID: gist.ID, - NbFiles: gist.NbFiles, - } - - if err = newGist.CreateForked(); err != nil { - return errorRes(500, "Error forking the gist in database", err) - } - - if err = gist.ForkClone(currentUser.Username, newGist.Uuid); err != nil { - return errorRes(500, "Error cloning the repository while forking", err) - } - if err = gist.IncrementForkCount(); err != nil { - return errorRes(500, "Error incrementing the fork count", err) - } - - addFlash(ctx, tr(ctx, "flash.gist.forked"), "success") - - return redirect(ctx, "/"+currentUser.Username+"/"+newGist.Identifier()) -} - -func rawFile(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - file, err := gist.File(ctx.Param("revision"), ctx.Param("file"), false) - if err != nil { - return errorRes(500, "Error getting file content", err) - } - - if file == nil { - return notFound("File not found") - } - - return plainText(ctx, 200, file.Content) -} - -func downloadFile(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - file, err := gist.File(ctx.Param("revision"), ctx.Param("file"), false) - if err != nil { - return errorRes(500, "Error getting file content", err) - } - - if file == nil { - return notFound("File not found") - } - - ctx.Response().Header().Set("Content-Type", "text/plain") - ctx.Response().Header().Set("Content-Disposition", "attachment; filename="+file.Filename) - ctx.Response().Header().Set("Content-Length", strconv.Itoa(len(file.Content))) - _, err = ctx.Response().Write([]byte(file.Content)) - if err != nil { - return errorRes(500, "Error downloading the file", err) - } - - return nil -} - -func edit(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - - files, err := gist.Files("HEAD", false) - if err != nil { - return errorRes(500, "Error fetching files from repository", err) - } - - setData(ctx, "files", files) - setData(ctx, "htmlTitle", trH(ctx, "gist.edit.edit-gist", gist.Title)) - - return html(ctx, "edit.html") -} - -func downloadZip(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - revision := ctx.Param("revision") - - files, err := gist.Files(revision, false) - if err != nil { - return errorRes(500, "Error fetching files from repository", err) - } - if len(files) == 0 { - return notFound("No files found in this revision") - } - - zipFile := new(bytes.Buffer) - - zipWriter := zip.NewWriter(zipFile) - - for _, file := range files { - fh := &zip.FileHeader{ - Name: file.Filename, - Method: zip.Deflate, - } - f, err := zipWriter.CreateHeader(fh) - if err != nil { - return errorRes(500, "Error adding a file the to the zip archive", err) - } - _, err = f.Write([]byte(file.Content)) - if err != nil { - return errorRes(500, "Error adding file content the to the zip archive", err) - } - } - err = zipWriter.Close() - if err != nil { - return errorRes(500, "Error closing the zip archive", err) - } - - ctx.Response().Header().Set("Content-Type", "application/zip") - ctx.Response().Header().Set("Content-Disposition", "attachment; filename="+gist.Identifier()+".zip") - ctx.Response().Header().Set("Content-Length", strconv.Itoa(len(zipFile.Bytes()))) - _, err = ctx.Response().Write(zipFile.Bytes()) - if err != nil { - return errorRes(500, "Error writing the zip archive", err) - } - return nil -} - -func likes(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - - pageInt := getPage(ctx) - - likers, err := gist.GetUsersLikes(pageInt - 1) - if err != nil { - return errorRes(500, "Error getting users who liked this gist", err) - } - - if err = paginate(ctx, likers, pageInt, 30, "likers", gist.User.Username+"/"+gist.Identifier()+"/likes", 1); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - setData(ctx, "htmlTitle", trH(ctx, "gist.likes.for", gist.Title)) - setData(ctx, "revision", "HEAD") - return html(ctx, "likes.html") -} - -func forks(ctx echo.Context) error { - gist := getData(ctx, "gist").(*db.Gist) - pageInt := getPage(ctx) - - currentUser := getUserLogged(ctx) - var fromUserID uint = 0 - if currentUser != nil { - fromUserID = currentUser.ID - } - - forks, err := gist.GetForks(fromUserID, pageInt-1) - if err != nil { - return errorRes(500, "Error getting users who liked this gist", err) - } - - if err = paginate(ctx, forks, pageInt, 30, "forks", gist.User.Username+"/"+gist.Identifier()+"/forks", 2); err != nil { - return errorRes(404, tr(ctx, "error.page-not-found"), nil) - } - - setData(ctx, "htmlTitle", trH(ctx, "gist.forks.for", gist.Title)) - setData(ctx, "revision", "HEAD") - return html(ctx, "forks.html") -} - -func checkbox(ctx echo.Context) error { - filename := ctx.FormValue("file") - checkboxNb := ctx.FormValue("checkbox") - - i, err := strconv.Atoi(checkboxNb) - if err != nil { - return errorRes(400, tr(ctx, "error.invalid-number"), nil) - } - - gist := getData(ctx, "gist").(*db.Gist) - file, err := gist.File("HEAD", filename, false) - if err != nil { - return errorRes(500, "Error getting file content", err) - } else if file == nil { - return notFound("File not found") - } - - markdown, err := render.Checkbox(file.Content, i) - if err != nil { - return errorRes(500, "Error checking checkbox", err) - } - - if err = gist.AddAndCommitFile(&db.FileDTO{ - Filename: filename, - Content: markdown, - }); err != nil { - return errorRes(500, "Error adding and committing files", err) - } - - if err = gist.UpdatePreviewAndCount(true); err != nil { - return errorRes(500, "Error updating the gist", err) - } - - return plainText(ctx, 200, "ok") -} - -func preview(ctx echo.Context) error { - content := ctx.FormValue("content") - - previewStr, err := render.MarkdownString(content) - if err != nil { - return errorRes(500, "Error rendering markdown", err) - } - - return plainText(ctx, 200, previewStr) -} - -func escapeJavaScriptContent(htmlContent, cssUrl string) (string, error) { - jsonContent, err := gojson.Marshal(htmlContent) - if err != nil { - return "", fmt.Errorf("failed to encode content: %w", err) - } - - jsonCssUrl, err := gojson.Marshal(cssUrl) - if err != nil { - return "", fmt.Errorf("failed to encode CSS URL: %w", err) - } - - js := fmt.Sprintf(` - document.write(''); - document.write(%s); - `, - string(jsonCssUrl), - string(jsonContent), - ) - - return js, nil -} diff --git a/internal/web/handlers/admin/actions.go b/internal/web/handlers/admin/actions.go new file mode 100644 index 0000000..3683930 --- /dev/null +++ b/internal/web/handlers/admin/actions.go @@ -0,0 +1,42 @@ +package admin + +import ( + "github.com/thomiceli/opengist/internal/actions" + "github.com/thomiceli/opengist/internal/web/context" +) + +func AdminSyncReposFromFS(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.sync-fs"), "success") + go actions.Run(actions.SyncReposFromFS) + return ctx.RedirectTo("/admin-panel") +} + +func AdminSyncReposFromDB(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.sync-db"), "success") + go actions.Run(actions.SyncReposFromDB) + return ctx.RedirectTo("/admin-panel") +} + +func AdminGcRepos(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.git-gc"), "success") + go actions.Run(actions.GitGcRepos) + return ctx.RedirectTo("/admin-panel") +} + +func AdminSyncGistPreviews(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.sync-previews"), "success") + go actions.Run(actions.SyncGistPreviews) + return ctx.RedirectTo("/admin-panel") +} + +func AdminResetHooks(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.reset-hooks"), "success") + go actions.Run(actions.ResetHooks) + return ctx.RedirectTo("/admin-panel") +} + +func AdminIndexGists(ctx *context.Context) error { + ctx.AddFlash(ctx.Tr("flash.admin.index-gists"), "success") + go actions.Run(actions.IndexGists) + return ctx.RedirectTo("/admin-panel") +} diff --git a/internal/web/handlers/admin/admin.go b/internal/web/handlers/admin/admin.go new file mode 100644 index 0000000..13a4744 --- /dev/null +++ b/internal/web/handlers/admin/admin.go @@ -0,0 +1,203 @@ +package admin + +import ( + "github.com/thomiceli/opengist/internal/actions" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/git" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "runtime" + "strconv" + "time" +) + +func AdminIndex(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("admin.admin_panel")) + ctx.SetData("adminHeaderPage", "index") + + ctx.SetData("opengistVersion", config.OpengistVersion) + ctx.SetData("goVersion", runtime.Version()) + gitVersion, err := git.GetGitVersion() + if err != nil { + return ctx.ErrorRes(500, "Cannot get git version", err) + } + ctx.SetData("gitVersion", gitVersion) + + countUsers, err := db.CountAll(&db.User{}) + if err != nil { + return ctx.ErrorRes(500, "Cannot count users", err) + } + ctx.SetData("countUsers", countUsers) + + countGists, err := db.CountAll(&db.Gist{}) + if err != nil { + return ctx.ErrorRes(500, "Cannot count gists", err) + } + ctx.SetData("countGists", countGists) + + countKeys, err := db.CountAll(&db.SSHKey{}) + if err != nil { + return ctx.ErrorRes(500, "Cannot count SSH keys", err) + } + ctx.SetData("countKeys", countKeys) + + ctx.SetData("syncReposFromFS", actions.IsRunning(actions.SyncReposFromFS)) + ctx.SetData("syncReposFromDB", actions.IsRunning(actions.SyncReposFromDB)) + ctx.SetData("gitGcRepos", actions.IsRunning(actions.GitGcRepos)) + ctx.SetData("syncGistPreviews", actions.IsRunning(actions.SyncGistPreviews)) + ctx.SetData("resetHooks", actions.IsRunning(actions.ResetHooks)) + ctx.SetData("indexGists", actions.IsRunning(actions.IndexGists)) + return ctx.Html("admin_index.html") +} + +func AdminUsers(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("admin.users")+" - "+ctx.TrH("admin.admin_panel")) + ctx.SetData("adminHeaderPage", "users") + ctx.SetData("loadStartTime", time.Now()) + + pageInt := handlers.GetPage(ctx) + + var data []*db.User + var err error + if data, err = db.GetAllUsers(pageInt - 1); err != nil { + return ctx.ErrorRes(500, "Cannot get users", err) + } + + if err = handlers.Paginate(ctx, data, pageInt, 10, "data", "admin-panel/users", 1); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + return ctx.Html("admin_users.html") +} + +func AdminGists(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("admin.gists")+" - "+ctx.TrH("admin.admin_panel")) + ctx.SetData("adminHeaderPage", "gists") + pageInt := handlers.GetPage(ctx) + + var data []*db.Gist + var err error + if data, err = db.GetAllGists(pageInt - 1); err != nil { + return ctx.ErrorRes(500, "Cannot get gists", err) + } + + if err = handlers.Paginate(ctx, data, pageInt, 10, "data", "admin-panel/gists", 1); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + return ctx.Html("admin_gists.html") +} + +func AdminUserDelete(ctx *context.Context) error { + userId, _ := strconv.ParseUint(ctx.Param("user"), 10, 64) + user, err := db.GetUserById(uint(userId)) + if err != nil { + return ctx.ErrorRes(500, "Cannot retrieve user", err) + } + + if err := user.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete this user", err) + } + + ctx.AddFlash(ctx.Tr("flash.admin.user-deleted"), "success") + return ctx.RedirectTo("/admin-panel/users") +} + +func AdminGistDelete(ctx *context.Context) error { + gist, err := db.GetGistByID(ctx.Param("gist")) + if err != nil { + return ctx.ErrorRes(500, "Cannot retrieve gist", err) + } + + if err = gist.DeleteRepository(); err != nil { + return ctx.ErrorRes(500, "Cannot delete the repository", err) + } + + if err = gist.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete this gist", err) + } + + gist.RemoveFromIndex() + + ctx.AddFlash(ctx.Tr("flash.admin.gist-deleted"), "success") + return ctx.RedirectTo("/admin-panel/gists") +} + +func AdminConfig(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("admin.configuration")+" - "+ctx.TrH("admin.admin_panel")) + ctx.SetData("adminHeaderPage", "config") + + ctx.SetData("dbtype", db.DatabaseInfo.Type.String()) + ctx.SetData("dbname", db.DatabaseInfo.Database) + + return ctx.Html("admin_config.html") +} + +func AdminSetConfig(ctx *context.Context) error { + key := ctx.FormValue("key") + value := ctx.FormValue("value") + + if err := db.UpdateSetting(key, value); err != nil { + return ctx.ErrorRes(500, "Cannot set setting", err) + } + + return ctx.JSON(200, map[string]interface{}{ + "success": true, + }) +} + +func AdminInvitations(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("admin.invitations")+" - "+ctx.TrH("admin.admin_panel")) + ctx.SetData("adminHeaderPage", "invitations") + + var invitations []*db.Invitation + var err error + if invitations, err = db.GetAllInvitations(); err != nil { + return ctx.ErrorRes(500, "Cannot get invites", err) + } + + ctx.SetData("invitations", invitations) + return ctx.Html("admin_invitations.html") +} + +func AdminInvitationsCreate(ctx *context.Context) error { + code := ctx.FormValue("code") + nbMax, err := strconv.ParseUint(ctx.FormValue("nbMax"), 10, 64) + if err != nil { + nbMax = 10 + } + + expiresAtUnix, err := strconv.ParseInt(ctx.FormValue("expiredAtUnix"), 10, 64) + if err != nil { + expiresAtUnix = time.Now().Unix() + 604800 // 1 week + } + + invitation := &db.Invitation{ + Code: code, + ExpiresAt: expiresAtUnix, + NbMax: uint(nbMax), + } + + if err := invitation.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot create invitation", err) + } + + ctx.AddFlash(ctx.Tr("flash.admin.invitation-created"), "success") + return ctx.RedirectTo("/admin-panel/invitations") +} + +func AdminInvitationsDelete(ctx *context.Context) error { + id, _ := strconv.ParseUint(ctx.Param("id"), 10, 64) + invitation, err := db.GetInvitationByID(uint(id)) + if err != nil { + return ctx.ErrorRes(500, "Cannot retrieve invitation", err) + } + + if err := invitation.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete this invitation", err) + } + + ctx.AddFlash(ctx.Tr("flash.admin.invitation-deleted"), "success") + return ctx.RedirectTo("/admin-panel/invitations") +} diff --git a/internal/web/handlers/auth.go b/internal/web/handlers/auth.go new file mode 100644 index 0000000..24314fd --- /dev/null +++ b/internal/web/handlers/auth.go @@ -0,0 +1,17 @@ +package handlers + +import ( + "github.com/thomiceli/opengist/internal/web/context" +) + +type ContextAuthInfo struct { + Context *context.Context +} + +func (auth ContextAuthInfo) RequireLogin() (bool, error) { + return auth.Context.GetData("RequireLogin") == true, nil +} + +func (auth ContextAuthInfo) AllowGistsWithoutLogin() (bool, error) { + return auth.Context.GetData("AllowGistsWithoutLogin") == true, nil +} diff --git a/internal/web/handlers/auth/mfa.go b/internal/web/handlers/auth/mfa.go new file mode 100644 index 0000000..51da153 --- /dev/null +++ b/internal/web/handlers/auth/mfa.go @@ -0,0 +1,22 @@ +package auth + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" +) + +func Mfa(ctx *context.Context) error { + var err error + + user := db.User{ID: ctx.GetSession().Values["mfaID"].(uint)} + + var hasWebauthn, hasTotp bool + if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } + + ctx.SetData("hasWebauthn", hasWebauthn) + ctx.SetData("hasTotp", hasTotp) + + return ctx.Html("mfa.html") +} diff --git a/internal/web/handlers/auth/oauth.go b/internal/web/handlers/auth/oauth.go new file mode 100644 index 0000000..f846155 --- /dev/null +++ b/internal/web/handlers/auth/oauth.go @@ -0,0 +1,166 @@ +package auth + +import ( + "crypto/md5" + "errors" + "fmt" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/auth/oauth" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "gorm.io/gorm" + "strings" +) + +func Oauth(ctx *context.Context) error { + providerStr := ctx.Param("provider") + + httpProtocol := "http" + if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { + httpProtocol = "https" + } + + forwarded_hdr := ctx.Request().Header.Get("Forwarded") + if forwarded_hdr != "" { + fields := strings.Split(forwarded_hdr, ";") + fwd := make(map[string]string) + for _, v := range fields { + p := strings.Split(v, "=") + fwd[p[0]] = p[1] + } + val, ok := fwd["proto"] + if ok && val == "https" { + httpProtocol = "https" + } + } + + var opengistUrl string + if config.C.ExternalUrl != "" { + opengistUrl = config.C.ExternalUrl + } else { + opengistUrl = httpProtocol + "://" + ctx.Request().Host + } + + provider, err := oauth.DefineProvider(providerStr, opengistUrl) + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil) + } + + if err = provider.RegisterProvider(); err != nil { + return ctx.ErrorRes(500, "Cannot create provider", err) + } + + provider.BeginAuthHandler(ctx) + return nil +} + +func OauthCallback(ctx *context.Context) error { + provider, err := oauth.CompleteUserAuth(ctx) + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.complete-oauth-login", err.Error()), err) + } + + currUser := ctx.User + // if user is logged in, link account to user and update its avatar URL + if currUser != nil { + provider.UpdateUserDB(currUser) + + if err = currUser.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot update user "+cases.Title(language.English).String(provider.GetProvider())+" id", err) + } + + ctx.AddFlash(ctx.Tr("flash.auth.account-linked-oauth", cases.Title(language.English).String(provider.GetProvider())), "success") + return ctx.RedirectTo("/settings") + } + + user := provider.GetProviderUser() + userDB, err := db.GetUserByProvider(user.UserID, provider.GetProvider()) + // if user is not in database, create it + if err != nil { + if ctx.GetData("DisableSignup") == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) + } + + if !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + if user.NickName == "" { + user.NickName = strings.Split(user.Email, "@")[0] + } + + userDB = &db.User{ + Username: user.NickName, + Email: user.Email, + MD5Hash: fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(user.Email))))), + } + + // set provider id and avatar URL + provider.UpdateUserDB(userDB) + + if err = userDB.Create(); err != nil { + if db.IsUniqueConstraintViolation(err) { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + return ctx.RedirectTo("/login") + } + + return ctx.ErrorRes(500, "Cannot create user", err) + } + + if userDB.ID == 1 { + if err = userDB.SetAdmin(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) + } + } + + keys, err := provider.GetProviderUserSSHKeys() + if err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-retrievable"), "error") + log.Error().Err(err).Msg("Could not get user keys") + } else { + for _, key := range keys { + sshKey := db.SSHKey{ + Title: "Added from " + user.Provider, + Content: key, + User: *userDB, + } + + if err = sshKey.Create(); err != nil { + ctx.AddFlash(ctx.Tr("flash.auth.user-sshkeys-not-created"), "error") + log.Error().Err(err).Msg("Could not create ssh key") + } + } + } + } + + sess := ctx.GetSession() + sess.Values["user"] = userDB.ID + ctx.SaveSession(sess) + ctx.DeleteCsrfCookie() + + return ctx.RedirectTo("/") +} + +func OauthUnlink(ctx *context.Context) error { + providerStr := ctx.Param("provider") + provider, err := oauth.DefineProvider(ctx.Param("provider"), "") + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.oauth-unsupported"), nil) + } + + currUser := ctx.User + + if provider.UserHasProvider(currUser) { + if err := currUser.DeleteProviderID(providerStr); err != nil { + return ctx.ErrorRes(500, "Cannot unlink account from "+cases.Title(language.English).String(providerStr), err) + } + + ctx.AddFlash(ctx.Tr("flash.auth.account-unlinked-oauth", cases.Title(language.English).String(providerStr)), "success") + return ctx.RedirectTo("/settings") + } + + return ctx.RedirectTo("/settings") +} diff --git a/internal/web/handlers/auth/password.go b/internal/web/handlers/auth/password.go new file mode 100644 index 0000000..6405fa4 --- /dev/null +++ b/internal/web/handlers/auth/password.go @@ -0,0 +1,170 @@ +package auth + +import ( + "errors" + "github.com/rs/zerolog/log" + passwordpkg "github.com/thomiceli/opengist/internal/auth/password" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" + "github.com/thomiceli/opengist/internal/web/context" + "gorm.io/gorm" +) + +func Register(ctx *context.Context) error { + disableSignup := ctx.GetData("DisableSignup") + disableForm := ctx.GetData("DisableLoginForm") + + code := ctx.QueryParam("code") + if code != "" { + if invitation, err := db.GetInvitationByCode(code); err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot check for invitation code", err) + } else if invitation != nil && invitation.IsUsable() { + disableSignup = false + } + } + + ctx.SetData("title", ctx.TrH("auth.new-account")) + ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) + ctx.SetData("disableForm", disableForm) + ctx.SetData("disableSignup", disableSignup) + ctx.SetData("isLoginPage", false) + return ctx.Html("auth_form.html") +} + +func ProcessRegister(ctx *context.Context) error { + disableSignup := ctx.GetData("DisableSignup") + + code := ctx.QueryParam("code") + invitation, err := db.GetInvitationByCode(code) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot check for invitation code", err) + } else if invitation.ID != 0 && invitation.IsUsable() { + disableSignup = false + } + + if disableSignup == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled"), nil) + } + + if ctx.GetData("DisableLoginForm") == true { + return ctx.ErrorRes(403, ctx.Tr("error.signup-disabled-form"), nil) + } + + ctx.SetData("title", ctx.TrH("auth.new-account")) + ctx.SetData("htmlTitle", ctx.TrH("auth.new-account")) + + sess := ctx.GetSession() + + dto := new(db.UserDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + return ctx.Html("auth_form.html") + } + + if exists, err := db.UserExists(dto.Username); err != nil || exists { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + return ctx.Html("auth_form.html") + } + + user := dto.ToUser() + + password, err := passwordpkg.HashPassword(user.Password) + if err != nil { + return ctx.ErrorRes(500, "Cannot hash password", err) + } + user.Password = password + + if err = user.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot create user", err) + } + + if user.ID == 1 { + if err = user.SetAdmin(); err != nil { + return ctx.ErrorRes(500, "Cannot set user admin", err) + } + } + + if invitation.ID != 0 { + if err := invitation.Use(); err != nil { + return ctx.ErrorRes(500, "Cannot use invitation", err) + } + } + + sess.Values["user"] = user.ID + ctx.SaveSession(sess) + + return ctx.RedirectTo("/") +} + +func Login(ctx *context.Context) error { + ctx.SetData("title", ctx.TrH("auth.login")) + ctx.SetData("htmlTitle", ctx.TrH("auth.login")) + ctx.SetData("disableForm", ctx.GetData("DisableLoginForm")) + ctx.SetData("isLoginPage", true) + return ctx.Html("auth_form.html") +} + +func ProcessLogin(ctx *context.Context) error { + if ctx.GetData("DisableLoginForm") == true { + return ctx.ErrorRes(403, ctx.Tr("error.login-disabled-form"), nil) + } + + var err error + sess := ctx.GetSession() + + dto := &db.UserDTO{} + if err = ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + password := dto.Password + + var user *db.User + + if user, err = db.GetUserByUsername(dto.Username); err != nil { + if !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Cannot get user", err) + } + log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) + ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") + return ctx.RedirectTo("/login") + } + + if ok, err := passwordpkg.VerifyPassword(password, user.Password); !ok { + if err != nil { + return ctx.ErrorRes(500, "Cannot check for password", err) + } + log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) + ctx.AddFlash(ctx.Tr("flash.auth.invalid-credentials"), "error") + return ctx.RedirectTo("/login") + } + + // handle MFA + var hasWebauthn, hasTotp bool + if hasWebauthn, hasTotp, err = user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } + if hasWebauthn || hasTotp { + sess.Values["mfaID"] = user.ID + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + return ctx.RedirectTo("/mfa") + } + + sess.Values["user"] = user.ID + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + ctx.SaveSession(sess) + ctx.DeleteCsrfCookie() + + return ctx.RedirectTo("/") +} + +func Logout(ctx *context.Context) error { + ctx.DeleteSession() + ctx.DeleteCsrfCookie() + return ctx.RedirectTo("/all") +} diff --git a/internal/web/handlers/auth/totp.go b/internal/web/handlers/auth/totp.go new file mode 100644 index 0000000..8be704c --- /dev/null +++ b/internal/web/handlers/auth/totp.go @@ -0,0 +1,177 @@ +package auth + +import ( + "github.com/thomiceli/opengist/internal/auth/totp" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "net/url" +) + +func BeginTotp(ctx *context.Context) error { + user := ctx.User + + if _, hasTotp, err := user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } else if hasTotp { + ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") + return ctx.RedirectTo("/settings") + } + + ogUrl, err := url.Parse(ctx.GetData("baseHttpUrl").(string)) + if err != nil { + return ctx.ErrorRes(500, "Cannot parse base URL", err) + } + + sess := ctx.GetSession() + generatedSecret, _ := sess.Values["generatedSecret"].([]byte) + + totpSecret, qrcode, err, generatedSecret := totp.GenerateQRCode(ctx.User.Username, ogUrl.Hostname(), generatedSecret) + if err != nil { + return ctx.ErrorRes(500, "Cannot generate TOTP QR code", err) + } + sess.Values["totpSecret"] = totpSecret + sess.Values["generatedSecret"] = generatedSecret + ctx.SaveSession(sess) + + ctx.SetData("totpSecret", totpSecret) + ctx.SetData("totpQrcode", qrcode) + + return ctx.Html("totp.html") + +} + +func FinishTotp(ctx *context.Context) error { + user := ctx.User + + if _, hasTotp, err := user.HasMFA(); err != nil { + return ctx.ErrorRes(500, "Cannot check for user MFA", err) + } else if hasTotp { + ctx.AddFlash(ctx.Tr("auth.totp.already-enabled"), "error") + return ctx.RedirectTo("/settings") + } + + dto := &db.TOTPDTO{} + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash("Invalid secret", "error") + return ctx.RedirectTo("/settings/totp/generate") + } + + sess := ctx.GetSession() + secret, ok := sess.Values["totpSecret"].(string) + if !ok { + return ctx.ErrorRes(500, "Cannot get TOTP secret from session", nil) + } + + if !totp.Validate(dto.Code, secret) { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + + return ctx.RedirectTo("/settings/totp/generate") + } + + userTotp := &db.TOTP{ + UserID: ctx.User.ID, + } + if err := userTotp.StoreSecret(secret); err != nil { + return ctx.ErrorRes(500, "Cannot store TOTP secret", err) + } + + if err := userTotp.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot create TOTP", err) + } + + ctx.AddFlash("TOTP successfully enabled", "success") + codes, err := userTotp.GenerateRecoveryCodes() + if err != nil { + return ctx.ErrorRes(500, "Cannot generate recovery codes", err) + } + + delete(sess.Values, "totpSecret") + delete(sess.Values, "generatedSecret") + ctx.SaveSession(sess) + + ctx.SetData("recoveryCodes", codes) + return ctx.Html("totp.html") +} + +func AssertTotp(ctx *context.Context) error { + var err error + dto := &db.TOTPDTO{} + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + return ctx.RedirectTo("/mfa") + } + + sess := ctx.GetSession() + userId := sess.Values["mfaID"].(uint) + var userTotp *db.TOTP + if userTotp, err = db.GetTOTPByUserID(userId); err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + redirectUrl := "/" + + var validCode, validRecoveryCode bool + if validCode, err = userTotp.ValidateCode(dto.Code); err != nil { + return ctx.ErrorRes(500, "Cannot validate TOTP code", err) + } + if !validCode { + validRecoveryCode, err = userTotp.ValidateRecoveryCode(dto.Code) + if err != nil { + return ctx.ErrorRes(500, "Cannot validate TOTP code", err) + } + + if !validRecoveryCode { + ctx.AddFlash(ctx.Tr("auth.totp.invalid-code"), "error") + return ctx.RedirectTo("/mfa") + } + + ctx.AddFlash(ctx.Tr("auth.totp.code-used", dto.Code), "warning") + redirectUrl = "/settings" + } + + sess.Values["user"] = userId + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + delete(sess.Values, "mfaID") + ctx.SaveSession(sess) + + return ctx.RedirectTo(redirectUrl) +} + +func DisableTotp(ctx *context.Context) error { + user := ctx.User + userTotp, err := db.GetTOTPByUserID(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + if err = userTotp.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete TOTP", err) + } + + ctx.AddFlash(ctx.Tr("auth.totp.disabled"), "success") + return ctx.RedirectTo("/settings") +} + +func RegenerateTotpRecoveryCodes(ctx *context.Context) error { + user := ctx.User + userTotp, err := db.GetTOTPByUserID(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get TOTP by UID", err) + } + + codes, err := userTotp.GenerateRecoveryCodes() + if err != nil { + return ctx.ErrorRes(500, "Cannot generate recovery codes", err) + } + + ctx.SetData("recoveryCodes", codes) + return ctx.Html("totp.html") +} diff --git a/internal/web/handlers/auth/webauthn.go b/internal/web/handlers/auth/webauthn.go new file mode 100644 index 0000000..5740a0a --- /dev/null +++ b/internal/web/handlers/auth/webauthn.go @@ -0,0 +1,151 @@ +package auth + +import ( + "bytes" + gojson "encoding/json" + "github.com/thomiceli/opengist/internal/auth/webauthn" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "io" +) + +func BeginWebAuthnBinding(ctx *context.Context) error { + credsCreation, jsonWaSession, err := webauthn.BeginBinding(ctx.User) + if err != nil { + return ctx.ErrorRes(500, "Cannot begin WebAuthn registration", err) + } + + sess := ctx.GetSession() + sess.Values["webauthn_registration_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.JSON(200, credsCreation) +} + +func FinishWebAuthnBinding(ctx *context.Context) error { + sess := ctx.GetSession() + jsonWaSession, ok := sess.Values["webauthn_registration_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn registration session", nil) + } + + user := ctx.User + + // extract passkey name from request + body, err := io.ReadAll(ctx.Request().Body) + if err != nil { + return ctx.ErrorRes(400, "Failed to read request body", err) + } + ctx.Request().Body.Close() + ctx.Request().Body = io.NopCloser(bytes.NewBuffer(body)) + + dto := new(db.CrendentialDTO) + _ = gojson.Unmarshal(body, &dto) + + if err = ctx.Validate(dto); err != nil { + return ctx.ErrorRes(400, "Invalid request", err) + } + passkeyName := dto.PasskeyName + if passkeyName == "" { + passkeyName = "WebAuthn" + } + + waCredential, err := webauthn.FinishBinding(user, jsonWaSession, ctx.Request()) + if err != nil { + return ctx.ErrorRes(403, "Failed binding attempt for passkey", err) + } + + if _, err = db.CreateFromCrendential(user.ID, passkeyName, waCredential); err != nil { + return ctx.ErrorRes(500, "Cannot create WebAuthn credential on database", err) + } + + delete(sess.Values, "webauthn_registration_session") + ctx.SaveSession(sess) + + ctx.AddFlash(ctx.Tr("flash.auth.passkey-registred", passkeyName), "success") + return ctx.Json([]string{"OK"}) +} + +func BeginWebAuthnLogin(ctx *context.Context) error { + credsCreation, jsonWaSession, err := webauthn.BeginDiscoverableLogin() + if err != nil { + return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) + } + + sess := ctx.GetSession() + sess.Values["webauthn_login_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.Json(credsCreation) +} + +func FinishWebAuthnLogin(ctx *context.Context) error { + sess := ctx.GetSession() + sessionData, ok := sess.Values["webauthn_login_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn login session", nil) + } + + userID, err := webauthn.FinishDiscoverableLogin(sessionData, ctx.Request()) + if err != nil { + return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) + } + + sess.Values["user"] = userID + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + + delete(sess.Values, "webauthn_login_session") + ctx.SaveSession(sess) + + return ctx.Json([]string{"OK"}) +} + +func BeginWebAuthnAssertion(ctx *context.Context) error { + sess := ctx.GetSession() + + ogUser, err := db.GetUserById(sess.Values["mfaID"].(uint)) + if err != nil { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + credsCreation, jsonWaSession, err := webauthn.BeginLogin(ogUser) + if err != nil { + return ctx.ErrorRes(401, "Cannot begin WebAuthn login", err) + } + + sess.Values["webauthn_assertion_session"] = jsonWaSession + sess.Options.MaxAge = 5 * 60 // 5 minutes + ctx.SaveSession(sess) + + return ctx.Json(credsCreation) +} + +func FinishWebAuthnAssertion(ctx *context.Context) error { + sess := ctx.GetSession() + sessionData, ok := sess.Values["webauthn_assertion_session"].([]byte) + if !ok { + return ctx.ErrorRes(401, "Cannot get WebAuthn assertion session", nil) + } + + userId := sess.Values["mfaID"].(uint) + + ogUser, err := db.GetUserById(userId) + if err != nil { + return ctx.ErrorRes(500, "Cannot get user", err) + } + + if err = webauthn.FinishLogin(ogUser, sessionData, ctx.Request()); err != nil { + return ctx.ErrorRes(403, "Failed authentication attempt for passkey", err) + } + + sess.Values["user"] = userId + sess.Options.MaxAge = 60 * 60 * 24 * 365 // 1 year + + delete(sess.Values, "webauthn_assertion_session") + delete(sess.Values, "mfaID") + ctx.SaveSession(sess) + + return ctx.Json([]string{"OK"}) +} diff --git a/internal/web/handlers/gist/all.go b/internal/web/handlers/gist/all.go new file mode 100644 index 0000000..2b45511 --- /dev/null +++ b/internal/web/handlers/gist/all.go @@ -0,0 +1,209 @@ +package gist + +import ( + "errors" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/index" + "github.com/thomiceli/opengist/internal/render" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "gorm.io/gorm" + "html/template" + "regexp" + "strings" +) + +func AllGists(ctx *context.Context) error { + var err error + var urlPage string + + fromUserStr := ctx.Param("user") + userLogged := ctx.User + pageInt := handlers.GetPage(ctx) + + sort := "created" + sortText := ctx.TrH("gist.list.sort-by-created") + order := "desc" + orderText := ctx.TrH("gist.list.order-by-desc") + + if ctx.QueryParam("sort") == "updated" { + sort = "updated" + sortText = ctx.TrH("gist.list.sort-by-updated") + } + + if ctx.QueryParam("order") == "asc" { + order = "asc" + orderText = ctx.TrH("gist.list.order-by-asc") + } + + ctx.SetData("sort", sortText) + ctx.SetData("order", orderText) + + var gists []*db.Gist + var currentUserId uint + if userLogged != nil { + currentUserId = userLogged.ID + } else { + currentUserId = 0 + } + + if fromUserStr == "" { + urlctx := ctx.Request().URL.Path + if strings.HasSuffix(urlctx, "search") { + ctx.SetData("htmlTitle", ctx.TrH("gist.list.search-results")) + ctx.SetData("mode", "search") + ctx.SetData("searchQuery", ctx.QueryParam("q")) + ctx.SetData("searchQueryUrl", template.URL("&q="+ctx.QueryParam("q"))) + urlPage = "search" + gists, err = db.GetAllGistsFromSearch(currentUserId, ctx.QueryParam("q"), pageInt-1, sort, order) + } else if strings.HasSuffix(urlctx, "all") { + ctx.SetData("htmlTitle", ctx.TrH("gist.list.all")) + ctx.SetData("mode", "all") + urlPage = "all" + gists, err = db.GetAllGistsForCurrentUser(currentUserId, pageInt-1, sort, order) + } + } else { + liked := false + forked := false + + liked, err = regexp.MatchString(`/[^/]*/liked`, ctx.Request().URL.Path) + if err != nil { + return ctx.ErrorRes(500, "Error matching regexp", err) + } + + forked, err = regexp.MatchString(`/[^/]*/forked`, ctx.Request().URL.Path) + if err != nil { + return ctx.ErrorRes(500, "Error matching regexp", err) + } + + var fromUser *db.User + + fromUser, err = db.GetUserByUsername(fromUserStr) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.NotFound("User not found") + } + return ctx.ErrorRes(500, "Error fetching user", err) + } + ctx.SetData("fromUser", fromUser) + + if countFromUser, err := db.CountAllGistsFromUser(fromUser.ID, currentUserId); err != nil { + return ctx.ErrorRes(500, "Error counting gists", err) + } else { + ctx.SetData("countFromUser", countFromUser) + } + + if countLiked, err := db.CountAllGistsLikedByUser(fromUser.ID, currentUserId); err != nil { + return ctx.ErrorRes(500, "Error counting liked gists", err) + } else { + ctx.SetData("countLiked", countLiked) + } + + if countForked, err := db.CountAllGistsForkedByUser(fromUser.ID, currentUserId); err != nil { + return ctx.ErrorRes(500, "Error counting forked gists", err) + } else { + ctx.SetData("countForked", countForked) + } + + if liked { + urlPage = fromUserStr + "/liked" + ctx.SetData("htmlTitle", ctx.TrH("gist.list.all-liked-by", fromUserStr)) + ctx.SetData("mode", "liked") + gists, err = db.GetAllGistsLikedByUser(fromUser.ID, currentUserId, pageInt-1, sort, order) + } else if forked { + urlPage = fromUserStr + "/forked" + ctx.SetData("htmlTitle", ctx.TrH("gist.list.all-forked-by", fromUserStr)) + ctx.SetData("mode", "forked") + gists, err = db.GetAllGistsForkedByUser(fromUser.ID, currentUserId, pageInt-1, sort, order) + } else { + urlPage = fromUserStr + ctx.SetData("htmlTitle", ctx.TrH("gist.list.all-from", fromUserStr)) + ctx.SetData("mode", "fromUser") + gists, err = db.GetAllGistsFromUser(fromUser.ID, currentUserId, pageInt-1, sort, order) + } + } + + renderedGists := make([]*render.RenderedGist, 0, len(gists)) + for _, gist := range gists { + rendered, err := render.HighlightGistPreview(gist) + if err != nil { + log.Error().Err(err).Msg("Error rendering gist preview for " + gist.Identifier() + " - " + gist.PreviewFilename) + } + renderedGists = append(renderedGists, &rendered) + } + + if err != nil { + return ctx.ErrorRes(500, "Error fetching gists", err) + } + + if err = handlers.Paginate(ctx, renderedGists, pageInt, 10, "gists", fromUserStr, 2, "&sort="+sort+"&order="+order); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + ctx.SetData("urlPage", urlPage) + return ctx.Html("all.html") +} + +func Search(ctx *context.Context) error { + var err error + + content, meta := handlers.ParseSearchQueryStr(ctx.QueryParam("q")) + pageInt := handlers.GetPage(ctx) + + var currentUserId uint + userLogged := ctx.User + if userLogged != nil { + currentUserId = userLogged.ID + } else { + currentUserId = 0 + } + + var visibleGistsIds []uint + visibleGistsIds, err = db.GetAllGistsVisibleByUser(currentUserId) + if err != nil { + return ctx.ErrorRes(500, "Error fetching gists", err) + } + + gistsIds, nbHits, langs, err := index.SearchGists(content, index.SearchGistMetadata{ + Username: meta["user"], + Title: meta["title"], + Filename: meta["filename"], + Extension: meta["extension"], + Language: meta["language"], + }, visibleGistsIds, pageInt) + if err != nil { + return ctx.ErrorRes(500, "Error searching gists", err) + } + + gists, err := db.GetAllGistsByIds(gistsIds) + if err != nil { + return ctx.ErrorRes(500, "Error fetching gists", err) + } + + renderedGists := make([]*render.RenderedGist, 0, len(gists)) + for _, gist := range gists { + rendered, err := render.HighlightGistPreview(gist) + if err != nil { + log.Error().Err(err).Msg("Error rendering gist preview for " + gist.Identifier() + " - " + gist.PreviewFilename) + } + renderedGists = append(renderedGists, &rendered) + } + + if pageInt > 1 && len(renderedGists) != 0 { + ctx.SetData("prevPage", pageInt-1) + } + if 10*pageInt < int(nbHits) { + ctx.SetData("nextPage", pageInt+1) + } + ctx.SetData("prevLabel", ctx.TrH("pagination.previous")) + ctx.SetData("nextLabel", ctx.TrH("pagination.next")) + ctx.SetData("urlPage", "search") + ctx.SetData("urlParams", template.URL("&q="+ctx.QueryParam("q"))) + ctx.SetData("htmlTitle", ctx.TrH("gist.list.search-results")) + ctx.SetData("nbHits", nbHits) + ctx.SetData("gists", renderedGists) + ctx.SetData("langs", langs) + ctx.SetData("searchQuery", ctx.QueryParam("q")) + return ctx.Html("search.html") +} diff --git a/internal/web/handlers/gist/create.go b/internal/web/handlers/gist/create.go new file mode 100644 index 0000000..33d3132 --- /dev/null +++ b/internal/web/handlers/gist/create.go @@ -0,0 +1,141 @@ +package gist + +import ( + "github.com/google/uuid" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" + "github.com/thomiceli/opengist/internal/web/context" + "net/url" + "strconv" + "strings" +) + +func Create(ctx *context.Context) error { + ctx.SetData("htmlTitle", ctx.TrH("gist.new.create-a-new-gist")) + return ctx.Html("create.html") +} + +func ProcessCreate(ctx *context.Context) error { + isCreate := false + if ctx.Request().URL.Path == "/" { + isCreate = true + } + + err := ctx.Request().ParseForm() + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.bad-request"), err) + } + + dto := new(db.GistDTO) + var gist *db.Gist + + if isCreate { + ctx.SetData("htmlTitle", ctx.TrH("gist.new.create-a-new-gist")) + } else { + gist = ctx.GetData("gist").(*db.Gist) + ctx.SetData("htmlTitle", ctx.TrH("gist.edit.edit-gist", gist.Title)) + } + + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + dto.Files = make([]db.FileDTO, 0) + fileCounter := 0 + for i := 0; i < len(ctx.Request().PostForm["content"]); i++ { + name := ctx.Request().PostForm["name"][i] + content := ctx.Request().PostForm["content"][i] + + if name == "" { + fileCounter += 1 + name = "gistfile" + strconv.Itoa(fileCounter) + ".txt" + } + + escapedValue, err := url.QueryUnescape(content) + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.invalid-character-unescaped"), err) + } + + dto.Files = append(dto.Files, db.FileDTO{ + Filename: strings.Trim(name, " "), + Content: escapedValue, + }) + } + + err = ctx.Validate(dto) + if err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + if isCreate { + return ctx.Html("create.html") + } else { + files, err := gist.Files("HEAD", false) + if err != nil { + return ctx.ErrorRes(500, "Error fetching files", err) + } + ctx.SetData("files", files) + return ctx.Html("edit.html") + } + } + + if isCreate { + gist = dto.ToGist() + } else { + gist = dto.ToExistingGist(gist) + } + + user := ctx.User + gist.NbFiles = len(dto.Files) + + if isCreate { + uuidGist, err := uuid.NewRandom() + if err != nil { + return ctx.ErrorRes(500, "Error creating an UUID", err) + } + gist.Uuid = strings.Replace(uuidGist.String(), "-", "", -1) + + gist.UserID = user.ID + gist.User = *user + } + + if gist.Title == "" { + if ctx.Request().PostForm["name"][0] == "" { + gist.Title = "gist:" + gist.Uuid + } else { + gist.Title = ctx.Request().PostForm["name"][0] + } + } + + if len(dto.Files) > 0 { + split := strings.Split(dto.Files[0].Content, "\n") + if len(split) > 10 { + gist.Preview = strings.Join(split[:10], "\n") + } else { + gist.Preview = dto.Files[0].Content + } + + gist.PreviewFilename = dto.Files[0].Filename + } + + if err = gist.InitRepository(); err != nil { + return ctx.ErrorRes(500, "Error creating the repository", err) + } + + if err = gist.AddAndCommitFiles(&dto.Files); err != nil { + return ctx.ErrorRes(500, "Error adding and committing files", err) + } + + if isCreate { + if err = gist.Create(); err != nil { + return ctx.ErrorRes(500, "Error creating the gist", err) + } + } else { + if err = gist.Update(); err != nil { + return ctx.ErrorRes(500, "Error updating the gist", err) + } + } + + gist.AddInIndex() + + return ctx.RedirectTo("/" + user.Username + "/" + gist.Identifier()) +} diff --git a/internal/web/handlers/gist/delete.go b/internal/web/handlers/gist/delete.go new file mode 100644 index 0000000..8f9874d --- /dev/null +++ b/internal/web/handlers/gist/delete.go @@ -0,0 +1,18 @@ +package gist + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" +) + +func DeleteGist(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + + if err := gist.Delete(); err != nil { + return ctx.ErrorRes(500, "Error deleting this gist", err) + } + gist.RemoveFromIndex() + + ctx.AddFlash(ctx.Tr("flash.gist.deleted"), "success") + return ctx.RedirectTo("/") +} diff --git a/internal/web/handlers/gist/download.go b/internal/web/handlers/gist/download.go new file mode 100644 index 0000000..fed370c --- /dev/null +++ b/internal/web/handlers/gist/download.go @@ -0,0 +1,90 @@ +package gist + +import ( + "archive/zip" + "bytes" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "strconv" +) + +func RawFile(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + file, err := gist.File(ctx.Param("revision"), ctx.Param("file"), false) + if err != nil { + return ctx.ErrorRes(500, "Error getting file content", err) + } + + if file == nil { + return ctx.NotFound("File not found") + } + + return ctx.PlainText(200, file.Content) +} + +func DownloadFile(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + file, err := gist.File(ctx.Param("revision"), ctx.Param("file"), false) + if err != nil { + return ctx.ErrorRes(500, "Error getting file content", err) + } + + if file == nil { + return ctx.NotFound("File not found") + } + + ctx.Response().Header().Set("Content-Type", "text/plain") + ctx.Response().Header().Set("Content-Disposition", "attachment; filename="+file.Filename) + ctx.Response().Header().Set("Content-Length", strconv.Itoa(len(file.Content))) + _, err = ctx.Response().Write([]byte(file.Content)) + if err != nil { + return ctx.ErrorRes(500, "Error downloading the file", err) + } + + return nil +} + +func DownloadZip(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + revision := ctx.Param("revision") + + files, err := gist.Files(revision, false) + if err != nil { + return ctx.ErrorRes(500, "Error fetching files from repository", err) + } + if len(files) == 0 { + return ctx.NotFound("No files found in this revision") + } + + zipFile := new(bytes.Buffer) + + zipWriter := zip.NewWriter(zipFile) + + for _, file := range files { + fh := &zip.FileHeader{ + Name: file.Filename, + Method: zip.Deflate, + } + f, err := zipWriter.CreateHeader(fh) + if err != nil { + return ctx.ErrorRes(500, "Error adding a file the to the zip archive", err) + } + _, err = f.Write([]byte(file.Content)) + if err != nil { + return ctx.ErrorRes(500, "Error adding file content the to the zip archive", err) + } + } + err = zipWriter.Close() + if err != nil { + return ctx.ErrorRes(500, "Error closing the zip archive", err) + } + + ctx.Response().Header().Set("Content-Type", "application/zip") + ctx.Response().Header().Set("Content-Disposition", "attachment; filename="+gist.Identifier()+".zip") + ctx.Response().Header().Set("Content-Length", strconv.Itoa(len(zipFile.Bytes()))) + _, err = ctx.Response().Write(zipFile.Bytes()) + if err != nil { + return ctx.ErrorRes(500, "Error writing the zip archive", err) + } + return nil +} diff --git a/internal/web/handlers/gist/edit.go b/internal/web/handlers/gist/edit.go new file mode 100644 index 0000000..2b571d0 --- /dev/null +++ b/internal/web/handlers/gist/edit.go @@ -0,0 +1,75 @@ +package gist + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/render" + "github.com/thomiceli/opengist/internal/web/context" + "strconv" +) + +func Edit(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + + files, err := gist.Files("HEAD", false) + if err != nil { + return ctx.ErrorRes(500, "Error fetching files from repository", err) + } + + ctx.SetData("files", files) + ctx.SetData("htmlTitle", ctx.TrH("gist.edit.edit-gist", gist.Title)) + + return ctx.Html("edit.html") +} + +func Checkbox(ctx *context.Context) error { + filename := ctx.FormValue("file") + checkboxNb := ctx.FormValue("checkbox") + + i, err := strconv.Atoi(checkboxNb) + if err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.invalid-number"), nil) + } + + gist := ctx.GetData("gist").(*db.Gist) + file, err := gist.File("HEAD", filename, false) + if err != nil { + return ctx.ErrorRes(500, "Error getting file content", err) + } else if file == nil { + return ctx.NotFound("File not found") + } + + markdown, err := render.Checkbox(file.Content, i) + if err != nil { + return ctx.ErrorRes(500, "Error checking checkbox", err) + } + + if err = gist.AddAndCommitFile(&db.FileDTO{ + Filename: filename, + Content: markdown, + }); err != nil { + return ctx.ErrorRes(500, "Error adding and committing files", err) + } + + if err = gist.UpdatePreviewAndCount(true); err != nil { + return ctx.ErrorRes(500, "Error updating the gist", err) + } + + return ctx.PlainText(200, "ok") +} + +func EditVisibility(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + + dto := new(db.VisibilityDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + gist.Private = dto.Private + if err := gist.UpdateNoTimestamps(); err != nil { + return ctx.ErrorRes(500, "Error updating this gist", err) + } + + ctx.AddFlash(ctx.Tr("flash.gist.visibility-changed"), "success") + return ctx.RedirectTo("/" + gist.User.Username + "/" + gist.Identifier()) +} diff --git a/internal/web/handlers/gist/fork.go b/internal/web/handlers/gist/fork.go new file mode 100644 index 0000000..1dac4d4 --- /dev/null +++ b/internal/web/handlers/gist/fork.go @@ -0,0 +1,86 @@ +package gist + +import ( + "errors" + "github.com/google/uuid" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "gorm.io/gorm" + "strings" +) + +func Fork(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + currentUser := ctx.User + + alreadyForked, err := gist.GetForkParent(currentUser) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return ctx.ErrorRes(500, "Error checking if gist is already forked", err) + } + + if gist.User.ID == currentUser.ID { + ctx.AddFlash(ctx.Tr("flash.gist.fork-own-gist"), "error") + return ctx.RedirectTo("/" + gist.User.Username + "/" + gist.Identifier()) + } + + if alreadyForked.ID != 0 { + return ctx.RedirectTo("/" + alreadyForked.User.Username + "/" + alreadyForked.Identifier()) + } + + uuidGist, err := uuid.NewRandom() + if err != nil { + return ctx.ErrorRes(500, "Error creating an UUID", err) + } + + newGist := &db.Gist{ + Uuid: strings.Replace(uuidGist.String(), "-", "", -1), + Title: gist.Title, + Preview: gist.Preview, + PreviewFilename: gist.PreviewFilename, + Description: gist.Description, + Private: gist.Private, + UserID: currentUser.ID, + ForkedID: gist.ID, + NbFiles: gist.NbFiles, + } + + if err = newGist.CreateForked(); err != nil { + return ctx.ErrorRes(500, "Error forking the gist in database", err) + } + + if err = gist.ForkClone(currentUser.Username, newGist.Uuid); err != nil { + return ctx.ErrorRes(500, "Error cloning the repository while forking", err) + } + if err = gist.IncrementForkCount(); err != nil { + return ctx.ErrorRes(500, "Error incrementing the fork count", err) + } + + ctx.AddFlash(ctx.Tr("flash.gist.forked"), "success") + + return ctx.RedirectTo("/" + currentUser.Username + "/" + newGist.Identifier()) +} + +func Forks(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + pageInt := handlers.GetPage(ctx) + + currentUser := ctx.User + var fromUserID uint = 0 + if currentUser != nil { + fromUserID = currentUser.ID + } + + forks, err := gist.GetForks(fromUserID, pageInt-1) + if err != nil { + return ctx.ErrorRes(500, "Error getting users who liked this gist", err) + } + + if err = handlers.Paginate(ctx, forks, pageInt, 30, "forks", gist.User.Username+"/"+gist.Identifier()+"/forks", 2); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + ctx.SetData("htmlTitle", ctx.TrH("gist.forks.for", gist.Title)) + ctx.SetData("revision", "HEAD") + return ctx.Html("forks.html") +} diff --git a/internal/web/handlers/gist/gist.go b/internal/web/handlers/gist/gist.go new file mode 100644 index 0000000..fc01c1a --- /dev/null +++ b/internal/web/handlers/gist/gist.go @@ -0,0 +1,157 @@ +package gist + +import ( + "bufio" + "bytes" + gojson "encoding/json" + "fmt" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/git" + "github.com/thomiceli/opengist/internal/render" + "github.com/thomiceli/opengist/internal/web/context" + "net/url" + "time" +) + +func GistIndex(ctx *context.Context) error { + if ctx.GetData("gistpage") == "js" { + return GistJs(ctx) + } else if ctx.GetData("gistpage") == "json" { + return GistJson(ctx) + } + + gist := ctx.GetData("gist").(*db.Gist) + revision := ctx.Param("revision") + + if revision == "" { + revision = "HEAD" + } + + files, err := gist.Files(revision, true) + if _, ok := err.(*git.RevisionNotFoundError); ok { + return ctx.NotFound("Revision not found") + } else if err != nil { + return ctx.ErrorRes(500, "Error fetching files", err) + } + + renderedFiles := render.HighlightFiles(files) + + ctx.SetData("page", "code") + ctx.SetData("commit", revision) + ctx.SetData("files", renderedFiles) + ctx.SetData("revision", revision) + ctx.SetData("htmlTitle", gist.Title) + return ctx.Html("gist.html") +} + +func GistJson(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + files, err := gist.Files("HEAD", true) + if err != nil { + return ctx.ErrorRes(500, "Error fetching files", err) + } + + renderedFiles := render.HighlightFiles(files) + ctx.SetData("files", renderedFiles) + + htmlbuf := bytes.Buffer{} + w := bufio.NewWriter(&htmlbuf) + if err = ctx.Echo().Renderer.Render(w, "gist_embed.html", ctx.DataMap(), ctx); err != nil { + return err + } + _ = w.Flush() + + jsUrl, err := url.JoinPath(ctx.GetData("baseHttpUrl").(string), gist.User.Username, gist.Identifier()+".js") + if err != nil { + return ctx.ErrorRes(500, "Error joining js url", err) + } + + cssUrl, err := url.JoinPath(ctx.GetData("baseHttpUrl").(string), context.ManifestEntries["embed.css"].File) + if err != nil { + return ctx.ErrorRes(500, "Error joining css url", err) + } + + return ctx.JSON(200, map[string]interface{}{ + "owner": gist.User.Username, + "id": gist.Identifier(), + "uuid": gist.Uuid, + "title": gist.Title, + "description": gist.Description, + "created_at": time.Unix(gist.CreatedAt, 0).Format(time.RFC3339), + "visibility": gist.VisibilityStr(), + "files": renderedFiles, + "embed": map[string]string{ + "html": htmlbuf.String(), + "css": cssUrl, + "js": jsUrl, + "js_dark": jsUrl + "?dark", + }, + }) +} + +func GistJs(ctx *context.Context) error { + if _, exists := ctx.QueryParams()["dark"]; exists { + ctx.SetData("dark", "dark") + } + + gist := ctx.GetData("gist").(*db.Gist) + files, err := gist.Files("HEAD", true) + if err != nil { + return ctx.ErrorRes(500, "Error fetching files", err) + } + + renderedFiles := render.HighlightFiles(files) + ctx.SetData("files", renderedFiles) + + htmlbuf := bytes.Buffer{} + w := bufio.NewWriter(&htmlbuf) + if err = ctx.Echo().Renderer.Render(w, "gist_embed.html", ctx.DataMap(), ctx); err != nil { + return err + } + _ = w.Flush() + + cssUrl, err := url.JoinPath(ctx.GetData("baseHttpUrl").(string), context.ManifestEntries["embed.css"].File) + if err != nil { + return ctx.ErrorRes(500, "Error joining css url", err) + } + + js, err := escapeJavaScriptContent(htmlbuf.String(), cssUrl) + if err != nil { + return ctx.ErrorRes(500, "Error escaping JavaScript content", err) + } + ctx.Response().Header().Set("Content-Type", "application/javascript") + return ctx.PlainText(200, js) +} + +func Preview(ctx *context.Context) error { + content := ctx.FormValue("content") + + previewStr, err := render.MarkdownString(content) + if err != nil { + return ctx.ErrorRes(500, "Error rendering markdown", err) + } + + return ctx.PlainText(200, previewStr) +} + +func escapeJavaScriptContent(htmlContent, cssUrl string) (string, error) { + jsonContent, err := gojson.Marshal(htmlContent) + if err != nil { + return "", fmt.Errorf("failed to encode content: %w", err) + } + + jsonCssUrl, err := gojson.Marshal(cssUrl) + if err != nil { + return "", fmt.Errorf("failed to encode CSS URL: %w", err) + } + + js := fmt.Sprintf(` + document.write(''); + document.write(%s); + `, + string(jsonCssUrl), + string(jsonContent), + ) + + return js, nil +} diff --git a/internal/web/handlers/gist/like.go b/internal/web/handlers/gist/like.go new file mode 100644 index 0000000..177d358 --- /dev/null +++ b/internal/web/handlers/gist/like.go @@ -0,0 +1,52 @@ +package gist + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" +) + +func Like(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + currentUser := ctx.User + + hasLiked, err := currentUser.HasLiked(gist) + if err != nil { + return ctx.ErrorRes(500, "Error checking if user has liked a gist", err) + } + + if hasLiked { + err = gist.RemoveUserLike(ctx.User) + } else { + err = gist.AppendUserLike(ctx.User) + } + + if err != nil { + return ctx.ErrorRes(500, "Error liking/dislking this gist", err) + } + + redirectTo := "/" + gist.User.Username + "/" + gist.Identifier() + if r := ctx.QueryParam("redirecturl"); r != "" { + redirectTo = r + } + return ctx.RedirectTo(redirectTo) +} + +func Likes(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + + pageInt := handlers.GetPage(ctx) + + likers, err := gist.GetUsersLikes(pageInt - 1) + if err != nil { + return ctx.ErrorRes(500, "Error getting users who liked this gist", err) + } + + if err = handlers.Paginate(ctx, likers, pageInt, 30, "likers", gist.User.Username+"/"+gist.Identifier()+"/likes", 1); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + ctx.SetData("htmlTitle", ctx.TrH("gist.likes.for", gist.Title)) + ctx.SetData("revision", "HEAD") + return ctx.Html("likes.html") +} diff --git a/internal/web/handlers/gist/revisions.go b/internal/web/handlers/gist/revisions.go new file mode 100644 index 0000000..7f94076 --- /dev/null +++ b/internal/web/handlers/gist/revisions.go @@ -0,0 +1,45 @@ +package gist + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "strings" +) + +func Revisions(ctx *context.Context) error { + gist := ctx.GetData("gist").(*db.Gist) + userName := gist.User.Username + gistName := gist.Identifier() + + pageInt := handlers.GetPage(ctx) + + commits, err := gist.Log((pageInt - 1) * 10) + if err != nil { + return ctx.ErrorRes(500, "Error fetching commits log", err) + } + + if err := handlers.Paginate(ctx, commits, pageInt, 10, "commits", userName+"/"+gistName+"/revisions", 2); err != nil { + return ctx.ErrorRes(404, ctx.Tr("error.page-not-found"), nil) + } + + emailsSet := map[string]struct{}{} + for _, commit := range commits { + if commit.AuthorEmail == "" { + continue + } + emailsSet[strings.ToLower(commit.AuthorEmail)] = struct{}{} + } + + emailsUsers, err := db.GetUsersFromEmails(emailsSet) + if err != nil { + return ctx.ErrorRes(500, "Error fetching users emails", err) + } + + ctx.SetData("page", "revisions") + ctx.SetData("revision", "HEAD") + ctx.SetData("emails", emailsUsers) + ctx.SetData("htmlTitle", ctx.TrH("gist.revision-of", gist.Title)) + + return ctx.Html("revisions.html") +} diff --git a/internal/web/git_http.go b/internal/web/handlers/git/http.go similarity index 71% rename from internal/web/git_http.go rename to internal/web/handlers/git/http.go index a0341ff..6e749a8 100644 --- a/internal/web/git_http.go +++ b/internal/web/handlers/git/http.go @@ -1,4 +1,4 @@ -package web +package git import ( "bytes" @@ -6,7 +6,9 @@ import ( "encoding/base64" "errors" "fmt" - "github.com/thomiceli/opengist/internal/utils" + "github.com/thomiceli/opengist/internal/auth/password" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" "net/http" "os" "os/exec" @@ -17,7 +19,6 @@ import ( "time" "github.com/google/uuid" - "github.com/labstack/echo/v4" "github.com/rs/zerolog/log" "github.com/thomiceli/opengist/internal/auth" "github.com/thomiceli/opengist/internal/db" @@ -29,7 +30,7 @@ import ( var routes = []struct { gitUrl string method string - handler func(ctx echo.Context) error + handler func(ctx *context.Context) error }{ {"(.*?)/git-upload-pack$", "POST", uploadPack}, {"(.*?)/git-receive-pack$", "POST", receivePack}, @@ -44,7 +45,7 @@ var routes = []struct { {"(.*?)/objects/pack/pack-[0-9a-f]{40}\\.idx$", "GET", idxFile}, } -func gitHttp(ctx echo.Context) error { +func GitHttp(ctx *context.Context) error { for _, route := range routes { matched, _ := regexp.MatchString(route.gitUrl, ctx.Request().URL.Path) if ctx.Request().Method == route.method && matched { @@ -52,7 +53,7 @@ func gitHttp(ctx echo.Context) error { continue } - gist := getData(ctx, "gist").(*db.Gist) + gist := ctx.GetData("gist").(*db.Gist) isInit := strings.HasPrefix(ctx.Request().URL.Path, "/init/info/refs") isInitReceive := strings.HasPrefix(ctx.Request().URL.Path, "/init/git-receive-pack") @@ -65,13 +66,13 @@ func gitHttp(ctx echo.Context) error { if _, err := os.Stat(repositoryPath); os.IsNotExist(err) { if err != nil { log.Info().Err(err).Msg("Repository directory does not exist") - return errorRes(404, "Repository directory does not exist", err) + return ctx.ErrorRes(404, "Repository directory does not exist", err) } } - setData(ctx, "repositoryPath", repositoryPath) + ctx.SetData("repositoryPath", repositoryPath) - allow, err := auth.ShouldAllowUnauthenticatedGistAccess(ContextAuthInfo{ctx}, true) + allow, err := auth.ShouldAllowUnauthenticatedGistAccess(handlers.ContextAuthInfo{Context: ctx}, true) if err != nil { log.Fatal().Err(err).Msg("Cannot check if unauthenticated access is allowed") } @@ -102,7 +103,7 @@ func gitHttp(ctx echo.Context) error { if !isInit && !isInitReceive { if gist.ID == 0 { - return plainText(ctx, 404, "Check your credentials or make sure you have access to the Gist") + return ctx.PlainText(404, "Check your credentials or make sure you have access to the Gist") } var userToCheckPermissions *db.User @@ -112,29 +113,29 @@ func gitHttp(ctx echo.Context) error { userToCheckPermissions = &gist.User } - if ok, err := utils.Argon2id.Verify(authPassword, userToCheckPermissions.Password); !ok { + if ok, err := password.VerifyPassword(authPassword, userToCheckPermissions.Password); !ok { if err != nil { - return errorRes(500, "Cannot verify password", err) + return ctx.ErrorRes(500, "Cannot verify password", err) } log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - return plainText(ctx, 404, "Check your credentials or make sure you have access to the Gist") + return ctx.PlainText(404, "Check your credentials or make sure you have access to the Gist") } } else { var user *db.User if user, err = db.GetUserByUsername(authUsername); err != nil { if !errors.Is(err, gorm.ErrRecordNotFound) { - return errorRes(500, "Cannot get user", err) + return ctx.ErrorRes(500, "Cannot get user", err) } log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - return errorRes(401, "Invalid credentials", nil) + return ctx.ErrorRes(401, "Invalid credentials", nil) } - if ok, err := utils.Argon2id.Verify(authPassword, user.Password); !ok { + if ok, err := password.VerifyPassword(authPassword, user.Password); !ok { if err != nil { - return errorRes(500, "Cannot check for password", err) + return ctx.ErrorRes(500, "Cannot check for password", err) } log.Warn().Msg("Invalid HTTP authentication attempt from " + ctx.RealIP()) - return errorRes(401, "Invalid credentials", nil) + return ctx.ErrorRes(401, "Invalid credentials", nil) } if isInit { @@ -143,56 +144,56 @@ func gitHttp(ctx echo.Context) error { gist.User = *user uuidGist, err := uuid.NewRandom() if err != nil { - return errorRes(500, "Error creating an UUID", err) + return ctx.ErrorRes(500, "Error creating an UUID", err) } gist.Uuid = strings.Replace(uuidGist.String(), "-", "", -1) gist.Title = "gist:" + gist.Uuid if err = gist.InitRepository(); err != nil { - return errorRes(500, "Cannot init repository in the file system", err) + return ctx.ErrorRes(500, "Cannot init repository in the file system", err) } if err = gist.Create(); err != nil { - return errorRes(500, "Cannot init repository in database", err) + return ctx.ErrorRes(500, "Cannot init repository in database", err) } if err := memdb.InsertGistInit(user.ID, gist); err != nil { - return errorRes(500, "Cannot save the URL for the new Gist", err) + return ctx.ErrorRes(500, "Cannot save the URL for the new Gist", err) } - setData(ctx, "gist", gist) + ctx.SetData("gist", gist) } else { gistFromMemdb, err := memdb.GetGistInitAndDelete(user.ID) if err != nil { - return errorRes(500, "Cannot get the gist link from the in memory database", err) + return ctx.ErrorRes(500, "Cannot get the gist link from the in memory database", err) } gist := gistFromMemdb.Gist - setData(ctx, "gist", gist) - setData(ctx, "repositoryPath", git.RepositoryPath(gist.User.Username, gist.Uuid)) + ctx.SetData("gist", gist) + ctx.SetData("repositoryPath", git.RepositoryPath(gist.User.Username, gist.Uuid)) } } return route.handler(ctx) } } - return notFound("Gist not found") + return ctx.NotFound("Gist not found") } -func uploadPack(ctx echo.Context) error { +func uploadPack(ctx *context.Context) error { return pack(ctx, "upload-pack") } -func receivePack(ctx echo.Context) error { +func receivePack(ctx *context.Context) error { return pack(ctx, "receive-pack") } -func pack(ctx echo.Context, serviceType string) error { +func pack(ctx *context.Context, serviceType string) error { noCacheHeaders(ctx) defer ctx.Request().Body.Close() if ctx.Request().Header.Get("Content-Type") != "application/x-git-"+serviceType+"-request" { - return errorRes(401, "Git client unsupported", nil) + return ctx.ErrorRes(401, "Git client unsupported", nil) } ctx.Response().Header().Set("Content-Type", "application/x-git-"+serviceType+"-result") @@ -202,12 +203,12 @@ func pack(ctx echo.Context, serviceType string) error { if ctx.Request().Header.Get("Content-Encoding") == "gzip" { reqBody, err = gzip.NewReader(reqBody) if err != nil { - return errorRes(500, "Cannot create gzip reader", err) + return ctx.ErrorRes(500, "Cannot create gzip reader", err) } } - repositoryPath := getData(ctx, "repositoryPath").(string) - gist := getData(ctx, "gist").(*db.Gist) + repositoryPath := ctx.GetData("repositoryPath").(string) + gist := ctx.GetData("gist").(*db.Gist) var stderr bytes.Buffer cmd := exec.Command("git", serviceType, "--stateless-rpc", repositoryPath) @@ -220,17 +221,17 @@ func pack(ctx echo.Context, serviceType string) error { cmd.Env = append(cmd.Env, "OPENGIST_REPOSITORY_ID="+strconv.Itoa(int(gist.ID))) if err = cmd.Run(); err != nil { - return errorRes(500, "Cannot run git "+serviceType+" ; "+stderr.String(), err) + return ctx.ErrorRes(500, "Cannot run git "+serviceType+" ; "+stderr.String(), err) } return nil } -func infoRefs(ctx echo.Context) error { +func infoRefs(ctx *context.Context) error { noCacheHeaders(ctx) var service string - gist := getData(ctx, "gist").(*db.Gist) + gist := ctx.GetData("gist").(*db.Gist) serviceType := ctx.QueryParam("service") if strings.HasPrefix(serviceType, "git-") { @@ -239,14 +240,14 @@ func infoRefs(ctx echo.Context) error { if service != "upload-pack" && service != "receive-pack" { if err := gist.UpdateServerInfo(); err != nil { - return errorRes(500, "Cannot update server info", err) + return ctx.ErrorRes(500, "Cannot update server info", err) } return sendFile(ctx, "text/plain; charset=utf-8") } refs, err := gist.RPC(service) if err != nil { - return errorRes(500, "Cannot run git "+service, err) + return ctx.ErrorRes(500, "Cannot run git "+service, err) } ctx.Response().Header().Set("Content-Type", "application/x-git-"+service+"-advertisement") @@ -258,38 +259,38 @@ func infoRefs(ctx echo.Context) error { return nil } -func textFile(ctx echo.Context) error { +func textFile(ctx *context.Context) error { noCacheHeaders(ctx) return sendFile(ctx, "text/plain") } -func infoPacks(ctx echo.Context) error { +func infoPacks(ctx *context.Context) error { cacheHeadersForever(ctx) return sendFile(ctx, "text/plain; charset=utf-8") } -func looseObject(ctx echo.Context) error { +func looseObject(ctx *context.Context) error { cacheHeadersForever(ctx) return sendFile(ctx, "application/x-git-loose-object") } -func packFile(ctx echo.Context) error { +func packFile(ctx *context.Context) error { cacheHeadersForever(ctx) return sendFile(ctx, "application/x-git-packed-objects") } -func idxFile(ctx echo.Context) error { +func idxFile(ctx *context.Context) error { cacheHeadersForever(ctx) return sendFile(ctx, "application/x-git-packed-objects-toc") } -func noCacheHeaders(ctx echo.Context) { +func noCacheHeaders(ctx *context.Context) { ctx.Response().Header().Set("Expires", "Thu, 01 Jan 1970 00:00:00 UTC") ctx.Response().Header().Set("Pragma", "no-cache") ctx.Response().Header().Set("Cache-Control", "no-cache, max-age=0, must-revalidate") } -func cacheHeadersForever(ctx echo.Context) { +func cacheHeadersForever(ctx *context.Context) { now := time.Now().Unix() expires := now + 31536000 ctx.Response().Header().Set("Date", fmt.Sprintf("%d", now)) @@ -297,9 +298,9 @@ func cacheHeadersForever(ctx echo.Context) { ctx.Response().Header().Set("Cache-Control", "public, max-age=31536000") } -func basicAuth(ctx echo.Context) error { +func basicAuth(ctx *context.Context) error { ctx.Response().Header().Set("WWW-Authenticate", `Basic realm="."`) - return plainText(ctx, 401, "Requires authentication") + return ctx.PlainText(401, "Requires authentication") } func basicAuthDecode(encoded string) (string, string, error) { @@ -312,12 +313,12 @@ func basicAuthDecode(encoded string) (string, string, error) { return auth[0], auth[1], nil } -func sendFile(ctx echo.Context, contentType string) error { +func sendFile(ctx *context.Context, contentType string) error { gitFile := "/" + strings.Join(strings.Split(ctx.Request().URL.Path, "/")[3:], "/") - gitFile = path.Join(getData(ctx, "repositoryPath").(string), gitFile) + gitFile = path.Join(ctx.GetData("repositoryPath").(string), gitFile) fi, err := os.Stat(gitFile) if os.IsNotExist(err) { - return errorRes(404, "File not found", nil) + return ctx.ErrorRes(404, "File not found", nil) } ctx.Response().Header().Set("Content-Type", contentType) ctx.Response().Header().Set("Content-Length", fmt.Sprintf("%d", fi.Size())) diff --git a/internal/web/healthcheck.go b/internal/web/handlers/health/healthcheck.go similarity index 53% rename from internal/web/healthcheck.go rename to internal/web/handlers/health/healthcheck.go index 6927d28..4a2c3e5 100644 --- a/internal/web/healthcheck.go +++ b/internal/web/handlers/health/healthcheck.go @@ -1,12 +1,12 @@ -package web +package health import ( - "github.com/labstack/echo/v4" "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" "time" ) -func healthcheck(ctx echo.Context) error { +func Healthcheck(ctx *context.Context) error { // Check database connection dbOk := "ok" httpStatus := 200 @@ -23,9 +23,3 @@ func healthcheck(ctx echo.Context) error { "time": time.Now().Format(time.RFC3339), }) } - -// metrics is a dummy handler to satisfy the /metrics endpoint (for Prometheus, Openmetrics, etc.) -// until we have a proper metrics endpoint -func metrics(ctx echo.Context) error { - return ctx.String(200, "") -} diff --git a/internal/web/handlers/health/metrics.go b/internal/web/handlers/health/metrics.go new file mode 100644 index 0000000..9630cfb --- /dev/null +++ b/internal/web/handlers/health/metrics.go @@ -0,0 +1,9 @@ +package health + +import "github.com/thomiceli/opengist/internal/web/context" + +// Metrics is a dummy handler to satisfy the /metrics endpoint (for Prometheus, Openmetrics, etc.) +// until we have a proper metrics endpoint +func Metrics(ctx *context.Context) error { + return ctx.String(200, "") +} diff --git a/internal/web/handlers/settings/account.go b/internal/web/handlers/settings/account.go new file mode 100644 index 0000000..8ac9d7d --- /dev/null +++ b/internal/web/handlers/settings/account.go @@ -0,0 +1,88 @@ +package settings + +import ( + "crypto/md5" + "fmt" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/git" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" + "github.com/thomiceli/opengist/internal/web/context" + "os" + "path/filepath" + "strings" + "time" +) + +func EmailProcess(ctx *context.Context) error { + user := ctx.User + email := ctx.FormValue("email") + var hash string + + if email == "" { + // generate random md5 string + hash = fmt.Sprintf("%x", md5.Sum([]byte(time.Now().String()))) + } else { + hash = fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(email))))) + } + + user.Email = strings.ToLower(email) + user.MD5Hash = hash + + if err := user.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot update email", err) + } + + ctx.AddFlash(ctx.Tr("flash.user.email-updated"), "success") + return ctx.RedirectTo("/settings") +} + +func AccountDeleteProcess(ctx *context.Context) error { + user := ctx.User + + if err := user.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete this user", err) + } + + return ctx.RedirectTo("/all") +} + +func UsernameProcess(ctx *context.Context) error { + user := ctx.User + + dto := new(db.UserDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + dto.Password = user.Password + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + return ctx.RedirectTo("/settings") + } + + if exists, err := db.UserExists(dto.Username); err != nil || exists { + ctx.AddFlash(ctx.Tr("flash.auth.username-exists"), "error") + return ctx.RedirectTo("/settings") + } + + sourceDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(user.Username)) + destinationDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(dto.Username)) + + if _, err := os.Stat(sourceDir); !os.IsNotExist(err) { + err := os.Rename(sourceDir, destinationDir) + if err != nil { + return ctx.ErrorRes(500, "Cannot rename user directory", err) + } + } + + user.Username = dto.Username + + if err := user.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot update username", err) + } + + ctx.AddFlash(ctx.Tr("flash.user.username-updated"), "success") + return ctx.RedirectTo("/settings") +} diff --git a/internal/web/handlers/settings/auth.go b/internal/web/handlers/settings/auth.go new file mode 100644 index 0000000..06afeb1 --- /dev/null +++ b/internal/web/handlers/settings/auth.go @@ -0,0 +1,58 @@ +package settings + +import ( + passwordpkg "github.com/thomiceli/opengist/internal/auth/password" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" + "github.com/thomiceli/opengist/internal/web/context" + "strconv" +) + +func PasskeyDelete(ctx *context.Context) error { + user := ctx.User + keyId, err := strconv.Atoi(ctx.Param("id")) + if err != nil { + return ctx.RedirectTo("/settings") + } + + passkey, err := db.GetCredentialByIDDB(uint(keyId)) + if err != nil || passkey.UserID != user.ID { + return ctx.RedirectTo("/settings") + } + + if err := passkey.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete passkey", err) + } + + ctx.AddFlash(ctx.Tr("flash.auth.passkey-deleted"), "success") + return ctx.RedirectTo("/settings") +} + +func PasswordProcess(ctx *context.Context) error { + user := ctx.User + + dto := new(db.UserDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + dto.Username = user.Username + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + return ctx.Html("settings.html") + } + + password, err := passwordpkg.HashPassword(dto.Password) + if err != nil { + return ctx.ErrorRes(500, "Cannot hash password", err) + } + user.Password = password + + if err = user.Update(); err != nil { + return ctx.ErrorRes(500, "Cannot update password", err) + } + + ctx.AddFlash(ctx.Tr("flash.user.password-updated"), "success") + return ctx.RedirectTo("/settings") +} diff --git a/internal/web/handlers/settings/settings.go b/internal/web/handlers/settings/settings.go new file mode 100644 index 0000000..06b3d4c --- /dev/null +++ b/internal/web/handlers/settings/settings.go @@ -0,0 +1,34 @@ +package settings + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/web/context" +) + +func UserSettings(ctx *context.Context) error { + user := ctx.User + + keys, err := db.GetSSHKeysByUserID(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get SSH keys", err) + } + + passkeys, err := db.GetAllCredentialsForUser(user.ID) + if err != nil { + return ctx.ErrorRes(500, "Cannot get WebAuthn credentials", err) + } + + _, hasTotp, err := user.HasMFA() + if err != nil { + return ctx.ErrorRes(500, "Cannot get MFA status", err) + } + + ctx.SetData("email", user.Email) + ctx.SetData("sshKeys", keys) + ctx.SetData("passkeys", passkeys) + ctx.SetData("hasTotp", hasTotp) + ctx.SetData("hasPassword", user.Password != "") + ctx.SetData("disableForm", ctx.GetData("DisableLoginForm")) + ctx.SetData("htmlTitle", ctx.TrH("settings")) + return ctx.Html("settings.html") +} diff --git a/internal/web/handlers/settings/sshkey.go b/internal/web/handlers/settings/sshkey.go new file mode 100644 index 0000000..db8bdc3 --- /dev/null +++ b/internal/web/handlers/settings/sshkey.go @@ -0,0 +1,71 @@ +package settings + +import ( + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/validator" + "github.com/thomiceli/opengist/internal/web/context" + "golang.org/x/crypto/ssh" + "strconv" + "strings" +) + +func SshKeysProcess(ctx *context.Context) error { + user := ctx.User + + dto := new(db.SSHKeyDTO) + if err := ctx.Bind(dto); err != nil { + return ctx.ErrorRes(400, ctx.Tr("error.cannot-bind-data"), err) + } + + if err := ctx.Validate(dto); err != nil { + ctx.AddFlash(validator.ValidationMessages(&err, ctx.GetData("locale").(*i18n.Locale)), "error") + return ctx.RedirectTo("/settings") + } + key := dto.ToSSHKey() + + key.UserID = user.ID + + pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key.Content)) + if err != nil { + ctx.AddFlash(ctx.Tr("flash.user.invalid-ssh-key"), "error") + return ctx.RedirectTo("/settings") + } + key.Content = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey))) + + if exists, err := db.SSHKeyDoesExists(key.Content); exists { + if err != nil { + return ctx.ErrorRes(500, "Cannot check if SSH key exists", err) + } + ctx.AddFlash(ctx.Tr("settings.ssh-key-exists"), "error") + return ctx.RedirectTo("/settings") + } + + if err := key.Create(); err != nil { + return ctx.ErrorRes(500, "Cannot add SSH key", err) + } + + ctx.AddFlash(ctx.Tr("flash.user.ssh-key-added"), "success") + return ctx.RedirectTo("/settings") +} + +func SshKeysDelete(ctx *context.Context) error { + user := ctx.User + keyId, err := strconv.Atoi(ctx.Param("id")) + if err != nil { + return ctx.RedirectTo("/settings") + } + + key, err := db.GetSSHKeyByID(uint(keyId)) + + if err != nil || key.UserID != user.ID { + return ctx.RedirectTo("/settings") + } + + if err := key.Delete(); err != nil { + return ctx.ErrorRes(500, "Cannot delete SSH key", err) + } + + ctx.AddFlash(ctx.Tr("flash.user.ssh-key-deleted"), "success") + return ctx.RedirectTo("/settings") +} diff --git a/internal/web/handlers/util.go b/internal/web/handlers/util.go new file mode 100644 index 0000000..25b161c --- /dev/null +++ b/internal/web/handlers/util.go @@ -0,0 +1,79 @@ +package handlers + +import ( + "errors" + "github.com/thomiceli/opengist/internal/web/context" + "html/template" + "strconv" + "strings" +) + +func GetPage(ctx *context.Context) int { + page := ctx.QueryParam("page") + if page == "" { + page = "1" + } + pageInt, err := strconv.Atoi(page) + if err != nil { + pageInt = 1 + } + ctx.SetData("currPage", pageInt) + + return pageInt +} + +func Paginate[T any](ctx *context.Context, data []*T, pageInt int, perPage int, templateDataName string, urlPage string, labels int, urlParams ...string) error { + lenData := len(data) + if lenData == 0 && pageInt != 1 { + return errors.New("page not found") + } + + if lenData > perPage { + if lenData > 1 { + data = data[:lenData-1] + } + ctx.SetData("nextPage", pageInt+1) + } + if pageInt > 1 { + ctx.SetData("prevPage", pageInt-1) + } + + if len(urlParams) > 0 { + ctx.SetData("urlParams", template.URL(urlParams[0])) + } + + switch labels { + case 1: + ctx.SetData("prevLabel", ctx.TrH("pagination.previous")) + ctx.SetData("nextLabel", ctx.TrH("pagination.next")) + case 2: + ctx.SetData("prevLabel", ctx.TrH("pagination.newer")) + ctx.SetData("nextLabel", ctx.TrH("pagination.older")) + } + + ctx.SetData("urlPage", urlPage) + ctx.SetData(templateDataName, data) + return nil +} + +func ParseSearchQueryStr(query string) (string, map[string]string) { + words := strings.Fields(query) + metadata := make(map[string]string) + var contentBuilder strings.Builder + + for _, word := range words { + if strings.Contains(word, ":") { + keyValue := strings.SplitN(word, ":", 2) + if len(keyValue) == 2 { + key := keyValue[0] + value := keyValue[1] + metadata[key] = value + } + } else { + contentBuilder.WriteString(word + " ") + } + } + + content := strings.TrimSpace(contentBuilder.String()) + return content, metadata +} diff --git a/internal/web/server.go b/internal/web/server.go deleted file mode 100644 index f3d4850..0000000 --- a/internal/web/server.go +++ /dev/null @@ -1,626 +0,0 @@ -package web - -import ( - "context" - gojson "encoding/json" - "errors" - "fmt" - htmlpkg "html" - "html/template" - "io" - "net/http" - "net/url" - "os" - "path" - "path/filepath" - "regexp" - "strconv" - "strings" - "time" - - "github.com/thomiceli/opengist/internal/index" - "github.com/thomiceli/opengist/internal/utils" - "github.com/thomiceli/opengist/templates" - - "github.com/gorilla/sessions" - "github.com/labstack/echo/v4" - "github.com/labstack/echo/v4/middleware" - "github.com/markbates/goth/gothic" - "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/auth" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/git" - "github.com/thomiceli/opengist/internal/i18n" - "github.com/thomiceli/opengist/public" - "golang.org/x/text/language" -) - -var ( - dev bool - flashStore *sessions.CookieStore // session store for flash messages - userStore *sessions.FilesystemStore // session store for user sessions - re = regexp.MustCompile("[^a-z0-9]+") - fm = template.FuncMap{ - "split": strings.Split, - "indexByte": strings.IndexByte, - "toInt": func(i string) int { - val, _ := strconv.Atoi(i) - return val - }, - "inc": func(i int) int { - return i + 1 - }, - "splitGit": func(i string) []string { - return strings.FieldsFunc(i, func(r rune) bool { - return r == ',' || r == ' ' - }) - }, - "lines": func(i string) []string { - return strings.Split(i, "\n") - }, - "isMarkdown": func(i string) bool { - return strings.ToLower(filepath.Ext(i)) == ".md" - }, - "isCsv": func(i string) bool { - return strings.ToLower(filepath.Ext(i)) == ".csv" - }, - "isSvg": func(i string) bool { - return strings.ToLower(filepath.Ext(i)) == ".svg" - }, - "csvFile": func(file *git.File) *git.CsvFile { - if strings.ToLower(filepath.Ext(file.Filename)) != ".csv" { - return nil - } - - csvFile, err := git.ParseCsv(file) - if err != nil { - return nil - } - - return csvFile - }, - "httpStatusText": http.StatusText, - "loadedTime": func(startTime time.Time) string { - return fmt.Sprint(time.Since(startTime).Nanoseconds()/1e6) + "ms" - }, - "slug": func(s string) string { - return strings.Trim(re.ReplaceAllString(strings.ToLower(s), "-"), "-") - }, - "avatarUrl": func(user *db.User, noGravatar bool) string { - if user.AvatarURL != "" { - return user.AvatarURL - } - - if user.MD5Hash != "" && !noGravatar { - return "https://www.gravatar.com/avatar/" + user.MD5Hash + "?d=identicon&s=200" - } - - return defaultAvatar() - }, - "asset": asset, - "custom": customAsset, - "dev": func() bool { - return dev - }, - "defaultAvatar": defaultAvatar, - "visibilityStr": func(visibility db.Visibility, lowercase bool) string { - s := "Public" - switch visibility { - case 1: - s = "Unlisted" - case 2: - s = "Private" - } - - if lowercase { - return strings.ToLower(s) - } - return s - }, - "unescape": htmlpkg.UnescapeString, - "join": func(s ...string) string { - return strings.Join(s, "") - }, - "toStr": func(i interface{}) string { - return fmt.Sprint(i) - }, - "safe": func(s string) template.HTML { - return template.HTML(s) - }, - "dict": func(values ...interface{}) (map[string]interface{}, error) { - if len(values)%2 != 0 { - return nil, errors.New("invalid dict call") - } - dict := make(map[string]interface{}) - for i := 0; i < len(values); i += 2 { - key, ok := values[i].(string) - if !ok { - return nil, errors.New("dict keys must be strings") - } - dict[key] = values[i+1] - } - return dict, nil - }, - "addMetadataToSearchQuery": addMetadataToSearchQuery, - "indexEnabled": index.Enabled, - "isUrl": func(s string) bool { - _, err := url.ParseRequestURI(s) - return err == nil - }, - } -) - -type Template struct { - templates *template.Template -} - -func (t *Template) Render(w io.Writer, name string, data interface{}, _ echo.Context) error { - return t.templates.ExecuteTemplate(w, name, data) -} - -type Server struct { - echo *echo.Echo - dev bool -} - -func NewServer(isDev bool, sessionsPath string, ignoreCsrf bool) *Server { - dev = isDev - flashStore = sessions.NewCookieStore([]byte("opengist")) - encryptKey, _ := utils.GenerateSecretKey(filepath.Join(sessionsPath, "session-encrypt.key")) - userStore = sessions.NewFilesystemStore(sessionsPath, config.SecretKey, encryptKey) - userStore.MaxLength(10 * 1024) - gothic.Store = userStore - - e := echo.New() - e.HideBanner = true - e.HidePort = true - - if err := i18n.Locales.LoadAll(); err != nil { - log.Fatal().Err(err).Msg("Failed to load locales") - } - - e.Use(dataInit) - e.Use(locale) - e.Pre(middleware.MethodOverrideWithConfig(middleware.MethodOverrideConfig{ - Getter: middleware.MethodFromForm("_method"), - })) - e.Pre(middleware.RemoveTrailingSlash()) - e.Pre(middleware.CORS()) - e.Pre(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ - LogURI: true, LogStatus: true, LogMethod: true, - LogValuesFunc: func(ctx echo.Context, v middleware.RequestLoggerValues) error { - log.Info().Str("uri", v.URI).Int("status", v.Status).Str("method", v.Method). - Str("ip", ctx.RealIP()).TimeDiff("duration", time.Now(), v.StartTime). - Msg("HTTP") - return nil - }, - })) - e.Use(middleware.Recover()) - e.Use(middleware.Secure()) - - t := template.Must(template.New("t").Funcs(fm).ParseFS(templates.Files, "*/*.html")) - customPattern := filepath.Join(config.GetHomeDir(), "custom", "*.html") - matches, err := filepath.Glob(customPattern) - if err != nil { - log.Fatal().Err(err).Msg("Failed to check for custom templates") - } - if len(matches) > 0 { - t, err = t.ParseGlob(customPattern) - if err != nil { - log.Fatal().Err(err).Msg("Failed to parse custom templates") - } - } - e.Renderer = &Template{ - templates: t, - } - - e.HTTPErrorHandler = func(er error, ctx echo.Context) { - var httpErr *echo.HTTPError - if errors.As(er, &httpErr) { - acceptJson := strings.Contains(ctx.Request().Header.Get("Accept"), "application/json") - setData(ctx, "error", er) - if acceptJson { - if fatalErr := jsonWithCode(ctx, httpErr.Code, httpErr); fatalErr != nil { - log.Fatal().Err(fatalErr).Send() - } - } else { - if fatalErr := htmlWithCode(ctx, httpErr.Code, "error.html"); fatalErr != nil { - log.Fatal().Err(fatalErr).Send() - } - } - } else { - log.Fatal().Err(er).Send() - } - } - - e.Use(sessionInit) - - e.Validator = utils.NewValidator() - - if !dev { - parseManifestEntries() - } - - // Web based routes - g1 := e.Group("") - { - if !ignoreCsrf { - g1.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ - TokenLookup: "form:_csrf,header:X-CSRF-Token", - CookiePath: "/", - CookieHTTPOnly: true, - CookieSameSite: http.SameSiteStrictMode, - })) - g1.Use(csrfInit) - } - - g1.GET("/", create, logged) - g1.POST("/", processCreate, logged) - g1.POST("/preview", preview, logged) - - g1.GET("/healthcheck", healthcheck) - g1.GET("/metrics", metrics) - - g1.GET("/register", register) - g1.POST("/register", processRegister) - g1.GET("/login", login) - g1.POST("/login", processLogin) - g1.GET("/logout", logout) - g1.GET("/oauth/:provider", oauth) - g1.GET("/oauth/:provider/callback", oauthCallback) - g1.GET("/oauth/:provider/unlink", oauthUnlink, logged) - g1.POST("/webauthn/bind", beginWebAuthnBinding, logged) - g1.POST("/webauthn/bind/finish", finishWebAuthnBinding, logged) - g1.POST("/webauthn/login", beginWebAuthnLogin) - g1.POST("/webauthn/login/finish", finishWebAuthnLogin) - g1.POST("/webauthn/assertion", beginWebAuthnAssertion, inMFASession) - g1.POST("/webauthn/assertion/finish", finishWebAuthnAssertion, inMFASession) - g1.GET("/mfa", mfa, inMFASession) - g1.POST("/mfa/totp/assertion", assertTotp, inMFASession) - - g1.GET("/settings", userSettings, logged) - g1.POST("/settings/email", emailProcess, logged) - g1.DELETE("/settings/account", accountDeleteProcess, logged) - g1.POST("/settings/ssh-keys", sshKeysProcess, logged) - g1.DELETE("/settings/ssh-keys/:id", sshKeysDelete, logged) - g1.DELETE("/settings/passkeys/:id", passkeyDelete, logged) - g1.PUT("/settings/password", passwordProcess, logged) - g1.PUT("/settings/username", usernameProcess, logged) - g1.GET("/settings/totp/generate", beginTotp, logged) - g1.POST("/settings/totp/generate", finishTotp, logged) - g1.DELETE("/settings/totp", disableTotp, logged) - g1.POST("/settings/totp/regenerate", regenerateTotpRecoveryCodes, logged) - - g2 := g1.Group("/admin-panel") - { - g2.Use(adminPermission) - g2.GET("", adminIndex) - g2.GET("/users", adminUsers) - g2.POST("/users/:user/delete", adminUserDelete) - g2.GET("/gists", adminGists) - g2.POST("/gists/:gist/delete", adminGistDelete) - g2.GET("/invitations", adminInvitations) - g2.POST("/invitations", adminInvitationsCreate) - g2.POST("/invitations/:id/delete", adminInvitationsDelete) - g2.POST("/sync-fs", adminSyncReposFromFS) - g2.POST("/sync-db", adminSyncReposFromDB) - g2.POST("/gc-repos", adminGcRepos) - g2.POST("/sync-previews", adminSyncGistPreviews) - g2.POST("/reset-hooks", adminResetHooks) - g2.POST("/index-gists", adminIndexGists) - g2.GET("/configuration", adminConfig) - g2.PUT("/set-config", adminSetConfig) - } - - if config.C.HttpGit { - e.Any("/init/*", gitHttp, gistNewPushSoftInit) - } - - g1.GET("/all", allGists, checkRequireLogin) - - if index.Enabled() { - g1.GET("/search", search, checkRequireLogin) - } else { - g1.GET("/search", allGists, checkRequireLogin) - } - - g1.GET("/:user", allGists, checkRequireLogin) - g1.GET("/:user/liked", allGists, checkRequireLogin) - g1.GET("/:user/forked", allGists, checkRequireLogin) - - g3 := g1.Group("/:user/:gistname") - { - g3.Use(makeCheckRequireLogin(true), gistInit) - g3.GET("", gistIndex) - g3.GET("/rev/:revision", gistIndex) - g3.GET("/revisions", revisions) - g3.GET("/archive/:revision", downloadZip) - g3.POST("/visibility", editVisibility, logged, writePermission) - g3.POST("/delete", deleteGist, logged, writePermission) - g3.GET("/raw/:revision/:file", rawFile) - g3.GET("/download/:revision/:file", downloadFile) - g3.GET("/edit", edit, logged, writePermission) - g3.POST("/edit", processCreate, logged, writePermission) - g3.POST("/like", like, logged) - g3.GET("/likes", likes, checkRequireLogin) - g3.POST("/fork", fork, logged) - g3.GET("/forks", forks, checkRequireLogin) - g3.PUT("/checkbox", checkbox, logged, writePermission) - } - } - - customFs := os.DirFS(filepath.Join(config.GetHomeDir(), "custom")) - e.GET("/assets/*", func(ctx echo.Context) error { - if _, err := public.Files.Open(path.Join("assets", ctx.Param("*"))); !dev && err == nil { - ctx.Response().Header().Set("Cache-Control", "public, max-age=31536000") - ctx.Response().Header().Set("Expires", time.Now().AddDate(1, 0, 0).Format(http.TimeFormat)) - - return echo.WrapHandler(http.FileServer(http.FS(public.Files)))(ctx) - } - - // if the custom file is an .html template, render it - if strings.HasSuffix(ctx.Param("*"), ".html") { - if err := html(ctx, ctx.Param("*")); err != nil { - return notFound("Page not found") - } - return nil - } - - return echo.WrapHandler(http.StripPrefix("/assets/", http.FileServer(http.FS(customFs))))(ctx) - }) - - // Git HTTP routes - if config.C.HttpGit { - e.Any("/:user/:gistname/*", gitHttp, gistSoftInit) - } - - e.Any("/*", noRouteFound) - - return &Server{echo: e, dev: dev} -} - -func (s *Server) Start() { - addr := config.C.HttpHost + ":" + config.C.HttpPort - - log.Info().Msg("Starting HTTP server on http://" + addr) - if err := s.echo.Start(addr); err != nil && err != http.ErrServerClosed { - log.Fatal().Err(err).Msg("Failed to start HTTP server") - } -} - -func (s *Server) Stop() { - if err := s.echo.Close(); err != nil { - log.Fatal().Err(err).Msg("Failed to stop HTTP server") - } -} - -func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - s.echo.ServeHTTP(w, r) -} - -func dataInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - ctxValue := context.WithValue(ctx.Request().Context(), dataKey, echo.Map{}) - ctx.SetRequest(ctx.Request().WithContext(ctxValue)) - setData(ctx, "loadStartTime", time.Now()) - - if err := loadSettings(ctx); err != nil { - return errorRes(500, "Cannot read settings from database", err) - } - - setData(ctx, "c", config.C) - - setData(ctx, "githubOauth", config.C.GithubClientKey != "" && config.C.GithubSecret != "") - setData(ctx, "gitlabOauth", config.C.GitlabClientKey != "" && config.C.GitlabSecret != "") - setData(ctx, "giteaOauth", config.C.GiteaClientKey != "" && config.C.GiteaSecret != "") - setData(ctx, "oidcOauth", config.C.OIDCClientKey != "" && config.C.OIDCSecret != "" && config.C.OIDCDiscoveryUrl != "") - - httpProtocol := "http" - if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { - httpProtocol = "https" - } - setData(ctx, "httpProtocol", strings.ToUpper(httpProtocol)) - - var baseHttpUrl string - // if a custom external url is set, use it - if config.C.ExternalUrl != "" { - baseHttpUrl = config.C.ExternalUrl - } else { - baseHttpUrl = httpProtocol + "://" + ctx.Request().Host - } - - setData(ctx, "baseHttpUrl", baseHttpUrl) - - return next(ctx) - } -} - -func locale(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - // Check URL arguments - lang := ctx.Request().URL.Query().Get("lang") - changeLang := lang != "" - - // Then check cookies - if len(lang) == 0 { - cookie, _ := ctx.Request().Cookie("lang") - if cookie != nil { - lang = cookie.Value - } - } - - // Check again in case someone changes the supported language list. - if lang != "" && !i18n.Locales.HasLocale(lang) { - lang = "" - changeLang = false - } - - // 3.Then check from 'Accept-Language' header. - if len(lang) == 0 { - tags, _, _ := language.ParseAcceptLanguage(ctx.Request().Header.Get("Accept-Language")) - lang = i18n.Locales.MatchTag(tags) - } - - if changeLang { - ctx.SetCookie(&http.Cookie{Name: "lang", Value: lang, Path: "/", MaxAge: 1<<31 - 1}) - } - - localeUsed, err := i18n.Locales.GetLocale(lang) - if err != nil { - return errorRes(500, "Cannot get locale", err) - } - - setData(ctx, "localeName", localeUsed.Name) - setData(ctx, "locale", localeUsed) - setData(ctx, "allLocales", i18n.Locales.Locales) - - return next(ctx) - } -} - -func sessionInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - sess := getSession(ctx) - if sess.Values["user"] != nil { - var err error - var user *db.User - - if user, err = db.GetUserById(sess.Values["user"].(uint)); err != nil { - sess.Values["user"] = nil - saveSession(sess, ctx) - setData(ctx, "userLogged", nil) - return redirect(ctx, "/all") - } - if user != nil { - setData(ctx, "userLogged", user) - } - return next(ctx) - } - - setData(ctx, "userLogged", nil) - return next(ctx) - } -} - -func csrfInit(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - setCsrfHtmlForm(ctx) - return next(ctx) - } -} - -func writePermission(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - gist := getData(ctx, "gist") - user := getUserLogged(ctx) - if !gist.(*db.Gist).CanWrite(user) { - return redirect(ctx, "/"+gist.(*db.Gist).User.Username+"/"+gist.(*db.Gist).Identifier()) - } - return next(ctx) - } -} - -func adminPermission(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - user := getUserLogged(ctx) - if user == nil || !user.IsAdmin { - return notFound("User not found") - } - return next(ctx) - } -} - -func logged(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - user := getUserLogged(ctx) - if user != nil { - return next(ctx) - } - return redirect(ctx, "/all") - } -} - -func inMFASession(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - sess := getSession(ctx) - _, ok := sess.Values["mfaID"].(uint) - if !ok { - return errorRes(400, tr(ctx, "error.not-in-mfa-session"), nil) - } - return next(ctx) - } -} - -func makeCheckRequireLogin(isSingleGistAccess bool) echo.MiddlewareFunc { - return func(next echo.HandlerFunc) echo.HandlerFunc { - return func(ctx echo.Context) error { - if user := getUserLogged(ctx); user != nil { - return next(ctx) - } - - allow, err := auth.ShouldAllowUnauthenticatedGistAccess(ContextAuthInfo{ctx}, isSingleGistAccess) - if err != nil { - log.Fatal().Err(err).Msg("Failed to check if unauthenticated access is allowed") - } - - if !allow { - addFlash(ctx, tr(ctx, "flash.auth.must-be-logged-in"), "error") - return redirect(ctx, "/login") - } - return next(ctx) - } - } -} - -func checkRequireLogin(next echo.HandlerFunc) echo.HandlerFunc { - return makeCheckRequireLogin(false)(next) -} - -func noRouteFound(echo.Context) error { - return notFound("Page not found") -} - -// --- - -type Asset struct { - File string `json:"file"` -} - -var manifestEntries map[string]Asset - -func parseManifestEntries() { - file, err := public.Files.Open("manifest.json") - if err != nil { - log.Fatal().Err(err).Msg("Failed to open manifest.json") - } - byteValue, err := io.ReadAll(file) - if err != nil { - log.Fatal().Err(err).Msg("Failed to read manifest.json") - } - if err = gojson.Unmarshal(byteValue, &manifestEntries); err != nil { - log.Fatal().Err(err).Msg("Failed to unmarshal manifest.json") - } -} - -func defaultAvatar() string { - if dev { - return "http://localhost:16157/default.png" - } - return config.C.ExternalUrl + "/" + manifestEntries["default.png"].File -} - -func asset(file string) string { - if dev { - return "http://localhost:16157/" + file - } - return config.C.ExternalUrl + "/" + manifestEntries[file].File -} - -func customAsset(file string) string { - assetpath, err := url.JoinPath("/", "assets", file) - if err != nil { - log.Error().Err(err).Msgf("Failed to join path for custom file %s", file) - } - return config.C.ExternalUrl + assetpath -} diff --git a/internal/web/server/handler.go b/internal/web/server/handler.go new file mode 100644 index 0000000..e4b3e60 --- /dev/null +++ b/internal/web/server/handler.go @@ -0,0 +1,40 @@ +package server + +import ( + "github.com/labstack/echo/v4" + "github.com/thomiceli/opengist/internal/web/context" +) + +type Handler func(ctx *context.Context) error +type Middleware func(next Handler) Handler + +func (h Handler) toEcho() echo.HandlerFunc { + return func(c echo.Context) error { + return h(c.(*context.Context)) + } +} + +func (m Middleware) toEcho() echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return m(func(c *context.Context) error { + return next(c) + }).toEcho() + } +} + +func (h Handler) toEchoHandler() echo.HandlerFunc { + return func(c echo.Context) error { + if ogc, ok := c.(*context.Context); ok { + return h(ogc) + } + // Could also add error handling for incorrect context type + return h(c.(*context.Context)) + } +} + +func chain(h Handler, middleware ...Middleware) Handler { + for i := len(middleware) - 1; i >= 0; i-- { + h = middleware[i](h) + } + return h +} diff --git a/internal/web/server/middlewares.go b/internal/web/server/middlewares.go new file mode 100644 index 0000000..26a6f72 --- /dev/null +++ b/internal/web/server/middlewares.go @@ -0,0 +1,393 @@ +package server + +import ( + "errors" + "fmt" + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/auth" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/i18n" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "golang.org/x/text/cases" + "golang.org/x/text/language" + "html/template" + "net/http" + "path/filepath" + "strings" + "time" +) + +func (s *Server) useCustomContext() { + s.echo.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + cc := context.NewContext(c, s.sessionsPath) + return next(cc) + } + }) +} + +func (s *Server) registerMiddlewares() { + s.echo.Use(Middleware(dataInit).toEcho()) + s.echo.Use(Middleware(locale).toEcho()) + + s.echo.Pre(middleware.MethodOverrideWithConfig(middleware.MethodOverrideConfig{ + Getter: middleware.MethodFromForm("_method"), + })) + s.echo.Pre(middleware.RemoveTrailingSlash()) + s.echo.Pre(middleware.CORS()) + s.echo.Pre(middleware.RequestLoggerWithConfig(middleware.RequestLoggerConfig{ + LogURI: true, LogStatus: true, LogMethod: true, + LogValuesFunc: func(ctx echo.Context, v middleware.RequestLoggerValues) error { + log.Info().Str("uri", v.URI).Int("status", v.Status).Str("method", v.Method). + Str("ip", ctx.RealIP()).TimeDiff("duration", time.Now(), v.StartTime). + Msg("HTTP") + return nil + }, + })) + s.echo.Use(middleware.Recover()) + s.echo.Use(middleware.Secure()) + s.echo.Use(Middleware(sessionInit).toEcho()) + + if !s.ignoreCsrf { + s.echo.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ + TokenLookup: "form:_csrf,header:X-CSRF-Token", + CookiePath: "/", + CookieHTTPOnly: true, + CookieSameSite: http.SameSiteStrictMode, + })) + s.echo.Use(Middleware(csrfInit).toEcho()) + } +} + +func (s *Server) errorHandler(err error, ctx echo.Context) { + var httpErr *echo.HTTPError + if errors.As(err, &httpErr) { + acceptJson := strings.Contains(ctx.Request().Header.Get("Accept"), "application/json") + data := ctx.Request().Context().Value(context.DataKeyStr).(echo.Map) + data["error"] = err + if acceptJson { + if err := ctx.JSON(httpErr.Code, httpErr); err != nil { + log.Fatal().Err(err).Send() + } + return + } + + if err := ctx.Render(httpErr.Code, "error", data); err != nil { + log.Fatal().Err(err).Send() + } + return + } + + log.Fatal().Err(err).Send() +} + +func dataInit(next Handler) Handler { + return func(ctx *context.Context) error { + ctx.SetData("loadStartTime", time.Now()) + + if err := loadSettings(ctx); err != nil { + return ctx.ErrorRes(500, "Cannot load settings", err) + } + + ctx.SetData("c", config.C) + + ctx.SetData("githubOauth", config.C.GithubClientKey != "" && config.C.GithubSecret != "") + ctx.SetData("gitlabOauth", config.C.GitlabClientKey != "" && config.C.GitlabSecret != "") + ctx.SetData("giteaOauth", config.C.GiteaClientKey != "" && config.C.GiteaSecret != "") + ctx.SetData("oidcOauth", config.C.OIDCClientKey != "" && config.C.OIDCSecret != "" && config.C.OIDCDiscoveryUrl != "") + + httpProtocol := "http" + if ctx.Request().TLS != nil || ctx.Request().Header.Get("X-Forwarded-Proto") == "https" { + httpProtocol = "https" + } + ctx.SetData("httpProtocol", strings.ToUpper(httpProtocol)) + + var baseHttpUrl string + // if a custom external url is set, use it + if config.C.ExternalUrl != "" { + baseHttpUrl = config.C.ExternalUrl + } else { + baseHttpUrl = httpProtocol + "://" + ctx.Request().Host + } + + ctx.SetData("baseHttpUrl", baseHttpUrl) + + return next(ctx) + } +} + +func writePermission(next Handler) Handler { + return func(ctx *context.Context) error { + gist := ctx.GetData("gist") + user := ctx.User + if !gist.(*db.Gist).CanWrite(user) { + return ctx.RedirectTo("/" + gist.(*db.Gist).User.Username + "/" + gist.(*db.Gist).Identifier()) + } + return next(ctx) + } +} + +func adminPermission(next Handler) Handler { + return func(ctx *context.Context) error { + user := ctx.User + if user == nil || !user.IsAdmin { + return ctx.NotFound("User not found") + } + return next(ctx) + } +} + +func logged(next Handler) Handler { + return func(ctx *context.Context) error { + user := ctx.User + if user != nil { + return next(ctx) + } + return ctx.RedirectTo("/all") + } +} + +func inMFASession(next Handler) Handler { + return func(ctx *context.Context) error { + sess := ctx.GetSession() + _, ok := sess.Values["mfaID"].(uint) + if !ok { + return ctx.ErrorRes(400, ctx.Tr("error.not-in-mfa-session"), nil) + } + return next(ctx) + } +} + +func makeCheckRequireLogin(isSingleGistAccess bool) Middleware { + return func(next Handler) Handler { + return func(ctx *context.Context) error { + if user := ctx.User; user != nil { + return next(ctx) + } + + allow, err := auth.ShouldAllowUnauthenticatedGistAccess(handlers.ContextAuthInfo{Context: ctx}, isSingleGistAccess) + if err != nil { + log.Fatal().Err(err).Msg("Failed to check if unauthenticated access is allowed") + } + + if !allow { + ctx.AddFlash(ctx.Tr("flash.auth.must-be-logged-in"), "error") + return ctx.RedirectTo("/login") + } + return next(ctx) + } + } +} + +func checkRequireLogin(next Handler) Handler { + return makeCheckRequireLogin(false)(next) +} + +func noRouteFound(ctx *context.Context) error { + return ctx.NotFound("Page not found") +} + +func locale(next Handler) Handler { + return func(ctx *context.Context) error { + // Check URL arguments + lang := ctx.Request().URL.Query().Get("lang") + changeLang := lang != "" + + // Then check cookies + if len(lang) == 0 { + cookie, _ := ctx.Request().Cookie("lang") + if cookie != nil { + lang = cookie.Value + } + } + + // Check again in case someone changes the supported language list. + if lang != "" && !i18n.Locales.HasLocale(lang) { + lang = "" + changeLang = false + } + + // 3.Then check from 'Accept-Language' header. + if len(lang) == 0 { + tags, _, _ := language.ParseAcceptLanguage(ctx.Request().Header.Get("Accept-Language")) + lang = i18n.Locales.MatchTag(tags) + } + + if changeLang { + ctx.SetCookie(&http.Cookie{Name: "lang", Value: lang, Path: "/", MaxAge: 1<<31 - 1}) + } + + localeUsed, err := i18n.Locales.GetLocale(lang) + if err != nil { + return ctx.ErrorRes(500, "Cannot get locale", err) + } + + ctx.SetData("localeName", localeUsed.Name) + ctx.SetData("locale", localeUsed) + ctx.SetData("allLocales", i18n.Locales.Locales) + + return next(ctx) + } +} + +func sessionInit(next Handler) Handler { + return func(ctx *context.Context) error { + sess := ctx.GetSession() + if sess.Values["user"] != nil { + var err error + var user *db.User + + if user, err = db.GetUserById(sess.Values["user"].(uint)); err != nil { + sess.Values["user"] = nil + ctx.SaveSession(sess) + ctx.User = nil + ctx.SetData("userLogged", nil) + return ctx.RedirectTo("/all") + } + if user != nil { + ctx.User = user + ctx.SetData("userLogged", user) + } + return next(ctx) + } + + ctx.User = nil + ctx.SetData("userLogged", nil) + return next(ctx) + } +} + +func csrfInit(next Handler) Handler { + return func(ctx *context.Context) error { + var csrf string + if csrfToken, ok := ctx.Get("csrf").(string); ok { + csrf = csrfToken + } + ctx.SetData("csrfHtml", template.HTML(``)) + ctx.SetData("csrfHtml", template.HTML(``)) + + return next(ctx) + } +} + +func loadSettings(ctx *context.Context) error { + settings, err := db.GetSettings() + if err != nil { + return err + } + + for key, value := range settings { + s := strings.ReplaceAll(key, "-", " ") + s = cases.Title(language.English).String(s) + ctx.SetData(strings.ReplaceAll(s, " ", ""), value == "1") + } + return nil +} + +func gistInit(next Handler) Handler { + return func(ctx *context.Context) error { + currUser := ctx.User + + userName := ctx.Param("user") + gistName := ctx.Param("gistname") + + switch filepath.Ext(gistName) { + case ".js": + ctx.SetData("gistpage", "js") + gistName = strings.TrimSuffix(gistName, ".js") + case ".json": + ctx.SetData("gistpage", "json") + gistName = strings.TrimSuffix(gistName, ".json") + case ".git": + ctx.SetData("gistpage", "git") + gistName = strings.TrimSuffix(gistName, ".git") + } + + gist, err := db.GetGist(userName, gistName) + if err != nil { + return ctx.NotFound("Gist not found") + } + + if gist.Private == db.PrivateVisibility { + if currUser == nil || currUser.ID != gist.UserID { + return ctx.NotFound("Gist not found") + } + } + + ctx.SetData("gist", gist) + + if config.C.SshGit { + var sshDomain string + + if config.C.SshExternalDomain != "" { + sshDomain = config.C.SshExternalDomain + } else { + sshDomain = strings.Split(ctx.Request().Host, ":")[0] + } + + if config.C.SshPort == "22" { + ctx.SetData("sshCloneUrl", sshDomain+":"+userName+"/"+gistName+".git") + } else { + ctx.SetData("sshCloneUrl", "ssh://"+sshDomain+":"+config.C.SshPort+"/"+userName+"/"+gistName+".git") + } + } + + baseHttpUrl := ctx.GetData("baseHttpUrl").(string) + + if config.C.HttpGit { + ctx.SetData("httpCloneUrl", baseHttpUrl+"/"+userName+"/"+gistName+".git") + } + + ctx.SetData("httpCopyUrl", baseHttpUrl+"/"+userName+"/"+gistName) + ctx.SetData("currentUrl", template.URL(ctx.Request().URL.Path)) + ctx.SetData("embedScript", fmt.Sprintf(``, baseHttpUrl+"/"+userName+"/"+gistName+".js")) + + nbCommits, err := gist.NbCommits() + if err != nil { + return ctx.ErrorRes(500, "Error fetching number of commits", err) + } + ctx.SetData("nbCommits", nbCommits) + + if currUser != nil { + hasLiked, err := currUser.HasLiked(gist) + if err != nil { + return ctx.ErrorRes(500, "Cannot get user like status", err) + } + ctx.SetData("hasLiked", hasLiked) + } + + if gist.Private > 0 { + ctx.SetData("NoIndex", true) + } + + return next(ctx) + } +} + +// gistSoftInit try to load a gist (same as gistInit) but does not return a 404 if the gist is not found +// useful for git clients using HTTP to obfuscate the existence of a private gist +func gistSoftInit(next Handler) Handler { + return func(ctx *context.Context) error { + userName := ctx.Param("user") + gistName := ctx.Param("gistname") + + gistName = strings.TrimSuffix(gistName, ".git") + + gist, _ := db.GetGist(userName, gistName) + ctx.SetData("gist", gist) + + return next(ctx) + } +} + +// gistNewPushSoftInit has the same behavior as gistSoftInit but create a new gist empty instead +func gistNewPushSoftInit(next Handler) Handler { + return func(ctx *context.Context) error { + ctx.SetData("gist", new(db.Gist)) + return next(ctx) + } +} diff --git a/internal/web/server/renderer.go b/internal/web/server/renderer.go new file mode 100644 index 0000000..ef34890 --- /dev/null +++ b/internal/web/server/renderer.go @@ -0,0 +1,213 @@ +package server + +import ( + gojson "encoding/json" + "errors" + "fmt" + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/git" + "github.com/thomiceli/opengist/internal/index" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers" + "github.com/thomiceli/opengist/public" + "github.com/thomiceli/opengist/templates" + htmlpkg "html" + "html/template" + "io" + "net/http" + "net/url" + "path/filepath" + "regexp" + "strconv" + "strings" + "time" +) + +type Template struct { + templates *template.Template +} + +func (t *Template) Render(w io.Writer, name string, data interface{}, _ echo.Context) error { + return t.templates.ExecuteTemplate(w, name, data) +} + +var re = regexp.MustCompile("[^a-z0-9]+") + +func (s *Server) setFuncMap() { + fm := template.FuncMap{ + "split": strings.Split, + "indexByte": strings.IndexByte, + "toInt": func(i string) int { + val, _ := strconv.Atoi(i) + return val + }, + "inc": func(i int) int { + return i + 1 + }, + "splitGit": func(i string) []string { + return strings.FieldsFunc(i, func(r rune) bool { + return r == ',' || r == ' ' + }) + }, + "lines": func(i string) []string { + return strings.Split(i, "\n") + }, + "isMarkdown": func(i string) bool { + return strings.ToLower(filepath.Ext(i)) == ".md" + }, + "isCsv": func(i string) bool { + return strings.ToLower(filepath.Ext(i)) == ".csv" + }, + "isSvg": func(i string) bool { + return strings.ToLower(filepath.Ext(i)) == ".svg" + }, + "csvFile": func(file *git.File) *git.CsvFile { + if strings.ToLower(filepath.Ext(file.Filename)) != ".csv" { + return nil + } + + csvFile, err := git.ParseCsv(file) + if err != nil { + return nil + } + + return csvFile + }, + "httpStatusText": http.StatusText, + "loadedTime": func(startTime time.Time) string { + return fmt.Sprint(time.Since(startTime).Nanoseconds()/1e6) + "ms" + }, + "slug": func(s string) string { + return strings.Trim(re.ReplaceAllString(strings.ToLower(s), "-"), "-") + }, + "avatarUrl": func(user *db.User, noGravatar bool) string { + if user.AvatarURL != "" { + return user.AvatarURL + } + + if user.MD5Hash != "" && !noGravatar { + return "https://www.gravatar.com/avatar/" + user.MD5Hash + "?d=identicon&s=200" + } + + if s.dev { + return "http://localhost:16157/default.png" + } + return config.C.ExternalUrl + "/" + context.ManifestEntries["default.png"].File + }, + "asset": func(file string) string { + if s.dev { + return "http://localhost:16157/" + file + } + return config.C.ExternalUrl + "/" + context.ManifestEntries[file].File + }, + "custom": func(file string) string { + assetpath, err := url.JoinPath("/", "assets", file) + if err != nil { + log.Error().Err(err).Msgf("Failed to join path for custom file %s", file) + } + return config.C.ExternalUrl + assetpath + }, + "dev": func() bool { + return s.dev + }, + "defaultAvatar": func() string { + if s.dev { + return "http://localhost:16157/default.png" + } + return config.C.ExternalUrl + "/" + context.ManifestEntries["default.png"].File + }, + "visibilityStr": func(visibility db.Visibility, lowercase bool) string { + s := "Public" + switch visibility { + case 1: + s = "Unlisted" + case 2: + s = "Private" + } + + if lowercase { + return strings.ToLower(s) + } + return s + }, + "unescape": htmlpkg.UnescapeString, + "join": func(s ...string) string { + return strings.Join(s, "") + }, + "toStr": func(i interface{}) string { + return fmt.Sprint(i) + }, + "safe": func(s string) template.HTML { + return template.HTML(s) + }, + "dict": func(values ...interface{}) (map[string]interface{}, error) { + if len(values)%2 != 0 { + return nil, errors.New("invalid dict call") + } + dict := make(map[string]interface{}) + for i := 0; i < len(values); i += 2 { + key, ok := values[i].(string) + if !ok { + return nil, errors.New("dict keys must be strings") + } + dict[key] = values[i+1] + } + return dict, nil + }, + "addMetadataToSearchQuery": func(input, key, value string) string { + content, metadata := handlers.ParseSearchQueryStr(input) + + metadata[key] = value + + var resultBuilder strings.Builder + resultBuilder.WriteString(content) + + for k, v := range metadata { + resultBuilder.WriteString(" ") + resultBuilder.WriteString(k) + resultBuilder.WriteString(":") + resultBuilder.WriteString(v) + } + + return strings.TrimSpace(resultBuilder.String()) + }, + "indexEnabled": index.Enabled, + "isUrl": func(s string) bool { + _, err := url.ParseRequestURI(s) + return err == nil + }, + } + + t := template.Must(template.New("t").Funcs(fm).ParseFS(templates.Files, "*/*.html")) + customPattern := filepath.Join(config.GetHomeDir(), "custom", "*.html") + matches, err := filepath.Glob(customPattern) + if err != nil { + log.Fatal().Err(err).Msg("Failed to check for custom templates") + } + if len(matches) > 0 { + t, err = t.ParseGlob(customPattern) + if err != nil { + log.Fatal().Err(err).Msg("Failed to parse custom templates") + } + } + s.echo.Renderer = &Template{ + templates: t, + } +} + +func (s *Server) parseManifestEntries() { + file, err := public.Files.Open("manifest.json") + if err != nil { + log.Fatal().Err(err).Msg("Failed to open manifest.json") + } + byteValue, err := io.ReadAll(file) + if err != nil { + log.Fatal().Err(err).Msg("Failed to read manifest.json") + } + if err = gojson.Unmarshal(byteValue, &context.ManifestEntries); err != nil { + log.Fatal().Err(err).Msg("Failed to unmarshal manifest.json") + } +} diff --git a/internal/web/server/router.go b/internal/web/server/router.go new file mode 100644 index 0000000..974516a --- /dev/null +++ b/internal/web/server/router.go @@ -0,0 +1,209 @@ +package server + +import ( + "github.com/labstack/echo/v4" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/index" + "github.com/thomiceli/opengist/internal/web/context" + "github.com/thomiceli/opengist/internal/web/handlers/admin" + "github.com/thomiceli/opengist/internal/web/handlers/auth" + "github.com/thomiceli/opengist/internal/web/handlers/gist" + "github.com/thomiceli/opengist/internal/web/handlers/git" + "github.com/thomiceli/opengist/internal/web/handlers/health" + "github.com/thomiceli/opengist/internal/web/handlers/settings" + "github.com/thomiceli/opengist/public" + "net/http" + "os" + "path" + "path/filepath" + "strings" + "time" +) + +func (s *Server) registerRoutes() { + r := NewRouter(s.echo.Group("")) + + { + r.GET("/", gist.Create, logged) + r.POST("/", gist.ProcessCreate, logged) + r.POST("/preview", gist.Preview, logged) + + r.GET("/healthcheck", health.Healthcheck) + r.GET("/metrics", health.Metrics) + + r.GET("/register", auth.Register) + r.POST("/register", auth.ProcessRegister) + r.GET("/login", auth.Login) + r.POST("/login", auth.ProcessLogin) + r.GET("/logout", auth.Logout) + r.GET("/oauth/:provider", auth.Oauth) + r.GET("/oauth/:provider/callback", auth.OauthCallback) + r.GET("/oauth/:provider/unlink", auth.OauthUnlink, logged) + r.POST("/webauthn/bind", auth.BeginWebAuthnBinding, logged) + r.POST("/webauthn/bind/finish", auth.FinishWebAuthnBinding, logged) + r.POST("/webauthn/login", auth.BeginWebAuthnLogin) + r.POST("/webauthn/login/finish", auth.FinishWebAuthnLogin) + r.POST("/webauthn/assertion", auth.BeginWebAuthnAssertion, inMFASession) + r.POST("/webauthn/assertion/finish", auth.FinishWebAuthnAssertion, inMFASession) + r.GET("/mfa", auth.Mfa, inMFASession) + r.POST("/mfa/totp/assertion", auth.AssertTotp, inMFASession) + + sA := r.SubGroup("/settings") + { + sA.Use(logged) + sA.GET("", settings.UserSettings) + sA.POST("/email", settings.EmailProcess) + sA.DELETE("/account", settings.AccountDeleteProcess) + sA.POST("/ssh-keys", settings.SshKeysProcess) + sA.DELETE("/ssh-keys/:id", settings.SshKeysDelete) + sA.DELETE("/passkeys/:id", settings.PasskeyDelete) + sA.PUT("/password", settings.PasswordProcess) + sA.PUT("/username", settings.UsernameProcess) + sA.GET("/totp/generate", auth.BeginTotp) + sA.POST("/totp/generate", auth.FinishTotp) + sA.DELETE("/totp", auth.DisableTotp) + sA.POST("/totp/regenerate", auth.RegenerateTotpRecoveryCodes) + } + + sB := r.SubGroup("/admin-panel") + { + sB.Use(adminPermission) + sB.GET("", admin.AdminIndex) + sB.GET("/users", admin.AdminUsers) + sB.POST("/users/:user/delete", admin.AdminUserDelete) + sB.GET("/gists", admin.AdminGists) + sB.POST("/gists/:gist/delete", admin.AdminGistDelete) + sB.GET("/invitations", admin.AdminInvitations) + sB.POST("/invitations", admin.AdminInvitationsCreate) + sB.POST("/invitations/:id/delete", admin.AdminInvitationsDelete) + sB.POST("/sync-fs", admin.AdminSyncReposFromFS) + sB.POST("/sync-db", admin.AdminSyncReposFromDB) + sB.POST("/gc-repos", admin.AdminGcRepos) + sB.POST("/sync-previews", admin.AdminSyncGistPreviews) + sB.POST("/reset-hooks", admin.AdminResetHooks) + sB.POST("/index-gists", admin.AdminIndexGists) + sB.GET("/configuration", admin.AdminConfig) + sB.PUT("/set-config", admin.AdminSetConfig) + } + + if config.C.HttpGit { + r.Any("/init/*", git.GitHttp, gistNewPushSoftInit) + } + + r.GET("/all", gist.AllGists, checkRequireLogin) + + if index.Enabled() { + r.GET("/search", gist.Search, checkRequireLogin) + } else { + r.GET("/search", gist.AllGists, checkRequireLogin) + } + + r.GET("/:user", gist.AllGists, checkRequireLogin) + r.GET("/:user/liked", gist.AllGists, checkRequireLogin) + r.GET("/:user/forked", gist.AllGists, checkRequireLogin) + + sC := r.SubGroup("/:user/:gistname") + { + sC.Use(makeCheckRequireLogin(true), gistInit) + sC.GET("", gist.GistIndex) + sC.GET("/rev/:revision", gist.GistIndex) + sC.GET("/revisions", gist.Revisions) + sC.GET("/archive/:revision", gist.DownloadZip) + sC.POST("/visibility", gist.EditVisibility, logged, writePermission) + sC.POST("/delete", gist.DeleteGist, logged, writePermission) + sC.GET("/raw/:revision/:file", gist.RawFile) + sC.GET("/download/:revision/:file", gist.DownloadFile) + sC.GET("/edit", gist.Edit, logged, writePermission) + sC.POST("/edit", gist.ProcessCreate, logged, writePermission) + sC.POST("/like", gist.Like, logged) + sC.GET("/likes", gist.Likes, checkRequireLogin) + sC.POST("/fork", gist.Fork, logged) + sC.GET("/forks", gist.Forks, checkRequireLogin) + sC.PUT("/checkbox", gist.Checkbox, logged, writePermission) + } + } + + customFs := os.DirFS(filepath.Join(config.GetHomeDir(), "custom")) + r.GET("/assets/*", func(ctx *context.Context) error { + if _, err := public.Files.Open(path.Join("assets", ctx.Param("*"))); !s.dev && err == nil { + ctx.Response().Header().Set("Cache-Control", "public, max-age=31536000") + ctx.Response().Header().Set("Expires", time.Now().AddDate(1, 0, 0).Format(http.TimeFormat)) + + return echo.WrapHandler(http.FileServer(http.FS(public.Files)))(ctx) + } + + // if the custom file is an .html template, render it + if strings.HasSuffix(ctx.Param("*"), ".html") { + if err := ctx.Html(ctx.Param("*")); err != nil { + return ctx.NotFound("Page not found") + } + return nil + } + + return echo.WrapHandler(http.StripPrefix("/assets/", http.FileServer(http.FS(customFs))))(ctx) + }) + + // Git HTTP routes + if config.C.HttpGit { + r.Any("/:user/:gistname/*", git.GitHttp, gistSoftInit) + } + + r.Any("/*", noRouteFound) +} + +// Router wraps echo.Group to provide custom Handler support +type Router struct { + *echo.Group +} + +func NewRouter(g *echo.Group) *Router { + return &Router{Group: g} +} + +func (r *Router) SubGroup(prefix string, m ...Middleware) *Router { + echoMiddleware := make([]echo.MiddlewareFunc, len(m)) + for i, mw := range m { + mw := mw // capture for closure + echoMiddleware[i] = func(next echo.HandlerFunc) echo.HandlerFunc { + return chain(func(c *context.Context) error { + return next(c) + }, mw).toEchoHandler() + } + } + return NewRouter(r.Group.Group(prefix, echoMiddleware...)) +} + +func (r *Router) GET(path string, h Handler, m ...Middleware) { + r.Group.GET(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) POST(path string, h Handler, m ...Middleware) { + r.Group.POST(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) PUT(path string, h Handler, m ...Middleware) { + r.Group.PUT(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) DELETE(path string, h Handler, m ...Middleware) { + r.Group.DELETE(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) PATCH(path string, h Handler, m ...Middleware) { + r.Group.PATCH(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) Any(path string, h Handler, m ...Middleware) { + r.Group.Any(path, chain(h, m...).toEchoHandler()) +} + +func (r *Router) Use(middleware ...Middleware) { + for _, m := range middleware { + m := m // capture for closure + r.Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return chain(func(c *context.Context) error { + return next(c) + }, m).toEchoHandler() + }) + } +} diff --git a/internal/web/server/server.go b/internal/web/server/server.go new file mode 100644 index 0000000..1c4092a --- /dev/null +++ b/internal/web/server/server.go @@ -0,0 +1,65 @@ +package server + +import ( + "github.com/thomiceli/opengist/internal/validator" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/rs/zerolog/log" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/i18n" +) + +type Server struct { + echo *echo.Echo + + dev bool + sessionsPath string + ignoreCsrf bool +} + +func NewServer(isDev bool, sessionsPath string, ignoreCsrf bool) *Server { + e := echo.New() + e.HideBanner = true + e.HidePort = true + e.Validator = validator.NewValidator() + + s := &Server{echo: e, dev: isDev, sessionsPath: sessionsPath, ignoreCsrf: ignoreCsrf} + + s.useCustomContext() + + if err := i18n.Locales.LoadAll(); err != nil { + log.Fatal().Err(err).Msg("Failed to load locales") + } + + s.registerMiddlewares() + s.setFuncMap() + s.echo.HTTPErrorHandler = s.errorHandler + + if !s.dev { + s.parseManifestEntries() + } + + s.registerRoutes() + + return s +} + +func (s *Server) Start() { + addr := config.C.HttpHost + ":" + config.C.HttpPort + + log.Info().Msg("Starting HTTP server on http://" + addr) + if err := s.echo.Start(addr); err != nil && err != http.ErrServerClosed { + log.Fatal().Err(err).Msg("Failed to start HTTP server") + } +} + +func (s *Server) Stop() { + if err := s.echo.Close(); err != nil { + log.Fatal().Err(err).Msg("Failed to stop HTTP server") + } +} + +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + s.echo.ServeHTTP(w, r) +} diff --git a/internal/web/settings.go b/internal/web/settings.go deleted file mode 100644 index f6b6324..0000000 --- a/internal/web/settings.go +++ /dev/null @@ -1,227 +0,0 @@ -package web - -import ( - "crypto/md5" - "fmt" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/git" - "github.com/thomiceli/opengist/internal/i18n" - "github.com/thomiceli/opengist/internal/utils" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/labstack/echo/v4" - "github.com/thomiceli/opengist/internal/db" - "golang.org/x/crypto/ssh" -) - -func userSettings(ctx echo.Context) error { - user := getUserLogged(ctx) - - keys, err := db.GetSSHKeysByUserID(user.ID) - if err != nil { - return errorRes(500, "Cannot get SSH keys", err) - } - - passkeys, err := db.GetAllCredentialsForUser(user.ID) - if err != nil { - return errorRes(500, "Cannot get WebAuthn credentials", err) - } - - _, hasTotp, err := user.HasMFA() - if err != nil { - return errorRes(500, "Cannot get MFA status", err) - } - - setData(ctx, "email", user.Email) - setData(ctx, "sshKeys", keys) - setData(ctx, "passkeys", passkeys) - setData(ctx, "hasTotp", hasTotp) - setData(ctx, "hasPassword", user.Password != "") - setData(ctx, "disableForm", getData(ctx, "DisableLoginForm")) - setData(ctx, "htmlTitle", trH(ctx, "settings")) - return html(ctx, "settings.html") -} - -func emailProcess(ctx echo.Context) error { - user := getUserLogged(ctx) - email := ctx.FormValue("email") - var hash string - - if email == "" { - // generate random md5 string - hash = fmt.Sprintf("%x", md5.Sum([]byte(time.Now().String()))) - } else { - hash = fmt.Sprintf("%x", md5.Sum([]byte(strings.ToLower(strings.TrimSpace(email))))) - } - - user.Email = strings.ToLower(email) - user.MD5Hash = hash - - if err := user.Update(); err != nil { - return errorRes(500, "Cannot update email", err) - } - - addFlash(ctx, tr(ctx, "flash.user.email-updated"), "success") - return redirect(ctx, "/settings") -} - -func accountDeleteProcess(ctx echo.Context) error { - user := getUserLogged(ctx) - - if err := user.Delete(); err != nil { - return errorRes(500, "Cannot delete this user", err) - } - - return redirect(ctx, "/all") -} - -func sshKeysProcess(ctx echo.Context) error { - user := getUserLogged(ctx) - - dto := new(db.SSHKeyDTO) - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, utils.ValidationMessages(&err, getData(ctx, "locale").(*i18n.Locale)), "error") - return redirect(ctx, "/settings") - } - key := dto.ToSSHKey() - - key.UserID = user.ID - - pubKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key.Content)) - if err != nil { - addFlash(ctx, tr(ctx, "flash.user.invalid-ssh-key"), "error") - return redirect(ctx, "/settings") - } - key.Content = strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey))) - - if exists, err := db.SSHKeyDoesExists(key.Content); exists { - if err != nil { - return errorRes(500, "Cannot check if SSH key exists", err) - } - addFlash(ctx, tr(ctx, "settings.ssh-key-exists"), "error") - return redirect(ctx, "/settings") - } - - if err := key.Create(); err != nil { - return errorRes(500, "Cannot add SSH key", err) - } - - addFlash(ctx, tr(ctx, "flash.user.ssh-key-added"), "success") - return redirect(ctx, "/settings") -} - -func sshKeysDelete(ctx echo.Context) error { - user := getUserLogged(ctx) - keyId, err := strconv.Atoi(ctx.Param("id")) - if err != nil { - return redirect(ctx, "/settings") - } - - key, err := db.GetSSHKeyByID(uint(keyId)) - - if err != nil || key.UserID != user.ID { - return redirect(ctx, "/settings") - } - - if err := key.Delete(); err != nil { - return errorRes(500, "Cannot delete SSH key", err) - } - - addFlash(ctx, tr(ctx, "flash.user.ssh-key-deleted"), "success") - return redirect(ctx, "/settings") -} - -func passkeyDelete(ctx echo.Context) error { - user := getUserLogged(ctx) - keyId, err := strconv.Atoi(ctx.Param("id")) - if err != nil { - return redirect(ctx, "/settings") - } - - passkey, err := db.GetCredentialByIDDB(uint(keyId)) - if err != nil || passkey.UserID != user.ID { - return redirect(ctx, "/settings") - } - - if err := passkey.Delete(); err != nil { - return errorRes(500, "Cannot delete passkey", err) - } - - addFlash(ctx, tr(ctx, "flash.auth.passkey-deleted"), "success") - return redirect(ctx, "/settings") -} - -func passwordProcess(ctx echo.Context) error { - user := getUserLogged(ctx) - - dto := new(db.UserDTO) - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - dto.Username = user.Username - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, utils.ValidationMessages(&err, getData(ctx, "locale").(*i18n.Locale)), "error") - return html(ctx, "settings.html") - } - - password, err := utils.Argon2id.Hash(dto.Password) - if err != nil { - return errorRes(500, "Cannot hash password", err) - } - user.Password = password - - if err = user.Update(); err != nil { - return errorRes(500, "Cannot update password", err) - } - - addFlash(ctx, tr(ctx, "flash.user.password-updated"), "success") - return redirect(ctx, "/settings") -} - -func usernameProcess(ctx echo.Context) error { - user := getUserLogged(ctx) - - dto := new(db.UserDTO) - if err := ctx.Bind(dto); err != nil { - return errorRes(400, tr(ctx, "error.cannot-bind-data"), err) - } - dto.Password = user.Password - - if err := ctx.Validate(dto); err != nil { - addFlash(ctx, utils.ValidationMessages(&err, getData(ctx, "locale").(*i18n.Locale)), "error") - return redirect(ctx, "/settings") - } - - if exists, err := db.UserExists(dto.Username); err != nil || exists { - addFlash(ctx, tr(ctx, "flash.auth.username-exists"), "error") - return redirect(ctx, "/settings") - } - - sourceDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(user.Username)) - destinationDir := filepath.Join(config.GetHomeDir(), git.ReposDirectory, strings.ToLower(dto.Username)) - - if _, err := os.Stat(sourceDir); !os.IsNotExist(err) { - err := os.Rename(sourceDir, destinationDir) - if err != nil { - return errorRes(500, "Cannot rename user directory", err) - } - } - - user.Username = dto.Username - - if err := user.Update(); err != nil { - return errorRes(500, "Cannot update username", err) - } - - addFlash(ctx, tr(ctx, "flash.user.username-updated"), "success") - return redirect(ctx, "/settings") -} diff --git a/internal/web/test/actions_test.go b/internal/web/test/actions_test.go new file mode 100644 index 0000000..db7f6d0 --- /dev/null +++ b/internal/web/test/actions_test.go @@ -0,0 +1,41 @@ +package test + +import ( + "github.com/stretchr/testify/require" + "github.com/thomiceli/opengist/internal/db" + "testing" +) + +func TestAdminActions(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + urls := []string{ + "/admin-panel/sync-fs", + "/admin-panel/sync-db", + "/admin-panel/gc-repos", + "/admin-panel/sync-previews", + "/admin-panel/reset-hooks", + "/admin-panel/index-gists", + } + + for _, url := range urls { + err := s.Request("POST", url, nil, 404) + require.NoError(t, err) + } + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + login(t, s, user1) + for _, url := range urls { + err := s.Request("POST", url, nil, 302) + require.NoError(t, err) + } + + user2 := db.UserDTO{Username: "nonadmin", Password: "nonadmin"} + register(t, s, user2) + login(t, s, user2) + for _, url := range urls { + err := s.Request("POST", url, nil, 404) + require.NoError(t, err) + } +} diff --git a/internal/web/test/admin_test.go b/internal/web/test/admin_test.go new file mode 100644 index 0000000..a18ebb4 --- /dev/null +++ b/internal/web/test/admin_test.go @@ -0,0 +1,261 @@ +package test + +import ( + "github.com/stretchr/testify/require" + "github.com/thomiceli/opengist/internal/config" + "github.com/thomiceli/opengist/internal/db" + "github.com/thomiceli/opengist/internal/git" + "os" + "path/filepath" + "strconv" + "testing" + "time" +) + +func TestAdminPages(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + urls := []string{ + "/admin-panel", + "/admin-panel/users", + "/admin-panel/gists", + "/admin-panel/invitations", + "/admin-panel/configuration", + } + + for _, url := range urls { + err := s.Request("GET", url, nil, 404) + require.NoError(t, err) + } + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + login(t, s, user1) + for _, url := range urls { + err := s.Request("GET", url, nil, 200) + require.NoError(t, err) + } + + user2 := db.UserDTO{Username: "nonadmin", Password: "nonadmin"} + register(t, s, user2) + login(t, s, user2) + for _, url := range urls { + err := s.Request("GET", url, nil, 404) + require.NoError(t, err) + } +} + +func TestSetConfig(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + settings := []string{ + db.SettingDisableSignup, + db.SettingRequireLogin, + db.SettingAllowGistsWithoutLogin, + db.SettingDisableLoginForm, + db.SettingDisableGravatar, + } + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + login(t, s, user1) + + for _, setting := range settings { + val, err := db.GetSetting(setting) + require.NoError(t, err) + require.Equal(t, "0", val) + + err = s.Request("PUT", "/admin-panel/set-config", settingSet{setting, "1"}, 200) + require.NoError(t, err) + + val, err = db.GetSetting(setting) + require.NoError(t, err) + require.Equal(t, "1", val) + + err = s.Request("PUT", "/admin-panel/set-config", settingSet{setting, "0"}, 200) + require.NoError(t, err) + + val, err = db.GetSetting(setting) + require.NoError(t, err) + require.Equal(t, "0", val) + } +} + +func TestPagination(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + for i := 0; i < 11; i++ { + user := db.UserDTO{Username: "user" + strconv.Itoa(i), Password: "user" + strconv.Itoa(i)} + register(t, s, user) + } + + login(t, s, user1) + + err := s.Request("GET", "/admin-panel/users", nil, 200) + require.NoError(t, err) + + err = s.Request("GET", "/admin-panel/users?page=2", nil, 200) + require.NoError(t, err) + + err = s.Request("GET", "/admin-panel/users?page=3", nil, 404) + require.NoError(t, err) + + err = s.Request("GET", "/admin-panel/users?page=0", nil, 200) + require.NoError(t, err) + + err = s.Request("GET", "/admin-panel/users?page=-1", nil, 200) + require.NoError(t, err) + + err = s.Request("GET", "/admin-panel/users?page=a", nil, 200) + require.NoError(t, err) +} + +func TestAdminUser(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + user2 := db.UserDTO{Username: "nonadmin", Password: "nonadmin"} + register(t, s, user1) + register(t, s, user2) + + login(t, s, user2) + + gist1 := db.GistDTO{ + Title: "gist", + VisibilityDTO: db.VisibilityDTO{ + Private: 0, + }, + Name: []string{"gist1.txt"}, + Content: []string{"yeah"}, + } + err := s.Request("POST", "/", gist1, 302) + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(config.GetHomeDir(), git.ReposDirectory, user2.Username)) + require.NoError(t, err) + + count, err := db.CountAll(db.User{}) + require.NoError(t, err) + require.Equal(t, int64(2), count) + + login(t, s, user1) + + err = s.Request("POST", "/admin-panel/users/2/delete", nil, 302) + require.NoError(t, err) + + count, err = db.CountAll(db.User{}) + require.NoError(t, err) + require.Equal(t, int64(1), count) + + _, err = os.Stat(filepath.Join(config.GetHomeDir(), git.ReposDirectory, user2.Username)) + require.Error(t, err) +} + +func TestAdminGist(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + login(t, s, user1) + + gist1 := db.GistDTO{ + Title: "gist", + VisibilityDTO: db.VisibilityDTO{ + Private: 0, + }, + Name: []string{"gist1.txt"}, + Content: []string{"yeah"}, + } + err := s.Request("POST", "/", gist1, 302) + require.NoError(t, err) + + count, err := db.CountAll(db.Gist{}) + require.NoError(t, err) + require.Equal(t, int64(1), count) + + gist1Db, err := db.GetGistByID("1") + require.NoError(t, err) + + _, err = os.Stat(filepath.Join(config.GetHomeDir(), git.ReposDirectory, user1.Username, gist1Db.Identifier())) + require.NoError(t, err) + + err = s.Request("POST", "/admin-panel/gists/1/delete", nil, 302) + require.NoError(t, err) + + count, err = db.CountAll(db.Gist{}) + require.NoError(t, err) + require.Equal(t, int64(0), count) + + _, err = os.Stat(filepath.Join(config.GetHomeDir(), git.ReposDirectory, user1.Username, gist1Db.Identifier())) + require.Error(t, err) +} + +func TestAdminInvitation(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + + user1 := db.UserDTO{Username: "admin", Password: "admin"} + register(t, s, user1) + login(t, s, user1) + + err := s.Request("POST", "/admin-panel/invitations", invitationAdmin{ + nbMax: "", + expiredAtUnix: "", + }, 302) + require.NoError(t, err) + invitation1, err := db.GetInvitationByID(1) + require.NoError(t, err) + require.Equal(t, invitation1, &db.Invitation{ + ID: 1, + Code: invitation1.Code, + ExpiresAt: time.Now().Unix() + 604800, + NbUsed: 0, + NbMax: 10, + }) + + err = s.Request("POST", "/admin-panel/invitations", invitationAdmin{ + nbMax: "aa", + expiredAtUnix: "1735722000", + }, 302) + require.NoError(t, err) + invitation2, err := db.GetInvitationByID(2) + require.NoError(t, err) + require.Equal(t, invitation2, &db.Invitation{ + ID: 2, + Code: invitation2.Code, + ExpiresAt: time.Unix(1735722000, 0).Unix(), + NbUsed: 0, + NbMax: 10, + }) + + err = s.Request("POST", "/admin-panel/invitations", invitationAdmin{ + nbMax: "20", + expiredAtUnix: "1735722000", + }, 302) + require.NoError(t, err) + invitation3, err := db.GetInvitationByID(3) + require.NoError(t, err) + require.Equal(t, invitation3, &db.Invitation{ + ID: 3, + Code: invitation3.Code, + ExpiresAt: time.Unix(1735722000, 0).Unix(), + NbUsed: 0, + NbMax: 20, + }) + + count, err := db.CountAll(db.Invitation{}) + require.NoError(t, err) + require.Equal(t, int64(3), count) + + err = s.Request("POST", "/admin-panel/invitations/1/delete", nil, 302) + require.NoError(t, err) + + count, err = db.CountAll(db.Invitation{}) + require.NoError(t, err) + require.Equal(t, int64(2), count) +} diff --git a/internal/web/test/auth_test.go b/internal/web/test/auth_test.go index 16bce8e..229a7c9 100644 --- a/internal/web/test/auth_test.go +++ b/internal/web/test/auth_test.go @@ -1,7 +1,7 @@ package test import ( - "fmt" + "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" "github.com/thomiceli/opengist/internal/config" "github.com/thomiceli/opengist/internal/db" @@ -12,13 +12,13 @@ import ( ) func TestRegister(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) - err := s.request("GET", "/", nil, 302) + err := s.Request("GET", "/", nil, 302) require.NoError(t, err) - err = s.request("GET", "/register", nil, 200) + err = s.Request("GET", "/register", nil, 200) require.NoError(t, err) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} @@ -29,13 +29,13 @@ func TestRegister(t *testing.T) { require.Equal(t, user1.Username, user1db.Username) require.True(t, user1db.IsAdmin) - err = s.request("GET", "/", nil, 200) + err = s.Request("GET", "/", nil, 200) require.NoError(t, err) s.sessionCookie = "" user2 := db.UserDTO{Username: "thomas", Password: "azeaze"} - err = s.request("POST", "/register", user2, 200) + err = s.Request("POST", "/register", user2, 200) require.Error(t, err) user3 := db.UserDTO{Username: "kaguya", Password: "kaguya"} @@ -53,10 +53,10 @@ func TestRegister(t *testing.T) { } func TestLogin(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) - err := s.request("GET", "/login", nil, 200) + err := s.Request("GET", "/login", nil, 200) require.NoError(t, err) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} @@ -72,38 +72,33 @@ func TestLogin(t *testing.T) { user2 := db.UserDTO{Username: "thomas", Password: "azeaze"} user3 := db.UserDTO{Username: "azeaze", Password: ""} - err = s.request("POST", "/login", user2, 302) + err = s.Request("POST", "/login", user2, 302) require.Empty(t, s.sessionCookie) require.Error(t, err) - err = s.request("POST", "/login", user3, 302) + err = s.Request("POST", "/login", user3, 302) require.Empty(t, s.sessionCookie) require.Error(t, err) } -func register(t *testing.T, s *testServer, user db.UserDTO) { - err := s.request("POST", "/register", user, 302) +func register(t *testing.T, s *TestServer, user db.UserDTO) { + err := s.Request("POST", "/register", user, 302) require.NoError(t, err) } -func login(t *testing.T, s *testServer, user db.UserDTO) { - err := s.request("POST", "/login", user, 302) +func login(t *testing.T, s *TestServer, user db.UserDTO) { + err := s.Request("POST", "/login", user, 302) require.NoError(t, err) } -type settingSet struct { - key string `form:"key"` - value string `form:"value"` -} - func TestAnonymous(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) user := db.UserDTO{Username: "thomas", Password: "azeaze"} register(t, s, user) - err := s.request("PUT", "/admin-panel/set-config", settingSet{"require-login", "1"}, 200) + err := s.Request("PUT", "/admin-panel/set-config", settingSet{"require-login", "1"}, 200) require.NoError(t, err) gist1 := db.GistDTO{ @@ -115,41 +110,41 @@ func TestAnonymous(t *testing.T) { Name: []string{"gist1.txt", "gist2.txt", "gist3.txt"}, Content: []string{"yeah", "yeah\ncool", "yeah\ncool gist actually"}, } - err = s.request("POST", "/", gist1, 302) + err = s.Request("POST", "/", gist1, 302) require.NoError(t, err) gist1db, err := db.GetGistByID("1") require.NoError(t, err) - err = s.request("GET", "/all", nil, 200) + err = s.Request("GET", "/all", nil, 200) require.NoError(t, err) cookie := s.sessionCookie s.sessionCookie = "" - err = s.request("GET", "/all", nil, 302) + err = s.Request("GET", "/all", nil, 302) require.NoError(t, err) // Should redirect to login if RequireLogin - err = s.request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 302) + err = s.Request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 302) require.NoError(t, err) s.sessionCookie = cookie - err = s.request("PUT", "/admin-panel/set-config", settingSet{"allow-gists-without-login", "1"}, 200) + err = s.Request("PUT", "/admin-panel/set-config", settingSet{"allow-gists-without-login", "1"}, 200) require.NoError(t, err) s.sessionCookie = "" // Should return results - err = s.request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 200) + err = s.Request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 200) require.NoError(t, err) } func TestGitOperations(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) admin := db.UserDTO{Username: "thomas", Password: "thomas"} register(t, s, admin) @@ -170,7 +165,7 @@ func TestGitOperations(t *testing.T) { "yeah", }, } - err := s.request("POST", "/", gist1, 302) + err := s.Request("POST", "/", gist1, 302) require.NoError(t, err) gist2 := db.GistDTO{ @@ -185,7 +180,7 @@ func TestGitOperations(t *testing.T) { "cool", }, } - err = s.request("POST", "/", gist2, 302) + err = s.Request("POST", "/", gist2, 302) require.NoError(t, err) gist3 := db.GistDTO{ @@ -200,11 +195,11 @@ func TestGitOperations(t *testing.T) { "super", }, } - err = s.request("POST", "/", gist3, 302) + err = s.Request("POST", "/", gist3, 302) require.NoError(t, err) gitOperations := func(credentials, owner, url, filename string, expectErrorClone, expectErrorCheck, expectErrorPush bool) { - fmt.Println("Testing", credentials, url, expectErrorClone, expectErrorCheck, expectErrorPush) + log.Debug().Msgf("Testing %s %s %t %t %t", credentials, url, expectErrorClone, expectErrorCheck, expectErrorPush) err := clientGitClone(credentials, owner, url) if expectErrorClone { require.Error(t, err) @@ -249,7 +244,7 @@ func TestGitOperations(t *testing.T) { } login(t, s, admin) - err = s.request("PUT", "/admin-panel/set-config", settingSet{"require-login", "1"}, 200) + err = s.Request("PUT", "/admin-panel/set-config", settingSet{"require-login", "1"}, 200) require.NoError(t, err) testsRequireLogin := []struct { @@ -276,7 +271,7 @@ func TestGitOperations(t *testing.T) { } login(t, s, admin) - err = s.request("PUT", "/admin-panel/set-config", settingSet{"allow-gists-without-login", "1"}, 200) + err = s.Request("PUT", "/admin-panel/set-config", settingSet{"allow-gists-without-login", "1"}, 200) require.NoError(t, err) for _, test := range tests { diff --git a/internal/web/test/gist_test.go b/internal/web/test/gist_test.go index 310cc66..b46621c 100644 --- a/internal/web/test/gist_test.go +++ b/internal/web/test/gist_test.go @@ -9,19 +9,19 @@ import ( ) func TestGists(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) - err := s.request("GET", "/", nil, 302) + err := s.Request("GET", "/", nil, 302) require.NoError(t, err) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} register(t, s, user1) - err = s.request("GET", "/all", nil, 200) + err = s.Request("GET", "/all", nil, 200) require.NoError(t, err) - err = s.request("POST", "/", nil, 200) + err = s.Request("POST", "/", nil, 200) require.NoError(t, err) gist1 := db.GistDTO{ @@ -33,7 +33,7 @@ func TestGists(t *testing.T) { Name: []string{"gist1.txt", "gist2.txt", "gist3.txt"}, Content: []string{"yeah", "yeah\ncool", "yeah\ncool gist actually"}, } - err = s.request("POST", "/", gist1, 302) + err = s.Request("POST", "/", gist1, 302) require.NoError(t, err) gist1db, err := db.GetGistByID("1") @@ -44,7 +44,7 @@ func TestGists(t *testing.T) { require.Regexp(t, "[a-f0-9]{32}", gist1db.Uuid) require.Equal(t, user1.Username, gist1db.User.Username) - err = s.request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 200) + err = s.Request("GET", "/"+gist1db.User.Username+"/"+gist1db.Uuid, nil, 200) require.NoError(t, err) gist1files, err := git.GetFilesOfRepository(gist1db.User.Username, gist1db.Uuid, "HEAD") @@ -64,7 +64,7 @@ func TestGists(t *testing.T) { Name: []string{"", "gist2.txt", "gist3.txt"}, Content: []string{"", "yeah\ncool", "yeah\ncool gist actually"}, } - err = s.request("POST", "/", gist2, 200) + err = s.Request("POST", "/", gist2, 200) require.NoError(t, err) gist3 := db.GistDTO{ @@ -76,7 +76,7 @@ func TestGists(t *testing.T) { Name: []string{""}, Content: []string{"yeah"}, } - err = s.request("POST", "/", gist3, 302) + err = s.Request("POST", "/", gist3, 302) require.NoError(t, err) gist3db, err := db.GetGistByID("2") @@ -86,26 +86,26 @@ func TestGists(t *testing.T) { require.NoError(t, err) require.Equal(t, "gistfile1.txt", gist3files[0]) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/edit", nil, 200) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/edit", nil, 200) require.NoError(t, err) gist1.Name = []string{"gist1.txt"} gist1.Content = []string{"only want one gist"} - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/edit", gist1, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/edit", gist1, 302) require.NoError(t, err) gist1files, err = git.GetFilesOfRepository(gist1db.User.Username, gist1db.Uuid, "HEAD") require.NoError(t, err) require.Equal(t, 1, len(gist1files)) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/delete", nil, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/delete", nil, 302) require.NoError(t, err) } func TestVisibility(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} register(t, s, user1) @@ -119,26 +119,26 @@ func TestVisibility(t *testing.T) { Name: []string{""}, Content: []string{"yeah"}, } - err := s.request("POST", "/", gist1, 302) + err := s.Request("POST", "/", gist1, 302) require.NoError(t, err) gist1db, err := db.GetGistByID("1") require.NoError(t, err) require.Equal(t, db.UnlistedVisibility, gist1db.Private) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.PrivateVisibility}, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.PrivateVisibility}, 302) require.NoError(t, err) gist1db, err = db.GetGistByID("1") require.NoError(t, err) require.Equal(t, db.PrivateVisibility, gist1db.Private) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.PublicVisibility}, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.PublicVisibility}, 302) require.NoError(t, err) gist1db, err = db.GetGistByID("1") require.NoError(t, err) require.Equal(t, db.PublicVisibility, gist1db.Private) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.UnlistedVisibility}, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/visibility", db.VisibilityDTO{Private: db.UnlistedVisibility}, 302) require.NoError(t, err) gist1db, err = db.GetGistByID("1") require.NoError(t, err) @@ -146,8 +146,8 @@ func TestVisibility(t *testing.T) { } func TestLikeFork(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} register(t, s, user1) @@ -161,7 +161,7 @@ func TestLikeFork(t *testing.T) { Name: []string{""}, Content: []string{"yeah"}, } - err := s.request("POST", "/", gist1, 302) + err := s.Request("POST", "/", gist1, 302) require.NoError(t, err) s.sessionCookie = "" @@ -176,7 +176,7 @@ func TestLikeFork(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(0), likeCount) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/like", nil, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/like", nil, 302) require.NoError(t, err) gist1db, err = db.GetGistByID("1") require.NoError(t, err) @@ -185,7 +185,7 @@ func TestLikeFork(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(1), likeCount) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/like", nil, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/like", nil, 302) require.NoError(t, err) gist1db, err = db.GetGistByID("1") require.NoError(t, err) @@ -194,7 +194,7 @@ func TestLikeFork(t *testing.T) { require.NoError(t, err) require.Equal(t, int64(0), likeCount) - err = s.request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/fork", nil, 302) + err = s.Request("POST", "/"+gist1db.User.Username+"/"+gist1db.Uuid+"/fork", nil, 302) require.NoError(t, err) gist2db, err := db.GetGistByID("2") require.NoError(t, err) @@ -205,8 +205,8 @@ func TestLikeFork(t *testing.T) { } func TestCustomUrl(t *testing.T) { - s := setup(t) - defer teardown(t, s) + s := Setup(t) + defer Teardown(t, s) user1 := db.UserDTO{Username: "thomas", Password: "thomas"} register(t, s, user1) @@ -221,7 +221,7 @@ func TestCustomUrl(t *testing.T) { Name: []string{"gist1.txt", "gist2.txt", "gist3.txt"}, Content: []string{"yeah", "yeah\ncool", "yeah\ncool gist actually"}, } - err := s.request("POST", "/", gist1, 302) + err := s.Request("POST", "/", gist1, 302) require.NoError(t, err) gist1db, err := db.GetGistByID("1") @@ -252,7 +252,7 @@ func TestCustomUrl(t *testing.T) { Name: []string{"gist1.txt", "gist2.txt", "gist3.txt"}, Content: []string{"yeah", "yeah\ncool", "yeah\ncool gist actually"}, } - err = s.request("POST", "/", gist2, 302) + err = s.Request("POST", "/", gist2, 302) require.NoError(t, err) gist2db, err := db.GetGistByID("2") diff --git a/internal/web/test/server.go b/internal/web/test/server.go index ec9d6d0..e4740c1 100644 --- a/internal/web/test/server.go +++ b/internal/web/test/server.go @@ -11,9 +11,11 @@ import ( "path" "path/filepath" "reflect" + "runtime" "strconv" "strings" "testing" + "time" "github.com/rs/zerolog/log" "github.com/stretchr/testify/require" @@ -21,34 +23,34 @@ import ( "github.com/thomiceli/opengist/internal/db" "github.com/thomiceli/opengist/internal/git" "github.com/thomiceli/opengist/internal/memdb" - "github.com/thomiceli/opengist/internal/web" + "github.com/thomiceli/opengist/internal/web/server" ) var databaseType string -type testServer struct { - server *web.Server +type TestServer struct { + server *server.Server sessionCookie string } -func newTestServer() (*testServer, error) { - s := &testServer{ - server: web.NewServer(true, path.Join(config.GetHomeDir(), "tmp", "sessions"), true), +func newTestServer() (*TestServer, error) { + s := &TestServer{ + server: server.NewServer(true, path.Join(config.GetHomeDir(), "tmp", "sessions"), true), } go s.start() return s, nil } -func (s *testServer) start() { +func (s *TestServer) start() { s.server.Start() } -func (s *testServer) stop() { +func (s *TestServer) stop() { s.server.Stop() } -func (s *testServer) request(method, uri string, data interface{}, expectedCode int) error { +func (s *TestServer) Request(method, uri string, data interface{}, expectedCode int) error { var bodyReader io.Reader if method == http.MethodPost || method == http.MethodPut { values := structToURLValues(data) @@ -133,18 +135,7 @@ func structToURLValues(s interface{}) url.Values { return v } -func setup(t *testing.T) *testServer { - var databaseDsn string - databaseType = os.Getenv("OPENGIST_TEST_DB") - switch databaseType { - case "sqlite": - databaseDsn = "file::memory:" - case "postgres": - databaseDsn = "postgres://postgres:opengist@localhost:5432/opengist_test" - case "mysql": - databaseDsn = "mysql://root:opengist@localhost:3306/opengist_test" - } - +func Setup(t *testing.T) *TestServer { _ = os.Setenv("OPENGIST_SKIP_GIT_HOOKS", "1") err := config.InitConfig("", io.Discard) @@ -158,19 +149,35 @@ func setup(t *testing.T) *testServer { git.ReposDirectory = path.Join("tests") config.C.IndexEnabled = false - config.C.LogLevel = "debug" + config.C.LogLevel = "error" config.InitLog() homePath := config.GetHomeDir() log.Info().Msg("Data directory: " + homePath) + var databaseDsn string + databaseType = os.Getenv("OPENGIST_TEST_DB") + switch databaseType { + case "sqlite": + databaseDsn = "file:" + filepath.Join(homePath, "tmp", "opengist.db") + case "postgres": + databaseDsn = "postgres://postgres:opengist@localhost:5432/opengist_test" + case "mysql": + databaseDsn = "mysql://root:opengist@localhost:3306/opengist_test" + default: + databaseDsn = ":memory:" + } + + err = os.MkdirAll(filepath.Join(homePath, "tests"), 0755) + require.NoError(t, err, "Could not create tests directory") + err = os.MkdirAll(filepath.Join(homePath, "tmp", "sessions"), 0755) require.NoError(t, err, "Could not create sessions directory") err = os.MkdirAll(filepath.Join(homePath, "tmp", "repos"), 0755) require.NoError(t, err, "Could not create tmp repos directory") - err = db.Setup(databaseDsn, true) + err = db.Setup(databaseDsn) require.NoError(t, err, "Could not initialize database") if err != nil { @@ -189,27 +196,40 @@ func setup(t *testing.T) *testServer { return s } -func teardown(t *testing.T, s *testServer) { +func Teardown(t *testing.T, s *TestServer) { s.stop() //err := db.Close() //require.NoError(t, err, "Could not close database") - err := os.RemoveAll(path.Join(config.GetHomeDir(), "tests")) - require.NoError(t, err, "Could not remove repos directory") - - err = os.RemoveAll(path.Join(config.GetHomeDir(), "tmp", "repos")) - require.NoError(t, err, "Could not remove repos directory") - - err = os.RemoveAll(path.Join(config.GetHomeDir(), "tmp", "sessions")) - require.NoError(t, err, "Could not remove repos directory") - - err = db.TruncateDatabase() + err := db.TruncateDatabase() require.NoError(t, err, "Could not truncate database") + err = os.RemoveAll(path.Join(config.GetHomeDir(), "tests")) + require.NoError(t, err, "Could not remove repos directory") + + if runtime.GOOS == "windows" { + err = db.Close() + require.NoError(t, err, "Could not close database") + + time.Sleep(2 * time.Second) + } + err = os.RemoveAll(path.Join(config.GetHomeDir(), "tmp")) + require.NoError(t, err, "Could not remove tmp directory") + // err = os.RemoveAll(path.Join(config.C.OpengistHome, "testsindex")) // require.NoError(t, err, "Could not remove repos directory") // err = index.Close() // require.NoError(t, err, "Could not close index") } + +type settingSet struct { + key string `form:"key"` + value string `form:"value"` +} + +type invitationAdmin struct { + nbMax string `form:"nbMax"` + expiredAtUnix string `form:"expiredAtUnix"` +} diff --git a/internal/web/test/settings_test.go b/internal/web/test/settings_test.go new file mode 100644 index 0000000..5f8fe4a --- /dev/null +++ b/internal/web/test/settings_test.go @@ -0,0 +1,22 @@ +package test + +import ( + "github.com/stretchr/testify/require" + "github.com/thomiceli/opengist/internal/db" + "testing" +) + +func TestSettingsPage(t *testing.T) { + s := Setup(t) + defer Teardown(t, s) + + err := s.Request("GET", "/settings", nil, 302) + require.NoError(t, err) + + user1 := db.UserDTO{Username: "thomas", Password: "thomas"} + register(t, s, user1) + login(t, s, user1) + + err = s.Request("GET", "/settings", nil, 200) + require.NoError(t, err) +} diff --git a/internal/web/util.go b/internal/web/util.go deleted file mode 100644 index d0adc7b..0000000 --- a/internal/web/util.go +++ /dev/null @@ -1,248 +0,0 @@ -package web - -import ( - "context" - "errors" - "github.com/gorilla/sessions" - "github.com/labstack/echo/v4" - "github.com/rs/zerolog/log" - "github.com/thomiceli/opengist/internal/config" - "github.com/thomiceli/opengist/internal/db" - "github.com/thomiceli/opengist/internal/i18n" - "golang.org/x/text/cases" - "golang.org/x/text/language" - "html/template" - "net/http" - "strconv" - "strings" -) - -type dataTypeKey string - -const dataKey dataTypeKey = "data" - -func setData(ctx echo.Context, key string, value any) { - data := ctx.Request().Context().Value(dataKey).(echo.Map) - data[key] = value - ctxValue := context.WithValue(ctx.Request().Context(), dataKey, data) - ctx.SetRequest(ctx.Request().WithContext(ctxValue)) -} - -func getData(ctx echo.Context, key string) any { - data := ctx.Request().Context().Value(dataKey).(echo.Map) - return data[key] -} - -func dataMap(ctx echo.Context) echo.Map { - return ctx.Request().Context().Value(dataKey).(echo.Map) -} - -func html(ctx echo.Context, template string) error { - return htmlWithCode(ctx, 200, template) -} - -func htmlWithCode(ctx echo.Context, code int, template string) error { - setErrorFlashes(ctx) - return ctx.Render(code, template, ctx.Request().Context().Value(dataKey)) -} - -func json(ctx echo.Context, data any) error { - return jsonWithCode(ctx, 200, data) -} - -func jsonWithCode(ctx echo.Context, code int, data any) error { - return ctx.JSON(code, data) -} - -func redirect(ctx echo.Context, location string) error { - return ctx.Redirect(302, config.C.ExternalUrl+location) -} - -func plainText(ctx echo.Context, code int, message string) error { - return ctx.String(code, message) -} - -func notFound(message string) error { - return errorRes(404, message, nil) -} - -func errorRes(code int, message string, err error) error { - if code >= 500 { - var skipLogger = log.With().CallerWithSkipFrameCount(3).Logger() - skipLogger.Error().Err(err).Msg(message) - } - - return &echo.HTTPError{Code: code, Message: message, Internal: err} -} - -func jsonErrorRes(code int, message string, err error) error { - if code >= 500 { - var skipLogger = log.With().CallerWithSkipFrameCount(3).Logger() - skipLogger.Error().Err(err).Msg(message) - } - - return &echo.HTTPError{Code: code, Message: message, Internal: err} -} - -func getUserLogged(ctx echo.Context) *db.User { - user := getData(ctx, "userLogged") - if user != nil { - return user.(*db.User) - } - return nil -} - -func setErrorFlashes(ctx echo.Context) { - sess, _ := flashStore.Get(ctx.Request(), "flash") - - setData(ctx, "flashErrors", sess.Flashes("error")) - setData(ctx, "flashSuccess", sess.Flashes("success")) - setData(ctx, "flashWarnings", sess.Flashes("warning")) - - _ = sess.Save(ctx.Request(), ctx.Response()) -} - -func addFlash(ctx echo.Context, flashMessage string, flashType string) { - sess, _ := flashStore.Get(ctx.Request(), "flash") - sess.AddFlash(flashMessage, flashType) - _ = sess.Save(ctx.Request(), ctx.Response()) -} - -func getSession(ctx echo.Context) *sessions.Session { - sess, _ := userStore.Get(ctx.Request(), "session") - return sess -} - -func saveSession(sess *sessions.Session, ctx echo.Context) { - _ = sess.Save(ctx.Request(), ctx.Response()) -} - -func deleteSession(ctx echo.Context) { - sess := getSession(ctx) - sess.Options.MaxAge = -1 - saveSession(sess, ctx) -} - -func setCsrfHtmlForm(ctx echo.Context) { - var csrf string - if csrfToken, ok := ctx.Get("csrf").(string); ok { - csrf = csrfToken - } - setData(ctx, "csrfHtml", template.HTML(``)) -} - -func deleteCsrfCookie(ctx echo.Context) { - ctx.SetCookie(&http.Cookie{Name: "_csrf", Path: "/", MaxAge: -1}) -} - -func loadSettings(ctx echo.Context) error { - settings, err := db.GetSettings() - if err != nil { - return err - } - - for key, value := range settings { - s := strings.ReplaceAll(key, "-", " ") - s = cases.Title(language.English).String(s) - setData(ctx, strings.ReplaceAll(s, " ", ""), value == "1") - } - return nil -} - -func getPage(ctx echo.Context) int { - page := ctx.QueryParam("page") - if page == "" { - page = "1" - } - pageInt, err := strconv.Atoi(page) - if err != nil { - pageInt = 1 - } - setData(ctx, "currPage", pageInt) - - return pageInt -} - -func paginate[T any](ctx echo.Context, data []*T, pageInt int, perPage int, templateDataName string, urlPage string, labels int, urlParams ...string) error { - lenData := len(data) - if lenData == 0 && pageInt != 1 { - return errors.New("page not found") - } - - if lenData > perPage { - if lenData > 1 { - data = data[:lenData-1] - } - setData(ctx, "nextPage", pageInt+1) - } - if pageInt > 1 { - setData(ctx, "prevPage", pageInt-1) - } - - if len(urlParams) > 0 { - setData(ctx, "urlParams", template.URL(urlParams[0])) - } - - switch labels { - case 1: - setData(ctx, "prevLabel", trH(ctx, "pagination.previous")) - setData(ctx, "nextLabel", trH(ctx, "pagination.next")) - case 2: - setData(ctx, "prevLabel", trH(ctx, "pagination.newer")) - setData(ctx, "nextLabel", trH(ctx, "pagination.older")) - } - - setData(ctx, "urlPage", urlPage) - setData(ctx, templateDataName, data) - return nil -} - -func trH(ctx echo.Context, key string, args ...any) template.HTML { - l := getData(ctx, "locale").(*i18n.Locale) - return l.Tr(key, args...) -} - -func tr(ctx echo.Context, key string, args ...any) string { - l := getData(ctx, "locale").(*i18n.Locale) - return l.String(key, args...) -} - -func parseSearchQueryStr(query string) (string, map[string]string) { - words := strings.Fields(query) - metadata := make(map[string]string) - var contentBuilder strings.Builder - - for _, word := range words { - if strings.Contains(word, ":") { - keyValue := strings.SplitN(word, ":", 2) - if len(keyValue) == 2 { - key := keyValue[0] - value := keyValue[1] - metadata[key] = value - } - } else { - contentBuilder.WriteString(word + " ") - } - } - - content := strings.TrimSpace(contentBuilder.String()) - return content, metadata -} - -func addMetadataToSearchQuery(input, key, value string) string { - content, metadata := parseSearchQueryStr(input) - - metadata[key] = value - - var resultBuilder strings.Builder - resultBuilder.WriteString(content) - - for k, v := range metadata { - resultBuilder.WriteString(" ") - resultBuilder.WriteString(k) - resultBuilder.WriteString(":") - resultBuilder.WriteString(v) - } - - return strings.TrimSpace(resultBuilder.String()) -} diff --git a/templates/pages/error.html b/templates/pages/error.html index 5f868d3..fec6fd3 100644 --- a/templates/pages/error.html +++ b/templates/pages/error.html @@ -1,3 +1,4 @@ +{{ define "error" }} {{ template "header" .}}