diff --git a/helpers.go b/helpers.go index 00c786b..527a2c4 100644 --- a/helpers.go +++ b/helpers.go @@ -141,6 +141,21 @@ func fCopy(src, dst string) error { return err } +func fMove(src, dst string) error { + err := fCopy(src, dst) + if err != nil { + log.Warn(err) + return err + } + err = fDelete(src) + if err != nil { + log.Warn(err) + return err + } + + return nil +} + func fDownload(path, url string, basicAuth bool) error { client := &http.Client{} req, err := http.NewRequest("GET", url, nil) diff --git a/main.go b/main.go index 440662d..3de1780 100644 --- a/main.go +++ b/main.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "encoding/json" "encoding/pem" + "errors" "fmt" "github.com/google/uuid" "io/ioutil" @@ -22,6 +23,7 @@ import ( "sync" "text/template" "time" + "unicode/utf8" "github.com/gobuffalo/packr/v2" "github.com/prometheus/client_golang/prometheus" @@ -282,7 +284,13 @@ func (oAdmin *OvpnAdmin) userRotateHandler(w http.ResponseWriter, r *http.Reques return } _ = r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userRotate(r.FormValue("username"), r.FormValue("password"))) + err, msg := oAdmin.userRotate(r.FormValue("username"), r.FormValue("password")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, msg) + } } func (oAdmin *OvpnAdmin) userDeleteHandler(w http.ResponseWriter, r *http.Request) { @@ -292,7 +300,13 @@ func (oAdmin *OvpnAdmin) userDeleteHandler(w http.ResponseWriter, r *http.Reques return } _ = r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userDelete(r.FormValue("username"))) + err, msg := oAdmin.userDelete(r.FormValue("username")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, msg) + } } func (oAdmin *OvpnAdmin) userRevokeHandler(w http.ResponseWriter, r *http.Request) { @@ -302,7 +316,13 @@ func (oAdmin *OvpnAdmin) userRevokeHandler(w http.ResponseWriter, r *http.Reques return } _ = r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userRevoke(r.FormValue("username"))) + err, msg := oAdmin.userRevoke(r.FormValue("username")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, msg) + } } func (oAdmin *OvpnAdmin) userUnrevokeHandler(w http.ResponseWriter, r *http.Request) { @@ -311,24 +331,28 @@ func (oAdmin *OvpnAdmin) userUnrevokeHandler(w http.ResponseWriter, r *http.Requ http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - _ = r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userUnrevoke(r.FormValue("username"))) + err, msg := oAdmin.userUnrevoke(r.FormValue("username")) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, msg) + } } func (oAdmin *OvpnAdmin) userChangePasswordHandler(w http.ResponseWriter, r *http.Request) { log.Info(r.RemoteAddr, " ", r.RequestURI) _ = r.ParseForm() if *authByPassword { - passwordChanged, passwordChangeMessage := oAdmin.userChangePassword(r.FormValue("username"), r.FormValue("password")) - if passwordChanged { - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"ok", "message": "%s"}`, passwordChangeMessage) - return - } else { + err, msg := oAdmin.userChangePassword(r.FormValue("username"), r.FormValue("password")) + if err != nil { w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, `{"status":"error", "message": "%s"}`, passwordChangeMessage) - return + fmt.Fprintf(w, `{"status":"error", "message": "%s"}`, msg) + + } else { + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, `{"status":"ok", "message": "%s"}`, msg) } } else { http.Error(w, `{"status":"error"}`, http.StatusNotImplemented) @@ -827,16 +851,20 @@ func checkStaticAddressIsFree(staticAddress string, username string) bool { return false } -func validateUsername(username string) bool { +func validateUsername(username string) error { var validUsername = regexp.MustCompile(usernameRegexp) - return validUsername.MatchString(username) + if validUsername.MatchString(username) { + return nil + } else { + return errors.New(fmt.Sprintf("Username can only contains %s", usernameRegexp)) + } } -func validatePassword(password string) bool { - if len(password) < passwordMinLength { - return false +func validatePassword(password string) error { + if utf8.RuneCountInString(password) < passwordMinLength { + return errors.New(fmt.Sprintf("Password too short, password length must be greater or equal %d", passwordMinLength)) } else { - return true + return nil } } @@ -928,17 +956,15 @@ func (oAdmin *OvpnAdmin) userCreate(username, password string) (bool, string) { return false, ucErr } - if !validateUsername(username) { - ucErr = fmt.Sprintf("Username \"%s\" incorrect, you can use only %s\n", username, usernameRegexp) - log.Debugf("userCreate: validateUsername(): %s", ucErr) - return false, ucErr + if err := validateUsername(username); err != nil { + log.Debugf("userCreate: validateUsername(): %s", err.Error()) + return false, err.Error() } if *authByPassword { - if !validatePassword(password) { - ucErr = fmt.Sprintf("Password too short, password length must be greater or equal %d", passwordMinLength) - log.Debugf("userCreate: authByPassword(): %s", ucErr) - return false, ucErr + if err := validatePassword(password); err != nil { + log.Debugf("userCreate: authByPassword(): %s", err.Error()) + return false, err.Error() } } @@ -964,20 +990,18 @@ func (oAdmin *OvpnAdmin) userCreate(username, password string) (bool, string) { return true, ucErr } -func (oAdmin *OvpnAdmin) userChangePassword(username, password string) (bool, string) { +func (oAdmin *OvpnAdmin) userChangePassword(username, password string) (error, string) { if checkUserExist(username) { o := runBash(fmt.Sprintf("openvpn-user check --db.path %s --user %s | grep %s | wc -l", *authDatabase, username, username)) log.Debug(o) - if !validatePassword(password) { - ucpErr := fmt.Sprintf("Password for too short, password length must be greater or equal %d", passwordMinLength) - log.Warningf("userChangePassword: %s", ucpErr) - return false, ucpErr + if err := validatePassword(password); err != nil { + log.Warningf("userChangePassword: %s", err.Error()) + return err, err.Error() } if strings.TrimSpace(o) == "0" { - log.Debug("Creating user in users.db") o = runBash(fmt.Sprintf("openvpn-user create --db.path %s --user %s --password %s", *authDatabase, username, password)) log.Debug(o) } @@ -987,10 +1011,10 @@ func (oAdmin *OvpnAdmin) userChangePassword(username, password string) (bool, st log.Infof("Password for user %s was changed", username) - return true, "Password changed" + return nil, "Password changed" } - return false, "User does not exist" + return errors.New(fmt.Sprintf("User \"%s\" not found}", username)), fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) } func (oAdmin *OvpnAdmin) getUserStatistic(username string) []clientStatus { @@ -1003,7 +1027,7 @@ func (oAdmin *OvpnAdmin) getUserStatistic(username string) []clientStatus { return userStatistic } -func (oAdmin *OvpnAdmin) userRevoke(username string) string { +func (oAdmin *OvpnAdmin) userRevoke(username string) (error, string) { log.Infof("Revoke certificate for user %s", username) if checkUserExist(username) { // check certificate valid flag 'V' @@ -1018,7 +1042,8 @@ func (oAdmin *OvpnAdmin) userRevoke(username string) string { } if *authByPassword { - _ = runBash(fmt.Sprintf("openvpn-user revoke --db-path %s --user %s", *authDatabase, username)) + o := runBash(fmt.Sprintf("openvpn-user revoke --db-path %s --user %s", *authDatabase, username)) + log.Debug(o) } crlFix() @@ -1032,13 +1057,13 @@ func (oAdmin *OvpnAdmin) userRevoke(username string) string { } oAdmin.setState() - return fmt.Sprintf("user \"%s\" revoked", username) + return nil, fmt.Sprintf("user \"%s\" revoked", username) } log.Infof("user \"%s\" not found", username) - return fmt.Sprintf("User \"%s\" not found", username) + return errors.New(fmt.Sprintf("User \"%s\" not found}", username)), fmt.Sprintf("User \"%s\" not found", username) } -func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { +func (oAdmin *OvpnAdmin) userUnrevoke(username string) (error, string) { if checkUserExist(username) { if *storageBackend == "kubernetes.secrets" { err := app.easyrsaUnrevoke(username) @@ -1055,19 +1080,19 @@ func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { usersFromIndexTxt[i].Flag = "V" usersFromIndexTxt[i].RevocationDate = "" - err := fCopy(fmt.Sprintf("%s/pki/revoked/certs_by_serial/%s.crt", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/issued/%s.crt", *easyrsaDirPath, username)) + err := fMove(fmt.Sprintf("%s/pki/revoked/certs_by_serial/%s.crt", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/issued/%s.crt", *easyrsaDirPath, username)) if err != nil { log.Error(err) } - err = fCopy(fmt.Sprintf("%s/pki/revoked/certs_by_serial/%s.crt", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/certs_by_serial/%s.pem", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber)) + err = fMove(fmt.Sprintf("%s/pki/revoked/certs_by_serial/%s.crt", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/certs_by_serial/%s.pem", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber)) if err != nil { log.Error(err) } - err = fCopy(fmt.Sprintf("%s/pki/revoked/private_by_serial/%s.key", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/private/%s.key", *easyrsaDirPath, username)) + err = fMove(fmt.Sprintf("%s/pki/revoked/private_by_serial/%s.key", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/private/%s.key", *easyrsaDirPath, username)) if err != nil { log.Error(err) } - err = fCopy(fmt.Sprintf("%s/pki/revoked/reqs_by_serial/%s.req", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/reqs/%s.req", *easyrsaDirPath, username)) + err = fMove(fmt.Sprintf("%s/pki/revoked/reqs_by_serial/%s.req", *easyrsaDirPath, usersFromIndexTxt[i].SerialNumber), fmt.Sprintf("%s/pki/reqs/%s.req", *easyrsaDirPath, username)) if err != nil { log.Error(err) } @@ -1079,7 +1104,8 @@ func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { _ = runBash(fmt.Sprintf("cd %s && easyrsa gen-crl 1>/dev/null", *easyrsaDirPath)) if *authByPassword { - _ = runBash(fmt.Sprintf("openvpn-user restore --db-path %s --user %s", *authDatabase, username)) + o := runBash(fmt.Sprintf("openvpn-user restore --db-path %s --user %s", *authDatabase, username)) + log.Debug(o) } crlFix() @@ -1096,12 +1122,12 @@ func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { } crlFix() oAdmin.clients = oAdmin.usersList() - return fmt.Sprintf("{\"msg\":\"User %s successfully unrevoked\"}", username) + return nil, fmt.Sprintf("{\"msg\":\"User %s successfully unrevoked\"}", username) } - return fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) + return errors.New(fmt.Sprintf("user \"%s\" not found", username)), fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) } -func (oAdmin *OvpnAdmin) userRotate(username, newPassword string) string { +func (oAdmin *OvpnAdmin) userRotate(username, newPassword string) (error, string) { if checkUserExist(username) { if *storageBackend == "kubernetes.secrets" { err := app.easyrsaRotate(username, newPassword) @@ -1110,12 +1136,17 @@ func (oAdmin *OvpnAdmin) userRotate(username, newPassword string) string { } } else { - uniqHash := strings.Replace(uuid.New().String(), "-", "", -1) var oldUserIndex, newUserIndex int + var oldUserSerial string + + uniqHash := strings.Replace(uuid.New().String(), "-", "", -1) + usersFromIndexTxt := indexTxtParser(fRead(*indexTxtPath)) for i := range usersFromIndexTxt { if usersFromIndexTxt[i].DistinguishedName == "/CN="+username { + oldUserSerial = usersFromIndexTxt[i].SerialNumber usersFromIndexTxt[i].DistinguishedName = "/CN=REVOKED-" + username + "-" + uniqHash + oldUserIndex = i break } } @@ -1124,36 +1155,53 @@ func (oAdmin *OvpnAdmin) userRotate(username, newPassword string) string { log.Error(err) } - oAdmin.userCreate(username, newPassword) + if *authByPassword { + o := runBash(fmt.Sprintf("openvpn-user delete --force --db.path %s --user %s", *authDatabase, username)) + log.Debug(o) + } + + userCreated, userCreateMessage := oAdmin.userCreate(username, newPassword) + if !userCreated { + usersFromIndexTxt = indexTxtParser(fRead(*indexTxtPath)) + for i := range usersFromIndexTxt { + if usersFromIndexTxt[i].SerialNumber == oldUserSerial { + usersFromIndexTxt[i].DistinguishedName = "/CN=" + username + break + } + } + err = fWrite(*indexTxtPath, renderIndexTxt(usersFromIndexTxt)) + if err != nil { + log.Error(err) + } + return errors.New(fmt.Sprintf("error rotaing user due: %s", userCreateMessage)), userCreateMessage + } + usersFromIndexTxt = indexTxtParser(fRead(*indexTxtPath)) for i := range usersFromIndexTxt { if usersFromIndexTxt[i].DistinguishedName == "/CN="+username { newUserIndex = i } - if usersFromIndexTxt[i].DistinguishedName == "/CN=REVOKED-"+username+"-"+uniqHash { + if usersFromIndexTxt[i].SerialNumber == oldUserSerial { oldUserIndex = i } } usersFromIndexTxt[oldUserIndex], usersFromIndexTxt[newUserIndex] = usersFromIndexTxt[newUserIndex], usersFromIndexTxt[oldUserIndex] - if *authByPassword { - _ = runBash(fmt.Sprintf("openvpn-user change-password --db.path %s --user %s --password %s", *authDatabase, username, newPassword)) - } - err = fWrite(*indexTxtPath, renderIndexTxt(usersFromIndexTxt)) if err != nil { log.Error(err) } + _ = runBash(fmt.Sprintf("cd %s && easyrsa gen-crl 1>/dev/null", *easyrsaDirPath)) } crlFix() oAdmin.clients = oAdmin.usersList() - return fmt.Sprintf("{\"msg\":\"User %s successfully rotated\"}", username) + return nil, fmt.Sprintf("{\"msg\":\"User %s successfully rotated\"}", username) } - return fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) + return errors.New(fmt.Sprintf("user \"%s\" not found", username)), fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) } -func (oAdmin *OvpnAdmin) userDelete(username string) string { +func (oAdmin *OvpnAdmin) userDelete(username string) (error, string) { if checkUserExist(username) { if *storageBackend == "kubernetes.secrets" { err := app.easyrsaDelete(username) @@ -1180,9 +1228,9 @@ func (oAdmin *OvpnAdmin) userDelete(username string) string { } crlFix() oAdmin.clients = oAdmin.usersList() - return fmt.Sprintf("{\"msg\":\"User %s successfully deleted\"}", username) + return nil, fmt.Sprintf("{\"msg\":\"User %s successfully deleted\"}", username) } - return fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) + return errors.New(fmt.Sprintf("User \"%s\" not found}", username)), fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) } func (oAdmin *OvpnAdmin) mgmtRead(conn net.Conn) string {