diff --git a/deltas.go b/deltas.go index f1cda61..9c81797 100644 --- a/deltas.go +++ b/deltas.go @@ -2,6 +2,7 @@ package caddyhugo import ( "fmt" + "log" "net/http" "time" @@ -9,6 +10,12 @@ import ( "github.com/gorilla/websocket" ) +// DeltaConn identifies the methods used in the original websocket implementation. +type DeltaConn interface { + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error +} + const ( IdleWebsocketTimeout = 10 * time.Minute WebsocketFileTicker = 1 * time.Second @@ -50,25 +57,43 @@ func (ch *CaddyHugo) DeltaWebsocket(w http.ResponseWriter, r *http.Request) (int return http.StatusBadRequest, err } + return ch.handleConn(conn, doc) +} + +func (ch *CaddyHugo) Message(deltas ...acedoc.Delta) Message { + return Message{ + Deltas: deltas, + LTime: ch.LTime(), + } +} + +func (ch *CaddyHugo) handleConn(conn DeltaConn, doc *docref) (int, error) { const idlePing = 15 * time.Second const idlePingShort = 1 * time.Millisecond - var timer *time.Timer - timer = time.AfterFunc(idlePing, func() { - conn.WriteJSON(Message{ - Deltas: []acedoc.Delta{}, - LTime: ch.LTime(), - }) - timer.Reset(idlePing) - }) + errCh := make(chan error) + doneCh := make(chan struct{}) + defer func() { + close(doneCh) + close(errCh) + for err := range errCh { + log.Println(err) + } + }() + + timer := time.NewTimer(idlePing) + resetTimer := func(d time.Duration) { + if !timer.Stop() { + <-timer.C + } + timer.Reset(d) + } + + wroteMessagesCh := make(chan Message, 2) client := doc.doc.Client(acedoc.DeltaHandlerFunc(func(ds []acedoc.Delta) error { - timer.Reset(idlePing) - err := conn.WriteJSON(Message{ - Deltas: ds, - LTime: ch.LTime(), - }) - return err + wroteMessagesCh <- Message{Deltas: ds} + return conn.WriteJSON(ch.Message(ds...)) })) ch.mtx.Lock() @@ -82,21 +107,44 @@ func (ch *CaddyHugo) DeltaWebsocket(w http.ResponseWriter, r *http.Request) (int ch.mtx.Unlock() }() - for { - var message Message - err := conn.ReadJSON(&message) - if err != nil { - return http.StatusBadRequest, err - } - - ch.ObserveLTime(message.LTime) - timer.Reset(idlePingShort) + readMessagesCh := make(chan Message, 2) + go func() { + for { + var message Message + + err := conn.ReadJSON(&message) + if err != nil { + errCh <- err + return + } + ch.ObserveLTime(message.LTime) + + err = client.PushDeltas(message.Deltas...) + if err != nil { + errCh <- err + return + } + + select { + case readMessagesCh <- message: + case <-doneCh: + return + } - err = client.PushDeltas(message.Deltas...) - if err != nil { - return http.StatusBadRequest, err } + }() + for { + select { + case <-timer.C: + conn.WriteJSON(ch.Message()) + case <-readMessagesCh: + resetTimer(idlePingShort) + case <-wroteMessagesCh: + resetTimer(idlePing) + case <-doneCh: + return 200, nil + } } } diff --git a/doc_test.go b/doc_test.go index e9e3beb..7de27e6 100644 --- a/doc_test.go +++ b/doc_test.go @@ -1,6 +1,7 @@ package caddyhugo import ( + "encoding/json" "io/ioutil" "os" "os/exec" @@ -81,12 +82,224 @@ func TestEdits(t *testing.T) { doc.doc.Apply(send...) <-time.After(5 * time.Second) + + mtx.Lock() + defer mtx.Unlock() if len(received) != len(send) { t.Errorf("expected %d deltas, received %d; expected: %v, received: %v", len(send), len(received), send, received) } } +type WebsocketTester struct { + receivedPointer int + received [][]byte + wroteMessages []Message + wroteDeltas []acedoc.Delta + mtx sync.Mutex +} + +func (ws *WebsocketTester) ReadJSON(v interface{}) error { + ws.mtx.Lock() + defer ws.mtx.Unlock() + + if len(ws.received) <= ws.receivedPointer { + return nil + } + + err := json.Unmarshal(ws.received[ws.receivedPointer], v) + ws.receivedPointer++ + return err +} + +func (ws *WebsocketTester) WriteJSON(v interface{}) error { + ws.mtx.Lock() + defer ws.mtx.Unlock() + + m, ok := v.(Message) + if !ok { + panic("wrong type written to WebsocketTester") + } + + if len(m.Deltas) == 0 { + return nil + } + + ws.wroteMessages = append(ws.wroteMessages, m) + ws.wroteDeltas = append(ws.wroteDeltas, m.Deltas...) + + return nil +} + +func (ws *WebsocketTester) ReceiveJSON(v interface{}) error { + ws.mtx.Lock() + defer ws.mtx.Unlock() + + out, err := json.Marshal(v) + if err != nil { + return err + } + + ws.received = append(ws.received, out) + return nil +} + +func TestDeltasSingle(t *testing.T) { + w := NewWorld(t) + defer w.Clean() + + const title = "test" + + _, err := w.CH.NewContent(title, "") + if err != nil { + t.Fatal("couldn't create new content:", err) + } + + client := new(WebsocketTester) + + doc, err := w.CH.client("content/" + title + ".md") + if err != nil { + t.Fatal("couldn't establish docref for client 0:", err) + } + + go w.CH.handleConn(client, doc) + + a := acedoc.Insert(0, 0, "a") + + // pretend to get one sent from the "browser" + client.ReceiveJSON(w.CH.Message(a)) + + // wait to make sure it was processed + time.Sleep(50 * time.Millisecond) + + // we shouldn't have written back to the client, + // so we expect to have written 0 messages + if len(client.wroteMessages) != 0 { + t.Errorf("client wrote %d messages, should have written %d", len(client.wroteMessages), 0) + } + + // we received one, so make sure that's counted properly + if len(client.received) != 1 { + t.Errorf("client has %d messages, should have received %d", len(client.received), 1) + } +} + +func TestDeltasDouble(t *testing.T) { + w := NewWorld(t) + defer w.Clean() + + const title = "test" + + _, err := w.CH.NewContent(title, "") + if err != nil { + t.Fatal("couldn't create new content:", err) + } + + clientA := new(WebsocketTester) + clientB := new(WebsocketTester) + + doc, err := w.CH.client("content/" + title + ".md") + if err != nil { + t.Fatal("couldn't establish docref for client 0:", err) + } + + go w.CH.handleConn(clientA, doc) + go w.CH.handleConn(clientB, doc) + + // send the first message, simulating the browser on clientA + clientA.ReceiveJSON(w.CH.Message(acedoc.Insert(0, 0, "a"))) + + time.Sleep(100 * time.Millisecond) + + clientA.mtx.Lock() + clientB.mtx.Lock() + + // so we expect clientA to have written 0 messages, and + // clientB to have written 1 + if len(clientA.wroteMessages) != 0 || len(clientB.wroteMessages) != 1 { + t.Errorf("clientA wrote %d messages, should have written 0. clientB wrote %d, should have written 1", len(clientA.wroteMessages), len(clientB.wroteMessages)) + } + + // we received one via clientA and zero via clientB, so make sure + // that's counted properly + if len(clientA.received) != 1 || len(clientB.received) != 0 { + t.Errorf("clientA has %d messages, should have received 1; clientB has %d messages, should have received 0", len(clientA.received), len(clientB.received)) + } + + clientA.mtx.Unlock() + clientB.mtx.Unlock() + + // send the second message, via clientB + clientB.ReceiveJSON(w.CH.Message(acedoc.Insert(0, 0, "b"))) + + time.Sleep(100 * time.Millisecond) + + clientA.mtx.Lock() + clientB.mtx.Lock() + + // so we expect clientA to have written 1 message this time, and + // clientB to have written nothing new, so 1 still + if len(clientA.wroteMessages) != 1 || len(clientB.wroteMessages) != 1 { + t.Errorf("clientA wrote %d messages, should have written 1. clientB wrote %d, should have written 1 (just from before)", len(clientA.wroteMessages), len(clientB.wroteMessages)) + } + + // we received zero (new) via clientA and one via clientB, so make sure + // that's counted properly + if len(clientA.received) != 1 || len(clientB.received) != 1 { + t.Errorf("clientA has %d messages, should have received 1; clientB has %d messages, should have received 1", len(clientA.received), len(clientB.received)) + } + clientA.mtx.Unlock() + clientB.mtx.Unlock() +} + +func TestDeltasMulti(t *testing.T) { + w := NewWorld(t) + defer w.Clean() + + const title = "test" + + _, err := w.CH.NewContent(title, "") + if err != nil { + t.Fatal("couldn't create new content:", err) + } + + clients := []*WebsocketTester{{}, {}, {}} + + doc, err := w.CH.client("content/" + title + ".md") + if err != nil { + t.Fatal("couldn't establish docref:", err) + } + + go w.CH.handleConn(clients[0], doc) + go w.CH.handleConn(clients[1], doc) + go w.CH.handleConn(clients[2], doc) + + a := acedoc.Insert(0, 0, "a") + b := acedoc.Insert(0, 0, "b") + c := acedoc.Insert(0, 0, "c") + + clients[0].ReceiveJSON(w.CH.Message(a)) + clients[1].ReceiveJSON(w.CH.Message(b)) + clients[2].ReceiveJSON(w.CH.Message(c)) + + time.Sleep(400 * time.Millisecond) + + for i, client := range clients { + client.mtx.Lock() + // all clients should have "written" 2 deltas (could be the same + // message) that came from the other clients + if len(client.wroteDeltas) != 2 { + t.Errorf("client %d wrote %d deltas, should have written 2", i, len(client.wroteDeltas)) + } + + // all clients "received" 1 message from the "browser" + if len(client.received) != 1 { + t.Errorf("client %d has %d messages, should have received 1", i, len(client.received)) + } + client.mtx.Unlock() + } +} + func TestPagesInPagesOut(t *testing.T) { w := NewWorld(t) defer w.Clean()