179 lines
3.9 KiB
Go
179 lines
3.9 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"encoding/json"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"io"
|
||
|
"io/ioutil"
|
||
|
"log"
|
||
|
"net/http"
|
||
|
"net/http/httputil"
|
||
|
"net/url"
|
||
|
"os"
|
||
|
"path"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
flagListen = flag.String("l", "localhost:9223", "listen address")
|
||
|
flagRemote = flag.String("r", "localhost:9222", "remote address")
|
||
|
flagNoLog = flag.Bool("n", false, "disable logging to file")
|
||
|
flagLogMask = flag.String("log", "logs/cdp-%s.log", "log file mask")
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
IncomingBufferSize = 10 * 1024 * 1024
|
||
|
OutgoingBufferSize = 25 * 1024 * 1024
|
||
|
)
|
||
|
|
||
|
var wsUpgrader = &websocket.Upgrader{
|
||
|
ReadBufferSize: IncomingBufferSize,
|
||
|
WriteBufferSize: OutgoingBufferSize,
|
||
|
CheckOrigin: func(r *http.Request) bool {
|
||
|
return true
|
||
|
},
|
||
|
}
|
||
|
|
||
|
var wsDialer = &websocket.Dialer{
|
||
|
ReadBufferSize: OutgoingBufferSize,
|
||
|
WriteBufferSize: IncomingBufferSize,
|
||
|
}
|
||
|
|
||
|
func main() {
|
||
|
flag.Parse()
|
||
|
|
||
|
mux := http.NewServeMux()
|
||
|
simplep := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: *flagRemote})
|
||
|
mux.Handle("/json", simplep)
|
||
|
mux.Handle("/", simplep)
|
||
|
mux.HandleFunc("/devtools/page/", func(res http.ResponseWriter, req *http.Request) {
|
||
|
id := path.Base(req.URL.Path)
|
||
|
f, logger := createLog(id)
|
||
|
if f != nil {
|
||
|
defer f.Close()
|
||
|
}
|
||
|
logger.Printf("---------- connection from %s ----------", req.RemoteAddr)
|
||
|
|
||
|
ver, err := checkVersion()
|
||
|
if err != nil {
|
||
|
msg := fmt.Sprintf("version error, got: %v", err)
|
||
|
logger.Println(msg)
|
||
|
http.Error(res, msg, 500)
|
||
|
return
|
||
|
}
|
||
|
logger.Printf("endpoint %s reported: %s", *flagRemote, string(ver))
|
||
|
|
||
|
endpoint := "ws://" + *flagRemote + "/devtools/page/" + id
|
||
|
|
||
|
// connect outgoing websocket
|
||
|
logger.Printf("connecting to %s", endpoint)
|
||
|
out, pres, err := wsDialer.Dial(endpoint, nil)
|
||
|
if err != nil {
|
||
|
msg := fmt.Sprintf("could not connect to %s, got: %v", endpoint, err)
|
||
|
logger.Println(msg)
|
||
|
http.Error(res, msg, 500)
|
||
|
return
|
||
|
}
|
||
|
defer pres.Body.Close()
|
||
|
defer out.Close()
|
||
|
|
||
|
logger.Printf("connected to %s", endpoint)
|
||
|
|
||
|
// connect incoming websocket
|
||
|
logger.Printf("upgrading connection on %s", req.RemoteAddr)
|
||
|
in, err := wsUpgrader.Upgrade(res, req, nil)
|
||
|
if err != nil {
|
||
|
msg := fmt.Sprintf("could not upgrade websocket from %s, got: %v", req.RemoteAddr, err)
|
||
|
logger.Println(msg)
|
||
|
http.Error(res, msg, 500)
|
||
|
return
|
||
|
}
|
||
|
defer in.Close()
|
||
|
logger.Printf("upgraded connection on %s", req.RemoteAddr)
|
||
|
|
||
|
ctxt, cancel := context.WithCancel(context.Background())
|
||
|
defer cancel()
|
||
|
|
||
|
errc := make(chan error, 1)
|
||
|
go proxyWS(ctxt, logger, "<-", in, out, errc)
|
||
|
go proxyWS(ctxt, logger, "->", out, in, errc)
|
||
|
<-errc
|
||
|
logger.Printf("---------- closing %s ----------", req.RemoteAddr)
|
||
|
})
|
||
|
|
||
|
log.Fatal(http.ListenAndServe(*flagListen, mux))
|
||
|
}
|
||
|
|
||
|
func proxyWS(ctxt context.Context, logger *log.Logger, prefix string, in, out *websocket.Conn, errc chan error) {
|
||
|
var mt int
|
||
|
var buf []byte
|
||
|
var err error
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
default:
|
||
|
mt, buf, err = in.ReadMessage()
|
||
|
if err != nil {
|
||
|
errc <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
logger.Printf("%s %s", prefix, string(buf))
|
||
|
|
||
|
err = out.WriteMessage(mt, buf)
|
||
|
if err != nil {
|
||
|
errc <- err
|
||
|
return
|
||
|
}
|
||
|
|
||
|
case <-ctxt.Done():
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func checkVersion() ([]byte, error) {
|
||
|
cl := &http.Client{}
|
||
|
req, err := http.NewRequest("GET", "http://"+*flagRemote+"/json/version", nil)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
res, err := cl.Do(req)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
defer res.Body.Close()
|
||
|
|
||
|
body, err := ioutil.ReadAll(res.Body)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
var v map[string]string
|
||
|
err = json.Unmarshal(body, &v)
|
||
|
if err != nil {
|
||
|
return nil, fmt.Errorf("expected json result")
|
||
|
}
|
||
|
|
||
|
return body, nil
|
||
|
}
|
||
|
|
||
|
func createLog(id string) (io.Closer, *log.Logger) {
|
||
|
var f io.Closer
|
||
|
var w io.Writer = os.Stdout
|
||
|
if !*flagNoLog && *flagLogMask != "" {
|
||
|
l, err := os.OpenFile(fmt.Sprintf(*flagLogMask, id), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
|
||
|
if err != nil {
|
||
|
log.Fatal(err)
|
||
|
}
|
||
|
|
||
|
f = l
|
||
|
w = io.MultiWriter(os.Stdout, l)
|
||
|
}
|
||
|
return f, log.New(w, "", log.LstdFlags)
|
||
|
}
|