chromedp/cmd/chromedp-proxy/main.go

192 lines
4.5 KiB
Go
Raw Normal View History

// chromedp-proxy provides a cli utility that will proxy requests from a Chrome
// Debugging Protocol client to a application instance.
//
// chromedp-proxy is particularly useful for recording events/data from
// Selenium (ChromeDriver), Chrome DevTools in the browser, or for debugging
// remote application instances compatible with the debugging protocol.
//
// Please see README.md for more information on using chromedp-proxy.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"io/ioutil"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
"path"
"regexp"
"github.com/gorilla/websocket"
)
var (
flagListen = flag.String("l", "localhost:9223", "listen address")
flagRemote = flag.String("r", "localhost:9222", "remote address")
flagNoLog = flag.Bool("n", false, "disable logging to file")
flagLogMask = flag.String("log", "logs/cdp-%s.log", "log file mask")
)
const (
incomingBufferSize = 10 * 1024 * 1024
outgoingBufferSize = 25 * 1024 * 1024
)
var wsUpgrader = &websocket.Upgrader{
ReadBufferSize: incomingBufferSize,
WriteBufferSize: outgoingBufferSize,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
var wsDialer = &websocket.Dialer{
ReadBufferSize: outgoingBufferSize,
WriteBufferSize: incomingBufferSize,
}
func main() {
flag.Parse()
mux := http.NewServeMux()
simplep := httputil.NewSingleHostReverseProxy(&url.URL{Scheme: "http", Host: *flagRemote})
mux.Handle("/json", simplep)
mux.Handle("/", simplep)
mux.HandleFunc("/devtools/page/", func(res http.ResponseWriter, req *http.Request) {
id := path.Base(req.URL.Path)
f, logger := createLog(id)
if f != nil {
defer f.Close()
}
logger.Printf("---------- connection from %s ----------", req.RemoteAddr)
ver, err := checkVersion()
if err != nil {
msg := fmt.Sprintf("version error, got: %v", err)
logger.Println(msg)
http.Error(res, msg, 500)
return
}
logger.Printf("endpoint %s reported: %s", *flagRemote, string(ver))
endpoint := "ws://" + *flagRemote + "/devtools/page/" + id
// connect outgoing websocket
logger.Printf("connecting to %s", endpoint)
out, pres, err := wsDialer.Dial(endpoint, nil)
if err != nil {
msg := fmt.Sprintf("could not connect to %s, got: %v", endpoint, err)
logger.Println(msg)
http.Error(res, msg, 500)
return
}
defer pres.Body.Close()
defer out.Close()
logger.Printf("connected to %s", endpoint)
// connect incoming websocket
logger.Printf("upgrading connection on %s", req.RemoteAddr)
in, err := wsUpgrader.Upgrade(res, req, nil)
if err != nil {
msg := fmt.Sprintf("could not upgrade websocket from %s, got: %v", req.RemoteAddr, err)
logger.Println(msg)
http.Error(res, msg, 500)
return
}
defer in.Close()
logger.Printf("upgraded connection on %s", req.RemoteAddr)
ctxt, cancel := context.WithCancel(context.Background())
defer cancel()
errc := make(chan error, 1)
go proxyWS(ctxt, logger, "<-", in, out, errc)
go proxyWS(ctxt, logger, "->", out, in, errc)
<-errc
logger.Printf("---------- closing %s ----------", req.RemoteAddr)
})
log.Fatal(http.ListenAndServe(*flagListen, mux))
}
func proxyWS(ctxt context.Context, logger *log.Logger, prefix string, in, out *websocket.Conn, errc chan error) {
var mt int
var buf []byte
var err error
for {
select {
default:
mt, buf, err = in.ReadMessage()
if err != nil {
errc <- err
return
}
logger.Printf("%s %s", prefix, string(buf))
err = out.WriteMessage(mt, buf)
if err != nil {
errc <- err
return
}
case <-ctxt.Done():
return
}
}
}
func checkVersion() ([]byte, error) {
cl := &http.Client{}
req, err := http.NewRequest("GET", "http://"+*flagRemote+"/json/version", nil)
if err != nil {
return nil, err
}
res, err := cl.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
body, err := ioutil.ReadAll(res.Body)
if err != nil {
return nil, err
}
var v map[string]string
err = json.Unmarshal(body, &v)
if err != nil {
return nil, fmt.Errorf("expected json result")
}
return body, nil
}
var (
cleanRE = regexp.MustCompile(`[^a-zA-Z0-9_\-\.]`)
)
func createLog(id string) (io.Closer, *log.Logger) {
var f io.Closer
var w io.Writer = os.Stdout
if !*flagNoLog && *flagLogMask != "" {
l, err := os.OpenFile(fmt.Sprintf(*flagLogMask, cleanRE.ReplaceAllString(id, "")), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0644)
if err != nil {
log.Fatal(err)
}
f = l
w = io.MultiWriter(os.Stdout, l)
}
return f, log.New(w, "", log.LstdFlags)
}