diff --git a/helpers.go b/helpers.go index 3a9f800..716e9b0 100644 --- a/helpers.go +++ b/helpers.go @@ -10,7 +10,7 @@ import ( "time" ) -func parseDate(layout,datetime string) time.Time { +func parseDate(layout, datetime string) time.Time { t, err := time.Parse(layout, datetime) if err != nil { log.Println(err) @@ -18,11 +18,11 @@ func parseDate(layout,datetime string) time.Time { return t } -func parseDateToString(layout,datetime,format string) string { +func parseDateToString(layout, datetime, format string) string { return parseDate(layout, datetime).Format(format) } -func parseDateToUnix(layout,datetime string) int64 { +func parseDateToUnix(layout, datetime string) int64 { return parseDate(layout, datetime).Unix() } @@ -100,7 +100,7 @@ func fDownload(path, url string, basicAuth bool) error { } if resp.StatusCode != 200 { - log.Printf("WARNING: Download file operation for url %s finished with status code %d\n", url, resp.StatusCode ) + log.Printf("WARNING: Download file operation for url %s finished with status code %d\n", url, resp.StatusCode) } defer resp.Body.Close() diff --git a/main.go b/main.go index 327b14b..6bba8db 100644 --- a/main.go +++ b/main.go @@ -1,18 +1,24 @@ package main import ( + "archive/tar" "bufio" "bytes" + "compress/gzip" "context" "encoding/json" "fmt" + "io" "log" "net" "net/http" "os" + "os/signal" + "path/filepath" "regexp" "strconv" "strings" + "syscall" "text/template" "time" @@ -28,9 +34,6 @@ import ( ) const ( - usernameRegexp = `^([a-zA-Z0-9_.-@])+$` - passwordRegexp = `^([a-zA-Z0-9_.-@])+$` - passwordMinLength = 6 downloadCertsApiUrl = "/api/data/certs/download" downloadCcdApiUrl = "/api/data/ccd/download" certsArchiveFileName = "certs.tar.gz" @@ -44,33 +47,33 @@ const ( ) var ( - listenHost = kingpin.Flag("listen.host","host for ovpn-admin").Default("0.0.0.0").Envar("OVPN_LISTEN_HOST").String() - listenPort = kingpin.Flag("listen.port","port for ovpn-admin").Default("8080").Envar("OVPN_LISTEN_PORT").String() - serverRole = kingpin.Flag("role","server role, master or slave").Default("master").Envar("OVPN_ROLE").HintOptions("master", "slave").String() - masterHost = kingpin.Flag("master.host","URL for the master server").Default("http://127.0.0.1").Envar("OVPN_MASTER_HOST").String() - masterBasicAuthUser = kingpin.Flag("master.basic-auth.user","user for master server's Basic Auth").Default("").Envar("OVPN_MASTER_USER").String() - masterBasicAuthPassword = kingpin.Flag("master.basic-auth.password","password for master server's Basic Auth").Default("").Envar("OVPN_MASTER_PASSWORD").String() + listenHost = kingpin.Flag("listen.host", "host for ovpn-admin").Default("0.0.0.0").Envar("OVPN_LISTEN_HOST").String() + listenPort = kingpin.Flag("listen.port", "port for ovpn-admin").Default("8080").Envar("OVPN_LISTEN_PORT").String() + serverRole = kingpin.Flag("role", "server role, master or slave").Default("master").Envar("OVPN_ROLE").HintOptions("master", "slave").String() + masterHost = kingpin.Flag("master.host", "URL for the master server").Default("http://127.0.0.1").Envar("OVPN_MASTER_HOST").String() + masterBasicAuthUser = kingpin.Flag("master.basic-auth.user", "user for master server's Basic Auth").Default("").Envar("OVPN_MASTER_USER").String() + masterBasicAuthPassword = kingpin.Flag("master.basic-auth.password", "password for master server's Basic Auth").Default("").Envar("OVPN_MASTER_PASSWORD").String() masterSyncFrequency = kingpin.Flag("master.sync-frequency", "master host data sync frequency in seconds").Default("600").Envar("OVPN_MASTER_SYNC_FREQUENCY").Int() masterSyncToken = kingpin.Flag("master.sync-token", "master host data sync security token").Default("VerySecureToken").Envar("OVPN_MASTER_TOKEN").PlaceHolder("TOKEN").String() - openvpnNetwork = kingpin.Flag("ovpn.network","NETWORK/MASK_PREFIX for OpenVPN server").Default("172.16.100.0/24").Envar("OVPN_NETWORK").String() - openvpnServer = kingpin.Flag("ovpn.server","HOST:PORT:PROTOCOL for OpenVPN server; can have multiple values").Default("127.0.0.1:7777:tcp").Envar("OVPN_SERVER").PlaceHolder("HOST:PORT:PROTOCOL").Strings() - openvpnServerBehindLB = kingpin.Flag("ovpn.server.behindLB","enable if your OpenVPN server is behind Kubernetes Service having the LoadBalancer type").Default("false").Envar("OVPN_LB").Bool() - openvpnServiceName = kingpin.Flag("ovpn.service","the name of Kubernetes Service having the LoadBalancer type if your OpenVPN server is behind it").Default("openvpn-external").Envar("OVPN_LB_SERVICE").String() - mgmtAddress = kingpin.Flag("mgmt","ALIAS=HOST:PORT for OpenVPN server mgmt interface; can have multiple values").Default("main=127.0.0.1:8989").Envar("OVPN_MGMT").Strings() - metricsPath = kingpin.Flag("metrics.path", "URL path for exposing collected metrics").Default("/metrics").Envar("OVPN_METRICS_PATH").String() - easyrsaDirPath = kingpin.Flag("easyrsa.path", "path to easyrsa dir").Default("./easyrsa/").Envar("EASYRSA_PATH").String() - indexTxtPath = kingpin.Flag("easyrsa.index-path", "path to easyrsa index file").Default("./easyrsa/pki/index.txt").Envar("OVPN_INDEX_PATH").String() - ccdEnabled = kingpin.Flag("ccd", "enable client-config-dir").Default("false").Envar("OVPN_CCD").Bool() - ccdDir = kingpin.Flag("ccd.path", "path to client-config-dir").Default("./ccd").Envar("OVPN_CCD_PATH").String() + openvpnNetwork = kingpin.Flag("ovpn.network", "NETWORK/MASK_PREFIX for OpenVPN server").Default("172.16.100.0/24").Envar("OVPN_NETWORK").String() + openvpnServer = kingpin.Flag("ovpn.server", "HOST:PORT:PROTOCOL for OpenVPN server; can have multiple values").Default("127.0.0.1:7777:tcp").Envar("OVPN_SERVER").PlaceHolder("HOST:PORT:PROTOCOL").Strings() + openvpnServerBehindLB = kingpin.Flag("ovpn.server.behindLB", "enable if your OpenVPN server is behind Kubernetes Service having the LoadBalancer type").Default("false").Envar("OVPN_LB").Bool() + openvpnServiceName = kingpin.Flag("ovpn.service", "the name of Kubernetes Service having the LoadBalancer type if your OpenVPN server is behind it").Default("openvpn-external").Envar("OVPN_LB_SERVICE").String() + mgmtAddress = kingpin.Flag("mgmt", "ALIAS=HOST:PORT for OpenVPN server mgmt interface; can have multiple values").Default("main=127.0.0.1:8989").Envar("OVPN_MGMT").Strings() + metricsPath = kingpin.Flag("metrics.path", "URL path for exposing collected metrics").Default("/metrics").Envar("OVPN_METRICS_PATH").String() + easyrsaDirPath = kingpin.Flag("easyrsa.path", "path to easyrsa dir").Default("./easyrsa/").Envar("EASYRSA_PATH").String() + indexTxtPath = kingpin.Flag("easyrsa.index-path", "path to easyrsa index file").Default("./easyrsa/pki/index.txt").Envar("OVPN_INDEX_PATH").String() + ccdEnabled = kingpin.Flag("ccd", "enable client-config-dir").Default("false").Envar("OVPN_CCD").Bool() + ccdDir = kingpin.Flag("ccd.path", "path to client-config-dir").Default("./ccd").Envar("OVPN_CCD_PATH").String() clientConfigTemplatePath = kingpin.Flag("templates.clientconfig-path", "path to custom client.conf.tpl").Default("").Envar("OVPN_TEMPLATES_CC_PATH").String() ccdTemplatePath = kingpin.Flag("templates.ccd-path", "path to custom ccd.tpl").Default("").Envar("OVPN_TEMPLATES_CCD_PATH").String() - authByPassword = kingpin.Flag("auth.password", "enable additional password authentication").Default("false").Envar("OVPN_AUTH").Bool() - authDatabase = kingpin.Flag("auth.db", "database path for password authentication").Default("./easyrsa/pki/users.db").Envar("OVPN_AUTH_DB_PATH").String() - debug = kingpin.Flag("debug", "enable debug mode").Default("false").Envar("OVPN_DEBUG").Bool() - verbose = kingpin.Flag("verbose", "enable verbose mode").Default("false").Envar("OVPN_VERBOSE").Bool() + authByPassword = kingpin.Flag("auth.password", "enable additional password authentication").Default("false").Envar("OVPN_AUTH").Bool() + authDatabase = kingpin.Flag("auth.db", "database path for password authentication").Default("./easyrsa/pki/users.db").Envar("OVPN_AUTH_DB_PATH").String() + debug = kingpin.Flag("debug", "enable debug mode").Default("false").Envar("OVPN_DEBUG").Bool() + verbose = kingpin.Flag("verbose", "enable verbose mode").Default("false").Envar("OVPN_VERBOSE").Bool() - certsArchivePath = "/tmp/" + certsArchiveFileName - ccdArchivePath = "/tmp/" + ccdArchiveFileName + certsArchivePath = "/tmp/" + certsArchiveFileName + ccdArchivePath = "/tmp/" + ccdArchiveFileName version = "1.7.5" ) @@ -160,6 +163,7 @@ type OvpnAdmin struct { mgmtInterfaces map[string]string templates *packr.Box modules []string + httpValidator validator } type OpenvpnServer struct { @@ -178,14 +182,12 @@ type openvpnClientConfig struct { } type OpenvpnClient struct { - - Identity string `json:"Identity"` - AccountStatus string `json:"AccountStatus"` - ExpirationDate string `json:"ExpirationDate"` - RevocationDate string `json:"RevocationDate"` - ConnectionStatus string `json:"ConnectionStatus"` - ConnectionServer string `json:"ConnectionServer"` - + Identity string `json:"Identity"` + AccountStatus string `json:"AccountStatus"` + ExpirationDate string `json:"ExpirationDate"` + RevocationDate string `json:"RevocationDate"` + ConnectionStatus string `json:"ConnectionStatus"` + ConnectionServer string `json:"ConnectionServer"` } type ccdRoute struct { @@ -194,7 +196,7 @@ type ccdRoute struct { Description string `json:"Description"` } -type Ccd struct { +type CCD struct { User string `json:"User"` ClientAddress string `json:"ClientAddress"` CustomRoutes []ccdRoute `json:"CustomRoutes"` @@ -239,16 +241,32 @@ func (oAdmin *OvpnAdmin) userCreateHandler(w http.ResponseWriter, r *http.Reques http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - r.ParseForm() - userCreated, userCreateStatus := oAdmin.userCreate(r.FormValue("username"), r.FormValue("password")) - - if userCreated { - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, userCreateStatus) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) return - } else { - http.Error(w, userCreateStatus, http.StatusUnprocessableEntity) } + userName := r.FormValue("username") + password := r.FormValue("password") + + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if *authByPassword { + if err := oAdmin.httpValidator.validatePassword(password); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + if err := oAdmin.userCreate(userName, password); err != nil { + if *debug { + log.Printf("ERROR: userCreate: %s already exist\n", userName) + } + http.Error(w, err.Error(), http.StatusUnprocessableEntity) + return + } + fmt.Fprintf(w, `User %s created`, userName) } func (oAdmin *OvpnAdmin) userRevokeHandler(w http.ResponseWriter, r *http.Request) { @@ -256,8 +274,21 @@ func (oAdmin *OvpnAdmin) userRevokeHandler(w http.ResponseWriter, r *http.Reques http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userRevoke(r.FormValue("username"))) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := oAdmin.userRevoke(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + fmt.Fprintf(w, `{"msg":"User %s successfully revoked"}`, userName) } func (oAdmin *OvpnAdmin) userUnrevokeHandler(w http.ResponseWriter, r *http.Request) { @@ -266,44 +297,96 @@ func (oAdmin *OvpnAdmin) userUnrevokeHandler(w http.ResponseWriter, r *http.Requ return } - r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.userUnrevoke(r.FormValue("username"))) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := oAdmin.userUnrevoke(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + fmt.Fprintf(w, `{"msg":"User %s successfully unrevoked"}`, userName) } func (oAdmin *OvpnAdmin) userChangePasswordHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - if *authByPassword { - passwordChanged, passwordChangeMessage := oAdmin.userChangePassword(r.FormValue("username"), r.FormValue("password")) - if passwordChanged { - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, `{"status":"ok", "message": "%s"}`, passwordChangeMessage) - return - } else { - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, `{"status":"error", "message": "%s"}`, passwordChangeMessage) - return - } - } else { + if !*authByPassword { http.Error(w, `{"status":"error"}`, http.StatusNotImplemented) } + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + password := r.FormValue("password") + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if *authByPassword { + if err := oAdmin.httpValidator.validatePassword(password); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + } + if err := oAdmin.userChangePassword(userName, password); err != nil { + w.WriteHeader(http.StatusInternalServerError) + fmt.Fprintf(w, `{"status":"error", "message": "%s"}`, err.Error()) + return + } + fmt.Fprint(w, `{"status":"ok", "message": "Password changed"}`) } func (oAdmin *OvpnAdmin) userShowConfigHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - fmt.Fprintf(w, "%s", oAdmin.renderClientConfig(r.FormValue("username"))) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + fmt.Fprint(w, oAdmin.renderClientConfig(userName)) } func (oAdmin *OvpnAdmin) userDisconnectHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + //TODO need implements // fmt.Fprintf(w, "%s", userDisconnect(r.FormValue("username"))) - fmt.Fprintf(w, "%s", r.FormValue("username")) + fmt.Fprint(w, userName) } func (oAdmin *OvpnAdmin) userShowCcdHandler(w http.ResponseWriter, r *http.Request) { - r.ParseForm() - ccd, _ := json.Marshal(oAdmin.getCcd(r.FormValue("username"))) - fmt.Fprintf(w, "%s", ccd) + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + userName := r.FormValue("username") + if err := oAdmin.httpValidator.validateUsername(userName); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + if err := json.NewEncoder(w).Encode(oAdmin.getCcd(userName)); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } func (oAdmin *OvpnAdmin) userApplyCcdHandler(w http.ResponseWriter, r *http.Request) { @@ -311,26 +394,29 @@ func (oAdmin *OvpnAdmin) userApplyCcdHandler(w http.ResponseWriter, r *http.Requ http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - var ccd Ccd - if r.Body == nil { - http.Error(w, "Please send a request body", http.StatusBadRequest) - return - } - + var ccd CCD err := json.NewDecoder(r.Body).Decode(&ccd) if err != nil { - log.Println(err) - } - - ccdApplied, applyStatus := oAdmin.modifyCcd(ccd) - - if ccdApplied { - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, applyStatus) + if err == io.EOF { + http.Error(w, "please send a request body", http.StatusBadRequest) + } else { + http.Error(w, err.Error(), http.StatusBadRequest) + } return - } else { - http.Error(w, applyStatus, http.StatusUnprocessableEntity) } + if err := oAdmin.httpValidator.validateCCD(ccd); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if err := oAdmin.modifyCcd(ccd); err != nil { + if *debug { + log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, err) + } + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Fprint(w, "ccd updated successfully") } func (oAdmin *OvpnAdmin) serverSettingsHandler(w http.ResponseWriter, r *http.Request) { @@ -354,7 +440,10 @@ func (oAdmin *OvpnAdmin) downloadCertsHandler(w http.ResponseWriter, r *http.Req http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - r.ParseForm() + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } token := r.Form.Get("token") if token != oAdmin.masterSyncToken { @@ -362,7 +451,10 @@ func (oAdmin *OvpnAdmin) downloadCertsHandler(w http.ResponseWriter, r *http.Req return } - archiveCerts() + if err := archive(*easyrsaDirPath+"/pki", certsArchivePath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } w.Header().Set("Content-Disposition", "attachment; filename="+certsArchiveFileName) http.ServeFile(w, r, certsArchivePath) } @@ -372,7 +464,10 @@ func (oAdmin *OvpnAdmin) downloadCcdHandler(w http.ResponseWriter, r *http.Reque http.Error(w, `{"status":"error"}`, http.StatusLocked) return } - r.ParseForm() + if err := r.ParseForm(); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } token := r.Form.Get("token") if token != oAdmin.masterSyncToken { @@ -380,7 +475,10 @@ func (oAdmin *OvpnAdmin) downloadCcdHandler(w http.ResponseWriter, r *http.Reque return } - archiveCcd() + if err := archive(*ccdDir, ccdArchivePath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } w.Header().Set("Content-Disposition", "attachment; filename="+ccdArchiveFileName) http.ServeFile(w, r, ccdArchivePath) } @@ -397,6 +495,12 @@ func main() { ovpnAdmin.promRegistry = prometheus.NewRegistry() ovpnAdmin.modules = []string{} + validator, err := newValidator() + if err != nil { + log.Fatal(err) + } + ovpnAdmin.httpValidator = validator + ovpnAdmin.mgmtInterfaces = make(map[string]string) for _, mgmtInterface := range *mgmtAddress { @@ -463,11 +567,31 @@ func main() { http.Handle(*metricsPath, promhttp.HandlerFor(ovpnAdmin.promRegistry, promhttp.HandlerOpts{})) http.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "pong") + fmt.Fprint(w, "pong") }) - log.Printf("Bind: http://%s:%s\n", *listenHost, *listenPort) - log.Fatal(http.ListenAndServe(*listenHost+":"+*listenPort, nil)) + done := make(chan os.Signal, 1) + signal.Notify(done, os.Interrupt, syscall.SIGINT, syscall.SIGTERM) + + bind := net.JoinHostPort(*listenHost, *listenPort) + srv := &http.Server{Addr: bind, Handler: nil} + go func() { + if err := srv.ListenAndServe(); err != nil { + log.Fatalf("listen: %s\n", err) + } + }() + log.Printf("http server started on http://%s\n", bind) + + <-done + log.Print("http server stopped...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer func() { + cancel() + }() + if err := srv.Shutdown(ctx); err != nil { + log.Fatalf("Server Shutdown Failed:%+v", err) + } } func CacheControlWrapper(h http.Handler) http.Handler { @@ -585,9 +709,9 @@ func (oAdmin *OvpnAdmin) renderClientConfig(username string) string { var tmp bytes.Buffer err := t.Execute(&tmp, conf) if err != nil { - log.Printf("ERROR: something goes wrong during rendering config for %s\n", username ) + log.Printf("ERROR: something goes wrong during rendering config for %s\n", username) if *debug { - log.Printf("DEBUG: rendering config for %s failed with error %v\n", username, err ) + log.Printf("DEBUG: rendering config for %s failed with error %v\n", username, err) } } @@ -601,20 +725,19 @@ func (oAdmin *OvpnAdmin) renderClientConfig(username string) string { return fmt.Sprintf("User \"%s\" not found", username) } -func (oAdmin *OvpnAdmin) getCcdTemplate() *template.Template { +func (oAdmin *OvpnAdmin) getCcdTemplate() (*template.Template, error) { if *ccdTemplatePath != "" { - return template.Must(template.ParseFiles(*ccdTemplatePath)) - } else { - ccdTpl, ccdTplErr := oAdmin.templates.FindString("ccd.tpl") - if ccdTplErr != nil { - log.Printf("ERROR: ccdTpl not found in templates box") - } - return template.Must(template.New("ccd").Parse(ccdTpl)) + return template.ParseFiles(*ccdTemplatePath) } + ccdTpl, err := oAdmin.templates.FindString("ccd.tpl") + if err != nil { + return nil, fmt.Errorf("%w: %s", err, "ccd.tpl not found in templates box") + } + return template.New("ccd").Parse(ccdTpl) } -func (oAdmin *OvpnAdmin) parseCcd(username string) Ccd { - ccd := Ccd{} +func (oAdmin *OvpnAdmin) parseCcd(username string) CCD { + ccd := CCD{} ccd.User = username ccd.ClientAddress = "dynamic" ccd.CustomRoutes = []ccdRoute{} @@ -636,88 +759,41 @@ func (oAdmin *OvpnAdmin) parseCcd(username string) Ccd { return ccd } -func (oAdmin *OvpnAdmin) modifyCcd(ccd Ccd) (bool, string) { - ccdErr := "something goes wrong" - - if fCreate(*ccdDir + "/" + ccd.User) { - ccdValid, ccdErr := validateCcd(ccd) - if ccdErr != "" { - return false, ccdErr +func (oAdmin *OvpnAdmin) modifyCcd(ccd CCD) error { + if ccd.ClientAddress != "dynamic" { + _, ovpnNet, err := net.ParseCIDR(*openvpnNetwork) + if err != nil { + return err } - - if ccdValid { - t := oAdmin.getCcdTemplate() - var tmp bytes.Buffer - tplErr := t.Execute(&tmp, ccd) - if tplErr != nil { - log.Println(tplErr) - } - fWrite(*ccdDir+"/"+ccd.User, tmp.String()) - return true, "ccd updated successfully" + if !ovpnNet.Contains(net.ParseIP(ccd.ClientAddress)) { + return fmt.Errorf("clientAddress \"%s\" not belongs to openvpn server network", ccd.ClientAddress) + } + if !checkStaticAddressIsFree(ccd.ClientAddress, ccd.User) { + return fmt.Errorf("clientAddress \"%s\" already assigned to another user", ccd.ClientAddress) } } - - return false, ccdErr + tmpl, err := oAdmin.getCcdTemplate() + if err != nil { + return err + } + var buf bytes.Buffer + if err := tmpl.Execute(&buf, ccd); err != nil { + return err + } + userPath := filepath.Join(*ccdDir, "/", ccd.User) + file, err := os.OpenFile(userPath, os.O_CREATE|os.O_RDWR, 0644) + if err != nil { + return err + } + defer file.Close() + if _, err := file.Write(buf.Bytes()); err != nil { + return err + } + return nil } -func validateCcd(ccd Ccd) (bool, string) { - - ccdErr := "" - - if ccd.ClientAddress != "dynamic" { - _, ovpnNet, err := net.ParseCIDR(*openvpnNetwork) - if err != nil { - log.Println(err) - } - - if ! checkStaticAddressIsFree(ccd.ClientAddress, ccd.User) { - ccdErr = fmt.Sprintf("ClientAddress \"%s\" already assigned to another user", ccd.ClientAddress) - if *debug { - log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, ccdErr) - } - return false, ccdErr - } - - if net.ParseIP(ccd.ClientAddress) == nil { - ccdErr = fmt.Sprintf("ClientAddress \"%s\" not a valid IP address", ccd.ClientAddress) - if *debug { - log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, ccdErr) - } - return false, ccdErr - } - - if ! ovpnNet.Contains(net.ParseIP(ccd.ClientAddress)) { - ccdErr = fmt.Sprintf("ClientAddress \"%s\" not belongs to openvpn server network", ccd.ClientAddress) - if *debug { - log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, ccdErr) - } - return false, ccdErr - } - } - - for _, route := range ccd.CustomRoutes { - if net.ParseIP(route.Address) == nil { - ccdErr = fmt.Sprintf("CustomRoute.Address \"%s\" must be a valid IP address", route.Address) - if *debug { - log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, ccdErr) - } - return false, ccdErr - } - - if net.ParseIP(route.Mask) == nil { - ccdErr = fmt.Sprintf("CustomRoute.Mask \"%s\" must be a valid IP address", route.Mask) - if *debug { - log.Printf("ERROR: Modify ccd for user %s: %s\n", ccd.User, ccdErr) - } - return false, ccdErr - } - } - - return true, ccdErr -} - -func (oAdmin *OvpnAdmin) getCcd(username string) Ccd { - ccd := Ccd{} +func (oAdmin *OvpnAdmin) getCcd(username string) CCD { + ccd := CCD{} ccd.User = username ccd.ClientAddress = "dynamic" ccd.CustomRoutes = []ccdRoute{} @@ -737,19 +813,6 @@ func checkStaticAddressIsFree(staticAddress string, username string) bool { return false } -func validateUsername(username string) bool { - var validUsername = regexp.MustCompile(usernameRegexp) - return validUsername.MatchString(username) -} - -func validatePassword(password string) bool { - if len(password) < passwordMinLength { - return false - } else { - return true - } -} - func checkUserExist(username string) bool { for _, u := range indexTxtParser(fRead(*indexTxtPath)) { if u.DistinguishedName == ("/CN=" + username) { @@ -819,33 +882,9 @@ func (oAdmin *OvpnAdmin) usersList() []OpenvpnClient { return users } -func (oAdmin *OvpnAdmin) userCreate(username, password string) (bool, string) { - ucErr := fmt.Sprintf("User \"%s\" created", username) - +func (oAdmin *OvpnAdmin) userCreate(username, password string) error { if checkUserExist(username) { - ucErr = fmt.Sprintf("User \"%s\" already exists\n", username) - if *debug { - log.Printf("ERROR: userCreate: %s\n", ucErr) - } - return false, ucErr - } - - if !validateUsername(username) { - ucErr = fmt.Sprintf("Username \"%s\" incorrect, you can use only %s\n", username, usernameRegexp) - if *debug { - log.Printf("ERROR: userCreate: %s\n", ucErr) - } - return false, ucErr - } - - if *authByPassword { - if !validatePassword(password) { - ucErr = fmt.Sprintf("Password too short, password length must be greater or equal %d", passwordMinLength) - if *debug { - log.Printf("ERROR: userCreate: %s\n", ucErr) - } - return false, ucErr - } + return fmt.Errorf("User \"%s\" already exists\n", username) } o := runBash(fmt.Sprintf("date +%%Y-%%m-%%d\\ %%H:%%M:%%S && cd %s && easyrsa build-client-full %s nopass", *easyrsaDirPath, username)) @@ -862,23 +901,14 @@ func (oAdmin *OvpnAdmin) userCreate(username, password string) (bool, string) { oAdmin.clients = oAdmin.usersList() - return true, ucErr + return nil } -func (oAdmin *OvpnAdmin) userChangePassword(username, password string) (bool, string) { - +func (oAdmin *OvpnAdmin) userChangePassword(username, password string) error { if checkUserExist(username) { o := runBash(fmt.Sprintf("openvpn-user check --db.path %s --user %s | grep %s | wc -l", *authDatabase, username, username)) log.Println(o) - if !validatePassword(password) { - ucpErr := fmt.Sprintf("Password for too short, password length must be greater or equal %d", passwordMinLength) - if *debug { - log.Printf("ERROR: userChangePassword: %s\n", ucpErr) - } - return false, ucpErr - } - if strings.TrimSpace(o) == "0" { log.Println("Creating user in users.db") o = runBash(fmt.Sprintf("openvpn-user create --db.path %s --user %s --password %s", *authDatabase, username, password)) @@ -891,10 +921,9 @@ func (oAdmin *OvpnAdmin) userChangePassword(username, password string) (bool, st if *verbose { log.Printf("INFO: password for user %s was changed\n", username) } - return true, "Password changed" + return nil } - - return false, "User does not exist" + return fmt.Errorf("user does not exist") } func (oAdmin *OvpnAdmin) getUserStatistic(username string) clientStatus { @@ -906,12 +935,12 @@ func (oAdmin *OvpnAdmin) getUserStatistic(username string) clientStatus { return clientStatus{} } -func (oAdmin *OvpnAdmin) userRevoke(username string) string { +func (oAdmin *OvpnAdmin) userRevoke(username string) error { if checkUserExist(username) { // check certificate valid flag 'V' - o := runBash(fmt.Sprintf("date +%%Y-%%m-%%d\\ %%H:%%M:%%S && cd %s && echo yes | easyrsa revoke %s && easyrsa gen-crl", *easyrsaDirPath, username)) + runBash(fmt.Sprintf("date +%%Y-%%m-%%d\\ %%H:%%M:%%S && cd %s && echo yes | easyrsa revoke %s && easyrsa gen-crl", *easyrsaDirPath, username)) if *authByPassword { - o = runBash(fmt.Sprintf("openvpn-user revoke --db-path %s --user %s", *authDatabase, username)) + runBash(fmt.Sprintf("openvpn-user revoke --db-path %s --user %s", *authDatabase, username)) //fmt.Println(o) } @@ -922,13 +951,12 @@ func (oAdmin *OvpnAdmin) userRevoke(username string) string { log.Printf("Session for user \"%s\" session killed\n", username) } oAdmin.clients = oAdmin.usersList() - return fmt.Sprintln(o) + return nil } - log.Printf("User \"%s\" not found\n", username) - return fmt.Sprintf("User \"%s\" not found", username) + return fmt.Errorf(`user "%s" not found`, username) } -func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { +func (oAdmin *OvpnAdmin) userUnrevoke(username string) error { if checkUserExist(username) { // check certificate revoked flag 'R' usersFromIndexTxt := indexTxtParser(fRead(*indexTxtPath)) @@ -964,9 +992,9 @@ func (oAdmin *OvpnAdmin) userUnrevoke(username string) string { fmt.Print(renderIndexTxt(usersFromIndexTxt)) crlFix() oAdmin.clients = oAdmin.usersList() - return fmt.Sprintf("{\"msg\":\"User %s successfully unrevoked\"}", username) + return nil } - return fmt.Sprintf("{\"msg\":\"User \"%s\" not found\"}", username) + return fmt.Errorf(`user "%s" not found`, username) } func (oAdmin *OvpnAdmin) mgmtRead(conn net.Conn) string { @@ -1096,26 +1124,86 @@ func (oAdmin *OvpnAdmin) downloadCcd() bool { return true } -func archiveCerts() { - o := runBash(fmt.Sprintf("cd %s && tar -czf %s *", *easyrsaDirPath+"/pki", certsArchivePath)) - fmt.Println(o) +func unArchive(src, dst string) error { + file, err := os.Open(src) + if err != nil { + return err + } + defer file.Close() + + gz, err := gzip.NewReader(file) + if err != nil { + return err + } + tr := tar.NewReader(gz) + + for { + header, err := tr.Next() + if err != nil { + if err == io.EOF { + break + } + return err + } + dstPath := filepath.Join(dst, header.Name) + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(dstPath, 0755); err != nil { + return err + } + case tar.TypeReg: + outFile, err := os.Create(dstPath) + if err != nil { + return err + } + if _, err := io.Copy(outFile, tr); err != nil { + outFile.Close() + return err + } + outFile.Close() + default: + return fmt.Errorf("uknown type: %v in %s", header.Typeflag, header.Name) + } + } + return nil } -func archiveCcd() { - o := runBash(fmt.Sprintf("cd %s && tar -czf %s *", *ccdDir, ccdArchivePath)) - fmt.Println(o) -} +func archive(src, dst string) error { + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() -func unArchiveCerts() { - runBash(fmt.Sprintf("mkdir -p %s", *easyrsaDirPath+"/pki")) - o := runBash(fmt.Sprintf("cd %s && tar -xzf %s", *easyrsaDirPath+"/pki", certsArchivePath)) - fmt.Println(o) -} - -func unArchiveCcd() { - runBash(fmt.Sprintf("mkdir -p %s", *ccdDir)) - o := runBash(fmt.Sprintf("cd %s && tar -xzf %s", *ccdDir, ccdArchivePath)) - fmt.Println(o) + zr := gzip.NewWriter(out) + tw := tar.NewWriter(zr) + defer tw.Close() + defer zr.Close() + err = filepath.Walk(src, func(file string, fi os.FileInfo, err error) error { + header, err := tar.FileInfoHeader(fi, file) + if err != nil { + return err + } + header.Name = filepath.ToSlash(file) + if err := tw.WriteHeader(header); err != nil { + return err + } + if !fi.IsDir() { + f, err := os.Open(file) + if err != nil { + return err + } + defer f.Close() + if _, err := io.Copy(tw, f); err != nil { + return err + } + } + return nil + }) + if err != nil { + return err + } + return nil } func (oAdmin *OvpnAdmin) syncDataFromMaster() { @@ -1131,7 +1219,9 @@ func (oAdmin *OvpnAdmin) syncDataFromMaster() { if oAdmin.downloadCerts() { certsDownloadFailed = false log.Println("Decompression certs archive from master") - unArchiveCerts() + if err := unArchive(certsArchivePath, *easyrsaDirPath+"/pki"); err != nil { + log.Printf("unArchive %s error: %s\n", certsArchivePath, err) + } } else { log.Printf("WARNING: something goes wrong during downloading certs from master. Attempt %d\n", certsDownloadRetries) } @@ -1143,7 +1233,9 @@ func (oAdmin *OvpnAdmin) syncDataFromMaster() { if oAdmin.downloadCcd() { ccdDownloadFailed = false log.Println("Decompression ccd archive from master") - unArchiveCcd() + if err := unArchive(ccdArchivePath, *ccdDir); err != nil { + log.Printf("unArchive %s error: %s\n", ccdArchivePath, err) + } } else { log.Printf("WARNING: something goes wrong during downloading certs from master. Attempt %d\n", ccdDownloadRetries) } @@ -1162,7 +1254,6 @@ func (oAdmin *OvpnAdmin) syncWithMaster() { } } - func getOvpnServerHostsFromKubeApi() []OpenvpnServer { var hosts []OpenvpnServer var lbHost string @@ -1194,7 +1285,7 @@ func getOvpnServerHostsFromKubeApi() []OpenvpnServer { if service.Status.LoadBalancer.Ingress[0].IP != "" { lbHost = service.Status.LoadBalancer.Ingress[0].IP } - hosts = append(hosts, OpenvpnServer{lbHost,strconv.Itoa(int(service.Spec.Ports[0].Port)),strings.ToLower(string(service.Spec.Ports[0].Protocol))}) + hosts = append(hosts, OpenvpnServer{lbHost, strconv.Itoa(int(service.Spec.Ports[0].Port)), strings.ToLower(string(service.Spec.Ports[0].Protocol))}) return hosts } diff --git a/validation.go b/validation.go new file mode 100644 index 0000000..518a72b --- /dev/null +++ b/validation.go @@ -0,0 +1,89 @@ +package main + +import ( + "fmt" + "net" + "regexp" + "unicode/utf8" +) + +const ( + usernameRegexp = `^[a-zA-Z0-9@]+([._]?[a-zA-Z0-9]+)*$` + passwordRegexp = `^[a-zA-Z0-9@$!#%^_\-]+([._]?[a-zA-Z0-9]+)*$` + CCDDescriptionRegexp = `^[a-zA-Z0-9 ]*$` + usernameMinLength = 3 + passwordMinLength = 6 +) + +type ( + validator interface { + validateUsername(string) error + validatePassword(string) error + validateCCD(ccd CCD) error + } + + validators struct { + reUserName *regexp.Regexp + rePassword *regexp.Regexp + reCCDDescription *regexp.Regexp + } +) + +func (v *validators) validateCCD(ccd CCD) error { + if err := v.validateUsername(ccd.User); err != nil { + return err + } + + if ccd.ClientAddress != "dynamic" { + if net.ParseIP(ccd.ClientAddress) == nil { + return fmt.Errorf("invalid ClientAddress") + } + } + + for _, route := range ccd.CustomRoutes { + if net.ParseIP(route.Address) == nil { + return fmt.Errorf("invalid CustomRoute.Address") + } + if net.ParseIP(route.Mask) == nil { + return fmt.Errorf("invalid CustomRoute.Mask") + } + if utf8.RuneCountInString(route.Description) > 0 && !v.reCCDDescription.MatchString(route.Description) { + return fmt.Errorf("invalid CustomRoute.Description") + } + } + return nil +} + +func newValidator() (validator, error) { + reUsername, err := regexp.Compile(usernameRegexp) + if err != nil { + return nil, fmt.Errorf("username regexp compile failed %s", err) + } + rePassword, err := regexp.Compile(passwordRegexp) + if err != nil { + return nil, fmt.Errorf("password regexp compile failed %s", err) + } + reCCDDescription, err := regexp.Compile(CCDDescriptionRegexp) + if err != nil { + return nil, fmt.Errorf("ccd desctiprion regexp compile failed %s", err) + } + return &validators{ + reUserName: reUsername, + rePassword: rePassword, + reCCDDescription: reCCDDescription, + }, nil +} + +func (v *validators) validateUsername(username string) error { + if usernameMinLength > utf8.RuneCountInString(username) || !v.reUserName.MatchString(username) { + return fmt.Errorf("invalid username") + } + return nil +} + +func (v *validators) validatePassword(passwd string) error { + if passwordMinLength > utf8.RuneCountInString(passwd) || !v.rePassword.MatchString(passwd) { + return fmt.Errorf("invalid password") + } + return nil +} diff --git a/validation_test.go b/validation_test.go new file mode 100644 index 0000000..8b665a8 --- /dev/null +++ b/validation_test.go @@ -0,0 +1,144 @@ +package main + +import ( + "testing" +) + +func Test_validator_validateUsername(t *testing.T) { + type args struct { + username string + } + + tests := []struct { + name string + args args + wantErr bool + }{ + { + name: "success (qwerty)", + args: args{"qwerty"}, + wantErr: false, + }, + { + name: "success (qwerty@gmail.com)", + args: args{"qwerty@gmail.com"}, + wantErr: false, + }, + { + name: "success (luck.cage)", + args: args{"luck.cage"}, + wantErr: false, + }, + { + name: "error (empty)", + args: args{""}, + wantErr: true, + }, + { + name: "error (/bin/bash)", + args: args{"/bin/bash"}, + wantErr: true, + }, + { + name: "error (./exploit.sh)", + args: args{"./exploit.sh"}, + wantErr: true, + }, + { + name: "error (rm rf /)", + args: args{"rm rf /"}, + wantErr: true, + }, + { + name: "error (; ls)", + args: args{"; ls"}, + wantErr: true, + }, + { + name: "error (&echo>passwd)", + args: args{"&echo>passwd"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v, err := newValidator() + if err != nil { + t.Error(err) + } + if err := v.validateUsername(tt.args.username); (err != nil) != tt.wantErr { + t.Errorf("validateUsername() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func Test_validator_validatePassword(t *testing.T) { + type args struct { + passwd string + } + + tests := []struct { + name string + args args + wantErr bool + }{ + + { + name: "success (qwe1239%$0123924qw)", + args: args{"qwe1239%$0123924qw"}, + wantErr: false, + }, + { + name: "success (pa$$w0rd)", + args: args{"pa$$w0rd"}, + wantErr: false, + }, + { + name: "short (qwe)", + args: args{"qwe"}, + wantErr: true, + }, + { + name: "error (empty)", + args: args{""}, + wantErr: true, + }, + { + name: "error (/bin/bash)", + args: args{"/bin/bash"}, + wantErr: true, + }, + { + name: "error (./exploit.sh)", + args: args{"./exploit.sh"}, + wantErr: true, + }, + { + name: "error (rm rf /)", + args: args{"rm rf /"}, + wantErr: true, + }, + { + name: "error (; ls)", + args: args{"; ls"}, + wantErr: true, + }, + { + name: "error (&echo>passwd)", + args: args{"&echo>passwd"}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + v, err := newValidator() + if err != nil { + t.Error(err) + } + if err := v.validatePassword(tt.args.passwd); (err != nil) != tt.wantErr { + t.Errorf("validatePassword() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +}