From e8752f4e9f9be3d2932cd4835a5d72d17ac2338b Mon Sep 17 00:00:00 2001
From: Elad <theman@elad.im>
Date: Wed, 15 Aug 2018 17:41:52 +0200
Subject: cmd/swarm, swarm: added access control functionality (#17404)

Co-authored-by: Janos Guljas <janos@resenje.org>
Co-authored-by: Anton Evangelatov <anton.evangelatov@gmail.com>
Co-authored-by: Balint Gabor <balint.g@gmail.com>
---
 swarm/api/http/middleware.go |  12 ++++-
 swarm/api/http/response.go   |   2 +-
 swarm/api/http/server.go     | 111 ++++++++++++++++++++-----------------------
 3 files changed, 63 insertions(+), 62 deletions(-)

(limited to 'swarm/api/http')

diff --git a/swarm/api/http/middleware.go b/swarm/api/http/middleware.go
index c0d8d1a40..3b2dcc7d5 100644
--- a/swarm/api/http/middleware.go
+++ b/swarm/api/http/middleware.go
@@ -9,6 +9,7 @@ import (
 	"github.com/ethereum/go-ethereum/metrics"
 	"github.com/ethereum/go-ethereum/swarm/api"
 	"github.com/ethereum/go-ethereum/swarm/log"
+	"github.com/ethereum/go-ethereum/swarm/sctx"
 	"github.com/ethereum/go-ethereum/swarm/spancontext"
 	"github.com/pborman/uuid"
 )
@@ -35,6 +36,15 @@ func SetRequestID(h http.Handler) http.Handler {
 	})
 }
 
+func SetRequestHost(h http.Handler) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		r = r.WithContext(sctx.SetHost(r.Context(), r.Host))
+		log.Info("setting request host", "ruid", GetRUID(r.Context()), "host", sctx.GetHost(r.Context()))
+
+		h.ServeHTTP(w, r)
+	})
+}
+
 func ParseURI(h http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		uri, err := api.Parse(strings.TrimLeft(r.URL.Path, "/"))
@@ -87,7 +97,7 @@ func RecoverPanic(h http.Handler) http.Handler {
 	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 		defer func() {
 			if err := recover(); err != nil {
-				log.Error("panic recovery!", "stack trace", debug.Stack(), "url", r.URL.String(), "headers", r.Header)
+				log.Error("panic recovery!", "stack trace", string(debug.Stack()), "url", r.URL.String(), "headers", r.Header)
 			}
 		}()
 		h.ServeHTTP(w, r)
diff --git a/swarm/api/http/response.go b/swarm/api/http/response.go
index f050e706a..c9fb9d285 100644
--- a/swarm/api/http/response.go
+++ b/swarm/api/http/response.go
@@ -79,7 +79,7 @@ func RespondTemplate(w http.ResponseWriter, r *http.Request, templateName, msg s
 }
 
 func RespondError(w http.ResponseWriter, r *http.Request, msg string, code int) {
-	log.Debug("RespondError", "ruid", GetRUID(r.Context()), "uri", GetURI(r.Context()))
+	log.Debug("RespondError", "ruid", GetRUID(r.Context()), "uri", GetURI(r.Context()), "code", code)
 	RespondTemplate(w, r, "error", msg, code)
 }
 
diff --git a/swarm/api/http/server.go b/swarm/api/http/server.go
index 5a5c42adc..b5ea0c23d 100644
--- a/swarm/api/http/server.go
+++ b/swarm/api/http/server.go
@@ -23,7 +23,6 @@ import (
 	"bufio"
 	"bytes"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"io"
 	"io/ioutil"
@@ -97,6 +96,7 @@ func NewServer(api *api.API, corsString string) *Server {
 	defaultMiddlewares := []Adapter{
 		RecoverPanic,
 		SetRequestID,
+		SetRequestHost,
 		InitLoggingResponseWriter,
 		ParseURI,
 		InstrumentOpenTracing,
@@ -169,6 +169,7 @@ func NewServer(api *api.API, corsString string) *Server {
 }
 
 func (s *Server) ListenAndServe(addr string) error {
+	s.listenAddr = addr
 	return http.ListenAndServe(addr, s)
 }
 
@@ -178,16 +179,24 @@ func (s *Server) ListenAndServe(addr string) error {
 // https://github.com/atom/electron/blob/master/docs/api/protocol.md
 type Server struct {
 	http.Handler
-	api *api.API
+	api        *api.API
+	listenAddr string
 }
 
 func (s *Server) HandleBzzGet(w http.ResponseWriter, r *http.Request) {
-	log.Debug("handleBzzGet", "ruid", GetRUID(r.Context()))
+	log.Debug("handleBzzGet", "ruid", GetRUID(r.Context()), "uri", r.RequestURI)
 	if r.Header.Get("Accept") == "application/x-tar" {
 		uri := GetURI(r.Context())
-		reader, err := s.api.GetDirectoryTar(r.Context(), uri)
+		_, credentials, _ := r.BasicAuth()
+		reader, err := s.api.GetDirectoryTar(r.Context(), s.api.Decryptor(r.Context(), credentials), uri)
 		if err != nil {
+			if isDecryptError(err) {
+				w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", uri.Address().String()))
+				RespondError(w, r, err.Error(), http.StatusUnauthorized)
+				return
+			}
 			RespondError(w, r, fmt.Sprintf("Had an error building the tarball: %v", err), http.StatusInternalServerError)
+			return
 		}
 		defer reader.Close()
 
@@ -287,7 +296,7 @@ func (s *Server) HandlePostFiles(w http.ResponseWriter, r *http.Request) {
 
 	var addr storage.Address
 	if uri.Addr != "" && uri.Addr != "encrypt" {
-		addr, err = s.api.Resolve(r.Context(), uri)
+		addr, err = s.api.Resolve(r.Context(), uri.Addr)
 		if err != nil {
 			postFilesFail.Inc(1)
 			RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusInternalServerError)
@@ -563,7 +572,7 @@ func (s *Server) HandleGetResource(w http.ResponseWriter, r *http.Request) {
 	// resolve the content key.
 	manifestAddr := uri.Address()
 	if manifestAddr == nil {
-		manifestAddr, err = s.api.Resolve(r.Context(), uri)
+		manifestAddr, err = s.api.Resolve(r.Context(), uri.Addr)
 		if err != nil {
 			getFail.Inc(1)
 			RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound)
@@ -682,62 +691,21 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *http.Request) {
 	uri := GetURI(r.Context())
 	log.Debug("handle.get", "ruid", ruid, "uri", uri)
 	getCount.Inc(1)
+	_, pass, _ := r.BasicAuth()
 
-	var err error
-	addr := uri.Address()
-	if addr == nil {
-		addr, err = s.api.Resolve(r.Context(), uri)
-		if err != nil {
-			getFail.Inc(1)
-			RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound)
-			return
-		}
-	} else {
-		w.Header().Set("Cache-Control", "max-age=2147483648, immutable") // url was of type bzz://<hex key>/path, so we are sure it is immutable.
+	addr, err := s.api.ResolveURI(r.Context(), uri, pass)
+	if err != nil {
+		getFail.Inc(1)
+		RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound)
+		return
 	}
+	w.Header().Set("Cache-Control", "max-age=2147483648, immutable") // url was of type bzz://<hex key>/path, so we are sure it is immutable.
 
 	log.Debug("handle.get: resolved", "ruid", ruid, "key", addr)
 
 	// if path is set, interpret <key> as a manifest and return the
 	// raw entry at the given path
-	if uri.Path != "" {
-		walker, err := s.api.NewManifestWalker(r.Context(), addr, nil)
-		if err != nil {
-			getFail.Inc(1)
-			RespondError(w, r, fmt.Sprintf("%s is not a manifest", addr), http.StatusBadRequest)
-			return
-		}
-		var entry *api.ManifestEntry
-		walker.Walk(func(e *api.ManifestEntry) error {
-			// if the entry matches the path, set entry and stop
-			// the walk
-			if e.Path == uri.Path {
-				entry = e
-				// return an error to cancel the walk
-				return errors.New("found")
-			}
-
-			// ignore non-manifest files
-			if e.ContentType != api.ManifestType {
-				return nil
-			}
-
-			// if the manifest's path is a prefix of the
-			// requested path, recurse into it by returning
-			// nil and continuing the walk
-			if strings.HasPrefix(uri.Path, e.Path) {
-				return nil
-			}
 
-			return api.ErrSkipManifest
-		})
-		if entry == nil {
-			getFail.Inc(1)
-			RespondError(w, r, fmt.Sprintf("manifest entry could not be loaded"), http.StatusNotFound)
-			return
-		}
-		addr = storage.Address(common.Hex2Bytes(entry.Hash))
-	}
 	etag := common.Bytes2Hex(addr)
 	noneMatchEtag := r.Header.Get("If-None-Match")
 	w.Header().Set("ETag", fmt.Sprintf("%q", etag)) // set etag to manifest key or raw entry key.
@@ -781,6 +749,7 @@ func (s *Server) HandleGet(w http.ResponseWriter, r *http.Request) {
 func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) {
 	ruid := GetRUID(r.Context())
 	uri := GetURI(r.Context())
+	_, credentials, _ := r.BasicAuth()
 	log.Debug("handle.get.list", "ruid", ruid, "uri", uri)
 	getListCount.Inc(1)
 
@@ -790,7 +759,7 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	addr, err := s.api.Resolve(r.Context(), uri)
+	addr, err := s.api.Resolve(r.Context(), uri.Addr)
 	if err != nil {
 		getListFail.Inc(1)
 		RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound)
@@ -798,9 +767,14 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) {
 	}
 	log.Debug("handle.get.list: resolved", "ruid", ruid, "key", addr)
 
-	list, err := s.api.GetManifestList(r.Context(), addr, uri.Path)
+	list, err := s.api.GetManifestList(r.Context(), s.api.Decryptor(r.Context(), credentials), addr, uri.Path)
 	if err != nil {
 		getListFail.Inc(1)
+		if isDecryptError(err) {
+			w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", addr.String()))
+			RespondError(w, r, err.Error(), http.StatusUnauthorized)
+			return
+		}
 		RespondError(w, r, err.Error(), http.StatusInternalServerError)
 		return
 	}
@@ -833,7 +807,8 @@ func (s *Server) HandleGetList(w http.ResponseWriter, r *http.Request) {
 func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) {
 	ruid := GetRUID(r.Context())
 	uri := GetURI(r.Context())
-	log.Debug("handle.get.file", "ruid", ruid)
+	_, credentials, _ := r.BasicAuth()
+	log.Debug("handle.get.file", "ruid", ruid, "uri", r.RequestURI)
 	getFileCount.Inc(1)
 
 	// ensure the root path has a trailing slash so that relative URLs work
@@ -845,7 +820,7 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) {
 	manifestAddr := uri.Address()
 
 	if manifestAddr == nil {
-		manifestAddr, err = s.api.Resolve(r.Context(), uri)
+		manifestAddr, err = s.api.ResolveURI(r.Context(), uri, credentials)
 		if err != nil {
 			getFileFail.Inc(1)
 			RespondError(w, r, fmt.Sprintf("cannot resolve %s: %s", uri.Addr, err), http.StatusNotFound)
@@ -856,7 +831,8 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) {
 	}
 
 	log.Debug("handle.get.file: resolved", "ruid", ruid, "key", manifestAddr)
-	reader, contentType, status, contentKey, err := s.api.Get(r.Context(), manifestAddr, uri.Path)
+
+	reader, contentType, status, contentKey, err := s.api.Get(r.Context(), s.api.Decryptor(r.Context(), credentials), manifestAddr, uri.Path)
 
 	etag := common.Bytes2Hex(contentKey)
 	noneMatchEtag := r.Header.Get("If-None-Match")
@@ -869,6 +845,12 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) {
 	}
 
 	if err != nil {
+		if isDecryptError(err) {
+			w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", manifestAddr))
+			RespondError(w, r, err.Error(), http.StatusUnauthorized)
+			return
+		}
+
 		switch status {
 		case http.StatusNotFound:
 			getFileNotFound.Inc(1)
@@ -883,9 +865,14 @@ func (s *Server) HandleGetFile(w http.ResponseWriter, r *http.Request) {
 	//the request results in ambiguous files
 	//e.g. /read with readme.md and readinglist.txt available in manifest
 	if status == http.StatusMultipleChoices {
-		list, err := s.api.GetManifestList(r.Context(), manifestAddr, uri.Path)
+		list, err := s.api.GetManifestList(r.Context(), s.api.Decryptor(r.Context(), credentials), manifestAddr, uri.Path)
 		if err != nil {
 			getFileFail.Inc(1)
+			if isDecryptError(err) {
+				w.Header().Set("WWW-Authenticate", fmt.Sprintf("Basic realm=%q", manifestAddr))
+				RespondError(w, r, err.Error(), http.StatusUnauthorized)
+				return
+			}
 			RespondError(w, r, err.Error(), http.StatusInternalServerError)
 			return
 		}
@@ -951,3 +938,7 @@ func (lrw *loggingResponseWriter) WriteHeader(code int) {
 	lrw.statusCode = code
 	lrw.ResponseWriter.WriteHeader(code)
 }
+
+func isDecryptError(err error) bool {
+	return strings.Contains(err.Error(), api.ErrDecrypt.Error())
+}
-- 
cgit v1.2.3