package comments
import (
"bytes"
"crypto/subtle"
"encoding/json"
"fmt"
"html/template"
"io"
"net/http"
"strconv"
"sync"
"time"
"github.com/peterbourgon/diskv"
)
type Comment struct {
Name, Text string
IP string
Date time.Time
Spam bool
}
func MarshalJSON(c Comment) ([]byte, error) {
buf := new(bytes.Buffer)
err := json.NewEncoder(buf).Encode(c)
return buf.Bytes(), err
}
func UnmarshalJSON(b []byte) (Comment, error) {
var c Comment
err := json.NewDecoder(bytes.NewBuffer(b)).Decode(&c)
return c, err
}
func Default() *Service {
return WithStorage(NewDiskv("comments"))
}
func WithStorage(storage Storage) *Service {
return &Service{Storage: storage}
}
type Service struct {
Storage Storage
Password string
}
type Storage interface {
Store(post string, c Comment) error
Retreive(post string) ([]Comment, error)
ReadAll(post string) (io.Reader, error)
}
func (cs *Service) User(r *http.Request) (string, bool) {
if cs.Password == "" {
return "", true
}
user, providedPass, ok := r.BasicAuth()
if ok && subtle.ConstantTimeCompare([]byte(cs.Password), []byte(providedPass)) == 1 {
return user, true
}
return "", false
}
func (cs *Service) ServeComments(post string, w http.ResponseWriter, r *http.Request) error {
user, allowed := cs.User(r)
if !allowed {
fmt.Fprintf(w, "click here to enter your name and the comments password")
w.WriteHeader(401)
return nil
}
if r.Method == "POST" {
return cs.Comment(post, Comment{
Name: user,
Text: readCommentText(r),
IP: r.RemoteAddr,
Date: time.Now(),
})
}
w.Header().Set("Content-Type", "text/html")
_, err := cs.WriteHTML(post, w)
if err != nil {
return err
}
return nil
}
func readCommentText(r *http.Request) string {
comment := struct{ Text string }{}
json.NewDecoder(r.Body).Decode(&comment)
return comment.Text
}
func (cs *Service) Comment(post string, c Comment) error {
return cs.Storage.Store(post, c)
}
func (cs *Service) Load(post string) ([]Comment, error) {
return cs.Storage.Retreive(post)
}
func (cs *Service) WriteHTML(post string, w io.Writer) (int64, error) {
buf := &bytes.Buffer{}
tmpl := template.Must(template.New("").Parse(
`
{{range .}}{{ if .Text }}
{{.Name}}
{{.Text}}
{{end}}{{end}}
`))
comments, err := cs.Load(post)
if err != nil {
return 0, err
}
err = tmpl.Execute(buf, comments)
if err != nil {
return 0, err
}
return io.Copy(w, buf)
}
func (cs *Service) WriteJSON(post string, w io.Writer) (int64, error) {
r, err := cs.Storage.ReadAll(post)
if err != nil {
return 0, err
}
return io.Copy(w, r)
}
type diskvStorage struct {
*diskv.Diskv
keyMtx sync.Mutex
keyLastKnown int64
}
func (d *diskvStorage) nextKey(post string) (string, error) {
d.keyMtx.Lock()
defer d.keyMtx.Unlock()
prefix := post + "-"
lastKey := prefix
keys := d.KeysPrefix(prefix, nil)
for key := range keys {
lastKey = key
}
if lastKey == prefix {
return prefix + "1", nil
}
minusPrefix := lastKey[len(prefix):]
lastIdx, err := strconv.Atoi(minusPrefix)
if err != nil {
return "", fmt.Errorf("comment key %q seems malformed", lastKey)
}
key := prefix + fmt.Sprint(lastIdx+1)
err = d.WriteStream(key, &bytes.Buffer{}, true)
return key, err
}
func (d *diskvStorage) Store(post string, c Comment) error {
b, err := MarshalJSON(c)
if err != nil {
return fmt.Errorf("error marshaling comment for storage: %v", err)
}
key, err := d.nextKey(post)
if err != nil {
return fmt.Errorf("error recording comment: %v", err)
}
err = d.Write(key, b)
if err != nil {
return fmt.Errorf("error writing comment to storage: %v", err)
}
return nil
}
type multiReadCloser struct {
rcs []io.ReadCloser
}
func (mrc *multiReadCloser) Read(b []byte) (int, error) {
if len(mrc.rcs) == 0 {
return 0, io.EOF
}
n, err := mrc.rcs[0].Read(b)
if n == len(b) { // we read enough
return n, err
}
for (err == nil || err == io.EOF) && len(mrc.rcs) > 1 {
var n2 int
closeErr := mrc.rcs[0].Close()
if closeErr != nil {
return n, closeErr
}
mrc.rcs = mrc.rcs[1:]
n2, err = mrc.rcs[0].Read(b[n:])
n += n2
}
return n, err
}
func (mrc *multiReadCloser) Close() error {
var err error
for _, rc := range mrc.rcs {
err2 := rc.Close()
if err2 != nil {
err = err2
}
}
if err != nil {
return err
}
return nil
}
func (d *diskvStorage) ReadAll(post string) (io.Reader, error) {
prefix := post + "-"
var rcs []io.ReadCloser
keys := d.KeysPrefix(prefix, nil)
for key := range keys {
rc, err := d.ReadStream(key, false)
if err != nil {
return nil, err
}
rcs = append(rcs, rc)
}
return &multiReadCloser{rcs}, nil
}
func (d *diskvStorage) Retreive(post string) ([]Comment, error) {
var comments []Comment
r, err := d.ReadAll(post)
if err != nil {
return comments, fmt.Errorf("error reading comment from storage: %v", err)
}
dec := json.NewDecoder(r)
for dec.More() {
comment := Comment{}
err := dec.Decode(&comment)
if err != nil {
return comments, fmt.Errorf("error unmarshaling comment: %v", err)
}
comments = append(comments, comment)
}
return comments, nil
}
func NewDiskv(path string) Storage {
return &diskvStorage{
Diskv: diskv.New(diskv.Options{
BasePath: path,
CacheSizeMax: 1024 * 1024,
}),
}
}