diff --git a/actions.go b/actions.go index 60cf966..1590776 100644 --- a/actions.go +++ b/actions.go @@ -49,7 +49,6 @@ func Sleep(d time.Duration) Action { return ActionFunc(func(ctxt context.Context, h cdp.Executor) error { select { case <-time.After(d): - case <-ctxt.Done(): return ctxt.Err() } diff --git a/allocate.go b/allocate.go index 5110ba1..382af83 100644 --- a/allocate.go +++ b/allocate.go @@ -136,11 +136,12 @@ func (p *ExecAllocator) Allocate(ctx context.Context) (*Browser, error) { } stderr.Close() - browser, err := NewBrowser(wsURL) + browser, err := NewBrowser(ctx, wsURL) if err != nil { return nil, err } browser.UserDataDir = dataDir + browser.Start(ctx) return browser, nil } diff --git a/browser.go b/browser.go index 5995af1..b8c1bdc 100644 --- a/browser.go +++ b/browser.go @@ -8,10 +8,13 @@ package chromedp import ( "context" + "encoding/json" "log" "sync/atomic" "github.com/chromedp/cdproto" + "github.com/chromedp/cdproto/cdp" + "github.com/chromedp/cdproto/target" "github.com/mailru/easyjson" ) @@ -21,26 +24,39 @@ import ( type Browser struct { UserDataDir string + pages map[target.SessionID]*Target + conn Transport // next is the next message id. next int64 + cmdQueue chan cmdJob + + // qres is the incoming command result queue. + qres chan *cdproto.Message + // logging funcs logf func(string, ...interface{}) errf func(string, ...interface{}) } +type cmdJob struct { + msg *cdproto.Message + resp chan *cdproto.Message +} + // NewBrowser creates a new browser. -func NewBrowser(urlstr string, opts ...BrowserOption) (*Browser, error) { - conn, err := Dial(ForceIP(urlstr)) +func NewBrowser(ctx context.Context, urlstr string, opts ...BrowserOption) (*Browser, error) { + conn, err := DialContext(ctx, ForceIP(urlstr)) if err != nil { return nil, err } b := &Browser{ - conn: conn, - logf: log.Printf, + conn: conn, + pages: make(map[target.SessionID]*Target, 1024), + logf: log.Printf, } // apply options @@ -72,25 +88,182 @@ func (b *Browser) Shutdown() error { // send writes the supplied message and params. func (b *Browser) send(method cdproto.MethodType, params easyjson.RawMessage) error { msg := &cdproto.Message{ - Method: method, ID: atomic.AddInt64(&b.next, 1), + Method: method, Params: params, } - buf, err := msg.MarshalJSON() - if err != nil { - return err - } - return b.conn.Write(buf) + return b.conn.Write(msg) } -// sendToTarget writes the supplied message to the target. -func (b *Browser) sendToTarget(targetID string, method cdproto.MethodType, params easyjson.RawMessage) error { +func (b *Browser) executorForTarget(ctx context.Context, sessionID target.SessionID) *Target { + if sessionID == "" { + panic("empty session ID") + } + if t, ok := b.pages[sessionID]; ok { + return t + } + t := &Target{ + browser: b, + sessionID: sessionID, + + eventQueue: make(chan *cdproto.Message, 1024), + waitQueue: make(chan func(cur *cdp.Frame) bool, 1024), + frames: make(map[cdp.FrameID]*cdp.Frame), + + logf: b.logf, + errf: b.errf, + } + go t.run(ctx) + b.pages[sessionID] = t + return t +} + +func (b *Browser) Execute(ctx context.Context, method string, params json.Marshaler, res json.Unmarshaler) error { + paramsMsg := emptyObj + if params != nil { + var err error + if paramsMsg, err = json.Marshal(params); err != nil { + return err + } + } + + id := atomic.AddInt64(&b.next, 1) + ch := make(chan *cdproto.Message, 1) + b.cmdQueue <- cmdJob{ + msg: &cdproto.Message{ + ID: id, + Method: cdproto.MethodType(method), + Params: paramsMsg, + }, + resp: ch, + } + select { + case msg := <-ch: + switch { + case msg == nil: + return ErrChannelClosed + case msg.Error != nil: + return msg.Error + case res != nil: + return json.Unmarshal(msg.Result, res) + } + case <-ctx.Done(): + return ctx.Err() + } return nil } -// CreateContext creates a new browser context. -func (b *Browser) CreateContext() (context.Context, error) { - return nil, nil +func (b *Browser) Start(ctx context.Context) { + b.cmdQueue = make(chan cmdJob) + b.qres = make(chan *cdproto.Message) + + go b.run(ctx) +} + +func (b *Browser) run(ctx context.Context) { + defer b.conn.Close() + + // add cancel to context + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + defer cancel() + + for { + select { + case <-ctx.Done(): + return + default: + // continue below + } + msg, err := b.conn.Read() + if err != nil { + return + } + var sessionID target.SessionID + if msg.Method == cdproto.EventTargetReceivedMessageFromTarget { + recv := new(target.EventReceivedMessageFromTarget) + if err := json.Unmarshal(msg.Params, recv); err != nil { + b.errf("%s", err) + continue + } + sessionID = recv.SessionID + msg = new(cdproto.Message) + if err := json.Unmarshal([]byte(recv.Message), msg); err != nil { + b.errf("%s", err) + continue + } + } + + switch { + case msg.Method != "": + if sessionID == "" { + // TODO: are we interested in + // these events? + continue + } + + page, ok := b.pages[sessionID] + if !ok { + b.errf("unknown session ID %q", sessionID) + continue + } + select { + case page.eventQueue <- msg: + default: + panic("eventQueue is full") + } + + case msg.ID != 0: + b.qres <- msg + + default: + b.errf("ignoring malformed incoming message (missing id or method): %#v", msg) + } + } + }() + + respByID := make(map[int64]chan *cdproto.Message) + + // process queues + for { + select { + case res := <-b.qres: + resp, ok := respByID[res.ID] + if !ok { + b.errf("id %d not present in response map", res.ID) + continue + } + if resp != nil { + // resp could be nil, if we're not interested in + // this response; for CommandSendMessageToTarget. + resp <- res + close(resp) + } + delete(respByID, res.ID) + + case q := <-b.cmdQueue: + if _, ok := respByID[q.msg.ID]; ok { + b.errf("id %d already present in response map", q.msg.ID) + continue + } + respByID[q.msg.ID] = q.resp + + if q.msg.Method == "" { + // Only register the chananel in respByID; + // useful for CommandSendMessageToTarget. + continue + } + if err := b.conn.Write(q.msg); err != nil { + b.errf("%s", err) + continue + } + + case <-ctx.Done(): + return + } + } } // BrowserOption is a browser option. diff --git a/conn.go b/conn.go index 8274b0c..b2de814 100644 --- a/conn.go +++ b/conn.go @@ -1,10 +1,12 @@ package chromedp import ( + "context" "io" "net" "strings" + "github.com/chromedp/cdproto" "github.com/gorilla/websocket" ) @@ -18,8 +20,8 @@ var ( // Transport is the common interface to send/receive messages to a target. type Transport interface { - Read() ([]byte, error) - Write([]byte) error + Read() (*cdproto.Message, error) + Write(*cdproto.Message) error io.Closer } @@ -29,28 +31,28 @@ type Conn struct { } // Read reads the next message. -func (c *Conn) Read() ([]byte, error) { - _, buf, err := c.ReadMessage() - if err != nil { +func (c *Conn) Read() (*cdproto.Message, error) { + msg := new(cdproto.Message) + if err := c.ReadJSON(msg); err != nil { return nil, err } - return buf, nil + return msg, nil } // Write writes a message. -func (c *Conn) Write(buf []byte) error { - return c.WriteMessage(websocket.TextMessage, buf) +func (c *Conn) Write(msg *cdproto.Message) error { + return c.WriteJSON(msg) } // Dial dials the specified websocket URL using gorilla/websocket. -func Dial(urlstr string) (*Conn, error) { +func DialContext(ctx context.Context, urlstr string) (*Conn, error) { d := &websocket.Dialer{ ReadBufferSize: DefaultReadBufferSize, WriteBufferSize: DefaultWriteBufferSize, } // connect - conn, _, err := d.Dial(urlstr, nil) + conn, _, err := d.DialContext(ctx, urlstr, nil) if err != nil { return nil, err } diff --git a/context.go b/context.go index 0a2531c..44399a7 100644 --- a/context.go +++ b/context.go @@ -3,7 +3,15 @@ package chromedp import ( "context" "encoding/json" - "net/http" + "fmt" + + "github.com/chromedp/cdproto/css" + "github.com/chromedp/cdproto/dom" + "github.com/chromedp/cdproto/inspector" + "github.com/chromedp/cdproto/log" + "github.com/chromedp/cdproto/page" + "github.com/chromedp/cdproto/runtime" + "github.com/chromedp/cdproto/target" ) // Executor @@ -16,10 +24,8 @@ type Context struct { Allocator Allocator browser *Browser - handler *TargetHandler - logf func(string, ...interface{}) - errf func(string, ...interface{}) + sessionID target.SessionID } // Wait can be called after cancelling the context containing Context, to block @@ -75,41 +81,47 @@ func Run(ctx context.Context, action Action) error { } c.browser = browser } - if c.handler == nil { - if err := c.newHandler(ctx); err != nil { + if c.sessionID == "" { + if err := c.newSession(ctx); err != nil { return err } } - return action.Do(ctx, c.handler) + return action.Do(ctx, c.browser.executorForTarget(ctx, c.sessionID)) } -func (c *Context) newHandler(ctx context.Context) error { - // TODO: add RemoteAddr() to the Transport interface? - conn := c.browser.conn.(*Conn).Conn - addr := conn.RemoteAddr() - url := "http://" + addr.String() + "/json/new" - resp, err := http.Get(url) +func (c *Context) newSession(ctx context.Context) error { + create := target.CreateTarget("about:blank") + targetID, err := create.Do(ctx, c.browser) if err != nil { return err } - defer resp.Body.Close() - var wurl withWebsocketURL - if err := json.NewDecoder(resp.Body).Decode(&wurl); err != nil { - return err - } - c.handler, err = NewTargetHandler(wurl.WebsocketURL) + + attach := target.AttachToTarget(targetID) + sessionID, err := attach.Do(ctx, c.browser) if err != nil { return err } - if err := c.handler.Run(ctx); err != nil { - return err + + target := c.browser.executorForTarget(ctx, sessionID) + + // enable domains + for _, enable := range []Action{ + log.Enable(), + runtime.Enable(), + //network.Enable(), + inspector.Enable(), + page.Enable(), + dom.Enable(), + css.Enable(), + } { + if err := enable.Do(ctx, target); err != nil { + return fmt.Errorf("unable to execute %T: %v", enable, err) + } } + + c.sessionID = sessionID return nil } -type withWebsocketURL struct { - WebsocketURL string `json:"webSocketDebuggerUrl"` -} - // ContextOption type ContextOption func(*Context) diff --git a/handler.go b/handler.go index 668bffc..567e9ce 100644 --- a/handler.go +++ b/handler.go @@ -3,235 +3,138 @@ package chromedp import ( "context" "encoding/json" - "fmt" - "reflect" - goruntime "runtime" - "strings" "sync" + "sync/atomic" "time" "github.com/mailru/easyjson" "github.com/chromedp/cdproto" "github.com/chromedp/cdproto/cdp" - "github.com/chromedp/cdproto/css" "github.com/chromedp/cdproto/dom" "github.com/chromedp/cdproto/inspector" - "github.com/chromedp/cdproto/log" "github.com/chromedp/cdproto/page" - "github.com/chromedp/cdproto/runtime" + "github.com/chromedp/cdproto/target" ) -// TargetHandler manages a Chrome DevTools Protocol target. -type TargetHandler struct { - conn Transport +// Target manages a Chrome DevTools Protocol target. +type Target struct { + browser *Browser + sessionID target.SessionID + + waitQueue chan func(cur *cdp.Frame) bool + eventQueue chan *cdproto.Message + + // below are the old TargetHandler fields. // frames is the set of encountered frames. frames map[cdp.FrameID]*cdp.Frame - // cur is the current top level frame. - cur *cdp.Frame - - // qcmd is the outgoing message queue. - qcmd chan *cdproto.Message - - // qres is the incoming command result queue. - qres chan *cdproto.Message - - // qevents is the incoming event queue. - qevents chan *cdproto.Message - - // detached is closed when the detached event is received. - detached chan *inspector.EventDetached - - pageWaitGroup, domWaitGroup *sync.WaitGroup - - // last is the last sent message identifier. - last int64 - lastm sync.Mutex - - // res is the id->result channel map. - res map[int64]chan *cdproto.Message - resrw sync.RWMutex + // cur is the current top level frame. TODO: delete mutex + curMu sync.RWMutex + cur *cdp.Frame // logging funcs - logf, debugf, errf func(string, ...interface{}) - - sync.RWMutex + logf, errf func(string, ...interface{}) } -// NewTargetHandler creates a new handler for the specified client target. -func NewTargetHandler(urlstr string, opts ...TargetHandlerOption) (*TargetHandler, error) { - conn, err := Dial(urlstr) - if err != nil { - return nil, err - } +func (t *Target) run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case msg := <-t.eventQueue: + //fmt.Printf("%d %s: %s\n", msg.ID, msg.Method, msg.Params) + if err := t.processEvent(ctx, msg); err != nil { + t.errf("could not process event: %v", err) + continue + } + default: + // prevent busy spinning. TODO: do better + time.Sleep(5 * time.Millisecond) + n := len(t.waitQueue) + if n == 0 { + continue + } - h := &TargetHandler{ - conn: conn, - errf: func(string, ...interface{}) {}, - } + t.curMu.RLock() + cur := t.cur + t.curMu.RUnlock() + if cur == nil { + continue + } - for _, o := range opts { - o(h) - } - - return h, nil -} - -// Run starts the processing of commands and events of the client target -// provided to NewTargetHandler. -// -// Callers can stop Run by closing the passed context. -func (h *TargetHandler) Run(ctxt context.Context) error { - // reset - h.Lock() - h.frames = make(map[cdp.FrameID]*cdp.Frame) - h.qcmd = make(chan *cdproto.Message) - h.qres = make(chan *cdproto.Message) - h.qevents = make(chan *cdproto.Message) - h.res = make(map[int64]chan *cdproto.Message) - h.detached = make(chan *inspector.EventDetached, 1) - h.pageWaitGroup = new(sync.WaitGroup) - h.domWaitGroup = new(sync.WaitGroup) - h.Unlock() - - // run - go h.run(ctxt) - - // enable domains - for _, a := range []Action{ - log.Enable(), - runtime.Enable(), - //network.Enable(), - inspector.Enable(), - page.Enable(), - dom.Enable(), - css.Enable(), - } { - if err := a.Do(ctxt, h); err != nil { - return fmt.Errorf("unable to execute %s: %v", reflect.TypeOf(a), err) + for i := 0; i < n; i++ { + fn := <-t.waitQueue + if !fn(cur) { + // try again later. + t.waitQueue <- fn + } + } } } +} - h.Lock() - - // get page resources - tree, err := page.GetResourceTree().Do(ctxt, h) +func (t *Target) Execute(ctx context.Context, method string, params json.Marshaler, res json.Unmarshaler) error { + paramsMsg := emptyObj + if params != nil { + var err error + if paramsMsg, err = json.Marshal(params); err != nil { + return err + } + } + innerID := atomic.AddInt64(&t.browser.next, 1) + msg := &cdproto.Message{ + ID: innerID, + Method: cdproto.MethodType(method), + Params: paramsMsg, + } + msgJSON, err := json.Marshal(msg) if err != nil { - return fmt.Errorf("unable to get resource tree: %v", err) + return err + } + sendParams := target.SendMessageToTarget(string(msgJSON)). + WithSessionID(t.sessionID) + sendParamsJSON, _ := json.Marshal(sendParams) + + // We want to grab the response from the inner message. + ch := make(chan *cdproto.Message, 1) + t.browser.cmdQueue <- cmdJob{ + msg: &cdproto.Message{ID: innerID}, + resp: ch, } - h.frames[tree.Frame.ID] = tree.Frame - h.cur = tree.Frame - - for _, c := range tree.ChildFrames { - h.frames[c.Frame.ID] = c.Frame + // The response from the outer message is uninteresting; pass a nil + // resp channel. + outerID := atomic.AddInt64(&t.browser.next, 1) + t.browser.cmdQueue <- cmdJob{ + msg: &cdproto.Message{ + ID: outerID, + Method: target.CommandSendMessageToTarget, + Params: sendParamsJSON, + }, } - h.Unlock() - - h.documentUpdated(ctxt) - + select { + case msg := <-ch: + switch { + case msg == nil: + return ErrChannelClosed + case msg.Error != nil: + return msg.Error + case res != nil: + return json.Unmarshal(msg.Result, res) + } + case <-ctx.Done(): + return ctx.Err() + } return nil } -// run handles the actual message processing to / from the web socket connection. -func (h *TargetHandler) run(ctxt context.Context) { - defer h.conn.Close() - - // add cancel to context - ctxt, cancel := context.WithCancel(ctxt) - defer cancel() - - go func() { - defer cancel() - - for { - select { - default: - msg, err := h.read() - if err != nil { - return - } - - switch { - case msg.Method != "": - select { - case h.qevents <- msg: - case <-ctxt.Done(): - return - } - - case msg.ID != 0: - select { - case h.qres <- msg: - case <-ctxt.Done(): - return - } - - default: - h.errf("ignoring malformed incoming message (missing id or method): %#v", msg) - } - - case <-h.detached: - // FIXME: should log when detached, and reason - return - - case <-ctxt.Done(): - return - } - } - }() - - // process queues - for { - select { - case ev := <-h.qevents: - err := h.processEvent(ctxt, ev) - if err != nil { - h.errf("could not process event %s: %v", ev.Method, err) - } - - case res := <-h.qres: - err := h.processResult(res) - if err != nil { - h.errf("could not process result for message %d: %v", res.ID, err) - } - - case cmd := <-h.qcmd: - err := h.processCommand(cmd) - if err != nil { - h.errf("could not process command message %d: %v", cmd.ID, err) - } - - case <-ctxt.Done(): - return - } - } -} - -// read reads a message from the client connection. -func (h *TargetHandler) read() (*cdproto.Message, error) { - // read - buf, err := h.conn.Read() - if err != nil { - return nil, err - } - - //h.debugf("-> %s", string(buf)) - - // unmarshal - msg := new(cdproto.Message) - if err := json.Unmarshal(buf, msg); err != nil { - return nil, err - } - - return msg, nil -} +// below are the old TargetHandler methods. // processEvent processes an incoming event. -func (h *TargetHandler) processEvent(ctxt context.Context, msg *cdproto.Message) error { +func (t *Target) processEvent(ctxt context.Context, msg *cdproto.Message) error { if msg == nil { return ErrChannelClosed } @@ -251,46 +154,29 @@ func (h *TargetHandler) processEvent(ctxt context.Context, msg *cdproto.Message) return err } - switch e := ev.(type) { + switch ev.(type) { case *inspector.EventDetached: - h.Lock() - defer h.Unlock() - h.detached <- e return nil - case *dom.EventDocumentUpdated: - h.domWaitGroup.Wait() - go h.documentUpdated(ctxt) + t.documentUpdated(ctxt) return nil } - d := msg.Method.Domain() - if d != "Page" && d != "DOM" { - return nil - } - - switch d { + switch msg.Method.Domain() { case "Page": - h.pageWaitGroup.Add(1) - go h.pageEvent(ctxt, ev) - + t.pageEvent(ctxt, ev) case "DOM": - h.domWaitGroup.Add(1) - go h.domEvent(ctxt, ev) + t.domEvent(ctxt, ev) } - return nil } // documentUpdated handles the document updated event, retrieving the document // root for the root frame. -func (h *TargetHandler) documentUpdated(ctxt context.Context) { - f, err := h.WaitFrame(ctxt, cdp.EmptyFrameID) - if err != nil { - h.errf("could not get current frame: %v", err) - return - } - +func (t *Target) documentUpdated(ctxt context.Context) { + t.curMu.RLock() + f := t.cur + t.curMu.RUnlock() f.Lock() defer f.Unlock() @@ -300,242 +186,33 @@ func (h *TargetHandler) documentUpdated(ctxt context.Context) { } f.Nodes = make(map[cdp.NodeID]*cdp.Node) - f.Root, err = dom.GetDocument().WithPierce(true).Do(ctxt, h) + var err error + f.Root, err = dom.GetDocument().WithPierce(true).Do(ctxt, t) + if err == context.Canceled { + return // TODO: perhaps not necessary, but useful to keep the tests less noisy + } if err != nil { - h.errf("could not retrieve document root for %s: %v", f.ID, err) + t.errf("could not retrieve document root for %s: %v", f.ID, err) return } f.Root.Invalidated = make(chan struct{}) walk(f.Nodes, f.Root) } -// processResult processes an incoming command result. -func (h *TargetHandler) processResult(msg *cdproto.Message) error { - h.resrw.RLock() - defer h.resrw.RUnlock() - - ch, ok := h.res[msg.ID] - if !ok { - return fmt.Errorf("id %d not present in res map", msg.ID) - } - defer close(ch) - - ch <- msg - - return nil -} - -// processCommand writes a command to the client connection. -func (h *TargetHandler) processCommand(cmd *cdproto.Message) error { - // marshal - buf, err := json.Marshal(cmd) - if err != nil { - return err - } - - //h.debugf("<- %s", string(buf)) - - return h.conn.Write(buf) -} - // emptyObj is an empty JSON object message. var emptyObj = easyjson.RawMessage([]byte(`{}`)) -// Execute executes commandType against the endpoint passed to Run, using the -// provided context and params, decoding the result of the command to res. -func (h *TargetHandler) Execute(ctxt context.Context, methodType string, params json.Marshaler, res json.Unmarshaler) error { - var paramsBuf easyjson.RawMessage - if params == nil { - paramsBuf = emptyObj - } else { - var err error - paramsBuf, err = json.Marshal(params) - if err != nil { - return err - } - } - - id := h.next() - - // save channel - ch := make(chan *cdproto.Message, 1) - h.resrw.Lock() - h.res[id] = ch - h.resrw.Unlock() - - // queue message - select { - case h.qcmd <- &cdproto.Message{ - ID: id, - Method: cdproto.MethodType(methodType), - Params: paramsBuf, - }: - case <- ctxt.Done(): - return ctxt.Err() - } - - errch := make(chan error, 1) - go func() { - defer close(errch) - - select { - case msg := <-ch: - switch { - case msg == nil: - errch <- ErrChannelClosed - - case msg.Error != nil: - errch <- msg.Error - - case res != nil: - errch <- json.Unmarshal(msg.Result, res) - } - - case <-ctxt.Done(): - errch <- ctxt.Err() - } - - h.resrw.Lock() - defer h.resrw.Unlock() - - delete(h.res, id) - }() - - return <-errch -} - -// next returns the next message id. -func (h *TargetHandler) next() int64 { - h.lastm.Lock() - defer h.lastm.Unlock() - h.last++ - return h.last -} - -// GetRoot returns the current top level frame's root document node. -func (h *TargetHandler) GetRoot(ctxt context.Context) (*cdp.Node, error) { - var root *cdp.Node - - for { - var cur *cdp.Frame - select { - default: - h.RLock() - cur = h.cur - if cur != nil { - cur.RLock() - root = cur.Root - cur.RUnlock() - } - h.RUnlock() - - if cur != nil && root != nil { - return root, nil - } - - time.Sleep(DefaultCheckDuration) - - case <-ctxt.Done(): - return nil, ctxt.Err() - } - } -} - -// SetActive sets the currently active frame after a successful navigation. -func (h *TargetHandler) SetActive(ctxt context.Context, id cdp.FrameID) error { - // get frame - f, err := h.WaitFrame(ctxt, id) - if err != nil { - return err - } - - h.Lock() - defer h.Unlock() - - h.cur = f - - return nil -} - -// WaitFrame waits for a frame to be loaded using the provided context. -func (h *TargetHandler) WaitFrame(ctxt context.Context, id cdp.FrameID) (*cdp.Frame, error) { - // TODO: fix this - timeout := time.After(time.Second) - - for { - select { - default: - var f *cdp.Frame - var ok bool - - h.RLock() - if id == cdp.EmptyFrameID { - f, ok = h.cur, h.cur != nil - } else { - f, ok = h.frames[id] - } - h.RUnlock() - - if ok { - return f, nil - } - - time.Sleep(DefaultCheckDuration) - - case <-ctxt.Done(): - return nil, ctxt.Err() - - case <-timeout: - return nil, fmt.Errorf("timeout waiting for frame `%s`", id) - } - } -} - -// WaitNode waits for a node to be loaded using the provided context. -func (h *TargetHandler) WaitNode(ctxt context.Context, f *cdp.Frame, id cdp.NodeID) (*cdp.Node, error) { - // TODO: fix this - timeout := time.After(time.Second) - - for { - select { - default: - var n *cdp.Node - var ok bool - - f.RLock() - n, ok = f.Nodes[id] - f.RUnlock() - - if n != nil && ok { - return n, nil - } - - time.Sleep(DefaultCheckDuration) - - case <-ctxt.Done(): - return nil, ctxt.Err() - - case <-timeout: - return nil, fmt.Errorf("timeout waiting for node `%d`", id) - } - } -} - // pageEvent handles incoming page events. -func (h *TargetHandler) pageEvent(ctxt context.Context, ev interface{}) { - defer h.pageWaitGroup.Done() - +func (t *Target) pageEvent(ctxt context.Context, ev interface{}) { var id cdp.FrameID var op frameOp switch e := ev.(type) { case *page.EventFrameNavigated: - h.Lock() - h.frames[e.Frame.ID] = e.Frame - if h.cur != nil && h.cur.ID == e.Frame.ID { - h.cur = e.Frame - } - h.Unlock() + t.frames[e.Frame.ID] = e.Frame + t.curMu.Lock() + t.cur = e.Frame + t.curMu.Unlock() return case *page.EventFrameAttached: @@ -545,7 +222,10 @@ func (h *TargetHandler) pageEvent(ctxt context.Context, ev interface{}) { id, op = e.FrameID, frameDetached case *page.EventFrameStartedLoading: - id, op = e.FrameID, frameStartedLoading + // TODO: this happens before EventFrameNavigated, so the frame + // isn't in t.frames yet. + //id, op = e.FrameID, frameStartedLoading + return case *page.EventFrameStoppedLoading: id, op = e.FrameID, frameStoppedLoading @@ -563,18 +243,11 @@ func (h *TargetHandler) pageEvent(ctxt context.Context, ev interface{}) { return default: - h.errf("unhandled page event %s", reflect.TypeOf(ev)) + t.errf("unhandled page event %T", ev) return } - f, err := h.WaitFrame(ctxt, id) - if err != nil { - h.errf("could not get frame %s: %v", id, err) - return - } - - h.Lock() - defer h.Unlock() + f := t.frames[id] f.Lock() defer f.Unlock() @@ -583,15 +256,10 @@ func (h *TargetHandler) pageEvent(ctxt context.Context, ev interface{}) { } // domEvent handles incoming DOM events. -func (h *TargetHandler) domEvent(ctxt context.Context, ev interface{}) { - defer h.domWaitGroup.Done() - - // wait current frame - f, err := h.WaitFrame(ctxt, cdp.EmptyFrameID) - if err != nil { - h.errf("could not process DOM event %s: %v", reflect.TypeOf(ev), err) - return - } +func (t *Target) domEvent(ctxt context.Context, ev interface{}) { + t.curMu.RLock() + f := t.cur + t.curMu.RUnlock() var id cdp.NodeID var op nodeOp @@ -620,12 +288,6 @@ func (h *TargetHandler) domEvent(ctxt context.Context, ev interface{}) { id, op = e.NodeID, childNodeCountUpdated(e.ChildNodeCount) case *dom.EventChildNodeInserted: - if e.PreviousNodeID != cdp.EmptyNodeID { - _, err = h.WaitNode(ctxt, f, e.PreviousNodeID) - if err != nil { - return - } - } id, op = e.ParentNodeID, childNodeInserted(f.Nodes, e.PreviousNodeID, e.Node) case *dom.EventChildNodeRemoved: @@ -647,29 +309,20 @@ func (h *TargetHandler) domEvent(ctxt context.Context, ev interface{}) { id, op = e.InsertionPointID, distributedNodesUpdated(e.DistributedNodes) default: - h.errf("unhandled node event %s", reflect.TypeOf(ev)) + t.errf("unhandled node event %T", ev) return } - // retrieve node - n, err := h.WaitNode(ctxt, f, id) - if err != nil { - s := strings.TrimSuffix(goruntime.FuncForPC(reflect.ValueOf(op).Pointer()).Name(), ".func1") - i := strings.LastIndex(s, ".") - if i != -1 { - s = s[i+1:] - } - h.errf("could not perform (%s) operation on node %d (wait node): %v", s, id, err) + n, ok := f.Nodes[id] + if !ok { + // Node ID has been invalidated. Nothing to do. return } - h.Lock() - defer h.Unlock() - f.Lock() defer f.Unlock() op(n) } -type TargetHandlerOption func(*TargetHandler) +type TargetOption func(*Target) diff --git a/nav.go b/nav.go index 2063431..6463205 100644 --- a/nav.go +++ b/nav.go @@ -11,17 +11,8 @@ import ( // Navigate navigates the current frame. func Navigate(urlstr string) Action { return ActionFunc(func(ctxt context.Context, h cdp.Executor) error { - th, ok := h.(*TargetHandler) - if !ok { - return ErrInvalidHandler - } - - frameID, _, _, err := page.Navigate(urlstr).Do(ctxt, th) - if err != nil { - return err - } - - return th.SetActive(ctxt, frameID) + _, _, _, err := page.Navigate(urlstr).Do(ctxt, h) + return err }) } diff --git a/query.go b/query.go index 35aefc1..165307c 100644 --- a/query.go +++ b/query.go @@ -25,7 +25,7 @@ func Nodes(sel interface{}, nodes *[]*cdp.Node, opts ...QueryOption) Action { panic("nodes cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, n ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, n ...*cdp.Node) error { *nodes = n return nil }, opts...) @@ -37,7 +37,7 @@ func NodeIDs(sel interface{}, ids *[]cdp.NodeID, opts ...QueryOption) Action { panic("nodes cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { nodeIDs := make([]cdp.NodeID, len(nodes)) for i, n := range nodes { nodeIDs[i] = n.NodeID @@ -51,7 +51,7 @@ func NodeIDs(sel interface{}, ids *[]cdp.NodeID, opts ...QueryOption) Action { // Focus focuses the first node matching the selector. func Focus(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -62,7 +62,7 @@ func Focus(sel interface{}, opts ...QueryOption) Action { // Blur unfocuses (blurs) the first node matching the selector. func Blur(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -87,7 +87,7 @@ func Dimensions(sel interface{}, model **dom.BoxModel, opts ...QueryOption) Acti if model == nil { panic("model cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -103,7 +103,7 @@ func Text(sel interface{}, text *string, opts ...QueryOption) Action { panic("text cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -114,7 +114,7 @@ func Text(sel interface{}, text *string, opts ...QueryOption) Action { // Clear clears the values of any input/textarea nodes matching the selector. func Clear(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -190,7 +190,7 @@ func Attributes(sel interface{}, attributes *map[string]string, opts ...QueryOpt panic("attributes cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -219,7 +219,7 @@ func AttributesAll(sel interface{}, attributes *[]map[string]string, opts ...Que panic("attributes cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -243,7 +243,7 @@ func AttributesAll(sel interface{}, attributes *[]map[string]string, opts ...Que // SetAttributes sets the element attributes for the first node matching the // selector. func SetAttributes(sel interface{}, attributes map[string]string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return errors.New("expected at least one element") } @@ -265,7 +265,7 @@ func AttributeValue(sel interface{}, name string, value *string, ok *bool, opts panic("value cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return errors.New("expected at least one element") } @@ -295,7 +295,7 @@ func AttributeValue(sel interface{}, name string, value *string, ok *bool, opts // SetAttributeValue sets the element attribute with name to value for the // first node matching the selector. func SetAttributeValue(sel interface{}, name, value string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -307,7 +307,7 @@ func SetAttributeValue(sel interface{}, name, value string, opts ...QueryOption) // RemoveAttribute removes the element attribute with name from the first node // matching the selector. func RemoveAttribute(sel interface{}, name string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -322,7 +322,7 @@ func JavascriptAttribute(sel interface{}, name string, res interface{}, opts ... if res == nil { panic("res cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -334,7 +334,7 @@ func JavascriptAttribute(sel interface{}, name string, res interface{}, opts ... // SetJavascriptAttribute sets the javascript attribute for the first node // matching the selector. func SetJavascriptAttribute(sel interface{}, name, value string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -370,7 +370,7 @@ func InnerHTML(sel interface{}, html *string, opts ...QueryOption) Action { // Click sends a mouse click event to the first node matching the selector. func Click(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -382,7 +382,7 @@ func Click(sel interface{}, opts ...QueryOption) Action { // DoubleClick sends a mouse double click event to the first node matching the // selector. func DoubleClick(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -397,7 +397,7 @@ func DoubleClick(sel interface{}, opts ...QueryOption) Action { // Note: when selector matches a input[type="file"] node, then dom.SetFileInputFiles // is used to set the upload path of the input node to v. func SendKeys(sel interface{}, v string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -426,7 +426,7 @@ func SendKeys(sel interface{}, v string, opts ...QueryOption) Action { // SetUploadFiles sets the files to upload (ie, for a input[type="file"] node) // for the first node matching the selector. func SetUploadFiles(sel interface{}, files []string, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -441,7 +441,7 @@ func Screenshot(sel interface{}, picbuf *[]byte, opts ...QueryOption) Action { panic("picbuf cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -497,7 +497,7 @@ func Screenshot(sel interface{}, picbuf *[]byte, opts ...QueryOption) Action { // Submit is an action that submits the form of the first node matching the // selector belongs to. func Submit(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -519,7 +519,7 @@ func Submit(sel interface{}, opts ...QueryOption) Action { // Reset is an action that resets the form of the first node matching the // selector belongs to. func Reset(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -544,7 +544,7 @@ func ComputedStyle(sel interface{}, style *[]*css.ComputedProperty, opts ...Quer panic("style cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -567,7 +567,7 @@ func MatchedStyle(sel interface{}, style **css.GetMatchedStylesForNodeReturns, o panic("style cannot be nil") } - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } @@ -589,7 +589,7 @@ func MatchedStyle(sel interface{}, style **css.GetMatchedStylesForNodeReturns, o // ScrollIntoView scrolls the window to the first node matching the selector. func ScrollIntoView(sel interface{}, opts ...QueryOption) Action { - return QueryAfter(sel, func(ctxt context.Context, h *TargetHandler, nodes ...*cdp.Node) error { + return QueryAfter(sel, func(ctxt context.Context, h *Target, nodes ...*cdp.Node) error { if len(nodes) < 1 { return fmt.Errorf("selector `%s` did not return any nodes", sel) } diff --git a/query_test.go b/query_test.go index 6b68f3b..3487265 100644 --- a/query_test.go +++ b/query_test.go @@ -925,6 +925,7 @@ func TestFileUpload(t *testing.T) { t.Fatal(err) } defer os.Remove(tmpfile.Name()) + defer tmpfile.Close() if _, err := tmpfile.WriteString(uploadHTML); err != nil { t.Fatal(err) } diff --git a/sel.go b/sel.go index 68c0412..12cc31e 100644 --- a/sel.go +++ b/sel.go @@ -5,7 +5,6 @@ import ( "fmt" "strings" "sync" - "time" "github.com/chromedp/cdproto/cdp" "github.com/chromedp/cdproto/dom" @@ -26,9 +25,9 @@ tagname type Selector struct { sel interface{} exp int - by func(context.Context, *TargetHandler, *cdp.Node) ([]cdp.NodeID, error) - wait func(context.Context, *TargetHandler, *cdp.Node, ...cdp.NodeID) ([]*cdp.Node, error) - after func(context.Context, *TargetHandler, ...*cdp.Node) error + by func(context.Context, *Target, *cdp.Node) ([]cdp.NodeID, error) + wait func(context.Context, *Target, *cdp.Frame, ...cdp.NodeID) ([]*cdp.Node, error) + after func(context.Context, *Target, ...*cdp.Node) error } // Query is an action to query for document nodes match the specified sel and @@ -57,15 +56,11 @@ func Query(sel interface{}, opts ...QueryOption) Action { // Do satisfies the Action interface. func (s *Selector) Do(ctxt context.Context, h cdp.Executor) error { - th, ok := h.(*TargetHandler) + th, ok := h.(*Target) if !ok { return ErrInvalidHandler } - // TODO: fix this - ctxt, cancel := context.WithTimeout(ctxt, 100*time.Second) - defer cancel() - var err error select { case err = <-s.run(ctxt, th): @@ -79,53 +74,35 @@ func (s *Selector) Do(ctxt context.Context, h cdp.Executor) error { // run runs the selector action, starting over if the original returned nodes // are invalidated prior to finishing the selector's by, wait, check, and after // funcs. -func (s *Selector) run(ctxt context.Context, h *TargetHandler) chan error { +func (s *Selector) run(ctxt context.Context, h *Target) chan error { ch := make(chan error, 1) + h.waitQueue <- func(cur *cdp.Frame) bool { + cur.RLock() + root := cur.Root + cur.RUnlock() - go func() { - defer close(ch) + if root == nil { + // not ready? + return false + } - for { - root, err := h.GetRoot(ctxt) - if err != nil { - select { - case <-ctxt.Done(): - ch <- ctxt.Err() - return - default: - continue - } - } - - select { - default: - ids, err := s.by(ctxt, h, root) - if err == nil && len(ids) >= s.exp { - nodes, err := s.wait(ctxt, h, root, ids...) - if err == nil { - if s.after == nil { - return - } - - if err := s.after(ctxt, h, nodes...); err != nil { - ch <- err - } - return - } - } - - time.Sleep(DefaultCheckDuration) - - case <-root.Invalidated: - continue - - case <-ctxt.Done(): - ch <- ctxt.Err() - return + ids, err := s.by(ctxt, h, root) + if err != nil || len(ids) < s.exp { + return false + } + nodes, err := s.wait(ctxt, h, cur, ids...) + // if nodes==nil, we're not yet ready + if nodes == nil || err != nil { + return false + } + if s.after != nil { + if err := s.after(ctxt, h, nodes...); err != nil { + ch <- err } } - }() - + close(ch) + return true + } return ch } @@ -151,7 +128,7 @@ func (s *Selector) selAsString() string { // QueryAfter is an action that will match the specified sel using the supplied // query options, and after the visibility conditions of the query have been // met, will execute f. -func QueryAfter(sel interface{}, f func(context.Context, *TargetHandler, ...*cdp.Node) error, opts ...QueryOption) Action { +func QueryAfter(sel interface{}, f func(context.Context, *Target, ...*cdp.Node) error, opts ...QueryOption) Action { return Query(sel, append(opts, After(f))...) } @@ -159,7 +136,7 @@ func QueryAfter(sel interface{}, f func(context.Context, *TargetHandler, ...*cdp type QueryOption func(*Selector) // ByFunc is a query option to set the func used to select elements. -func ByFunc(f func(context.Context, *TargetHandler, *cdp.Node) ([]cdp.NodeID, error)) QueryOption { +func ByFunc(f func(context.Context, *Target, *cdp.Node) ([]cdp.NodeID, error)) QueryOption { return func(s *Selector) { s.by = f } @@ -168,7 +145,7 @@ func ByFunc(f func(context.Context, *TargetHandler, *cdp.Node) ([]cdp.NodeID, er // ByQuery is a query option to select a single element using // DOM.querySelector. func ByQuery(s *Selector) { - ByFunc(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) ([]cdp.NodeID, error) { + ByFunc(func(ctxt context.Context, h *Target, n *cdp.Node) ([]cdp.NodeID, error) { nodeID, err := dom.QuerySelector(n.NodeID, s.selAsString()).Do(ctxt, h) if err != nil { return nil, err @@ -184,7 +161,7 @@ func ByQuery(s *Selector) { // ByQueryAll is a query option to select elements by DOM.querySelectorAll. func ByQueryAll(s *Selector) { - ByFunc(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) ([]cdp.NodeID, error) { + ByFunc(func(ctxt context.Context, h *Target, n *cdp.Node) ([]cdp.NodeID, error) { return dom.QuerySelectorAll(n.NodeID, s.selAsString()).Do(ctxt, h) })(s) } @@ -198,7 +175,7 @@ func ByID(s *Selector) { // BySearch is a query option via DOM.performSearch (works with both CSS and // XPath queries). func BySearch(s *Selector) { - ByFunc(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) ([]cdp.NodeID, error) { + ByFunc(func(ctxt context.Context, h *Target, n *cdp.Node) ([]cdp.NodeID, error) { id, count, err := dom.PerformSearch(s.selAsString()).Do(ctxt, h) if err != nil { return nil, err @@ -224,7 +201,7 @@ func ByNodeID(s *Selector) { panic("ByNodeID can only work on []cdp.NodeID") } - ByFunc(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) ([]cdp.NodeID, error) { + ByFunc(func(ctxt context.Context, h *Target, n *cdp.Node) ([]cdp.NodeID, error) { for _, id := range ids { err := dom.RequestChildNodes(id).WithPierce(true).Do(ctxt, h) if err != nil { @@ -237,32 +214,22 @@ func ByNodeID(s *Selector) { } // waitReady waits for the specified nodes to be ready. -func (s *Selector) waitReady(check func(context.Context, *TargetHandler, *cdp.Node) error) func(context.Context, *TargetHandler, *cdp.Node, ...cdp.NodeID) ([]*cdp.Node, error) { - return func(ctxt context.Context, h *TargetHandler, n *cdp.Node, ids ...cdp.NodeID) ([]*cdp.Node, error) { - f, err := h.WaitFrame(ctxt, cdp.EmptyFrameID) - if err != nil { - return nil, err - } - - wg := new(sync.WaitGroup) +func (s *Selector) waitReady(check func(context.Context, *Target, *cdp.Node) error) func(context.Context, *Target, *cdp.Frame, ...cdp.NodeID) ([]*cdp.Node, error) { + return func(ctxt context.Context, h *Target, cur *cdp.Frame, ids ...cdp.NodeID) ([]*cdp.Node, error) { nodes := make([]*cdp.Node, len(ids)) - errs := make([]error, len(ids)) + cur.RLock() for i, id := range ids { - wg.Add(1) - go func(i int, id cdp.NodeID) { - defer wg.Done() - nodes[i], errs[i] = h.WaitNode(ctxt, f, id) - }(i, id) - } - wg.Wait() - - for _, err := range errs { - if err != nil { - return nil, err + nodes[i] = cur.Nodes[id] + if nodes[i] == nil { + cur.RUnlock() + // not yet ready + return nil, nil } } + cur.RUnlock() if check != nil { + var wg sync.WaitGroup errs := make([]error, len(nodes)) for i, n := range nodes { wg.Add(1) @@ -285,7 +252,7 @@ func (s *Selector) waitReady(check func(context.Context, *TargetHandler, *cdp.No } // WaitFunc is a query option to set a custom wait func. -func WaitFunc(wait func(context.Context, *TargetHandler, *cdp.Node, ...cdp.NodeID) ([]*cdp.Node, error)) QueryOption { +func WaitFunc(wait func(context.Context, *Target, *cdp.Frame, ...cdp.NodeID) ([]*cdp.Node, error)) QueryOption { return func(s *Selector) { s.wait = wait } @@ -298,7 +265,7 @@ func NodeReady(s *Selector) { // NodeVisible is a query option to wait until the element is visible. func NodeVisible(s *Selector) { - WaitFunc(s.waitReady(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) error { + WaitFunc(s.waitReady(func(ctxt context.Context, h *Target, n *cdp.Node) error { // check box model _, err := dom.GetBoxModel().WithNodeID(n.NodeID).Do(ctxt, h) if err != nil { @@ -324,7 +291,7 @@ func NodeVisible(s *Selector) { // NodeNotVisible is a query option to wait until the element is not visible. func NodeNotVisible(s *Selector) { - WaitFunc(s.waitReady(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) error { + WaitFunc(s.waitReady(func(ctxt context.Context, h *Target, n *cdp.Node) error { // check box model _, err := dom.GetBoxModel().WithNodeID(n.NodeID).Do(ctxt, h) if err != nil { @@ -350,7 +317,7 @@ func NodeNotVisible(s *Selector) { // NodeEnabled is a query option to wait until the element is enabled. func NodeEnabled(s *Selector) { - WaitFunc(s.waitReady(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) error { + WaitFunc(s.waitReady(func(ctxt context.Context, h *Target, n *cdp.Node) error { n.RLock() defer n.RUnlock() @@ -366,7 +333,7 @@ func NodeEnabled(s *Selector) { // NodeSelected is a query option to wait until the element is selected. func NodeSelected(s *Selector) { - WaitFunc(s.waitReady(func(ctxt context.Context, h *TargetHandler, n *cdp.Node) error { + WaitFunc(s.waitReady(func(ctxt context.Context, h *Target, n *cdp.Node) error { n.RLock() defer n.RUnlock() @@ -380,11 +347,11 @@ func NodeSelected(s *Selector) { }))(s) } -// NodeNotPresent is a query option to wait until no elements match are -// present matching the selector. +// NodeNotPresent is a query option to wait until no elements are present +// matching the selector. func NodeNotPresent(s *Selector) { s.exp = 0 - WaitFunc(func(ctxt context.Context, h *TargetHandler, n *cdp.Node, ids ...cdp.NodeID) ([]*cdp.Node, error) { + WaitFunc(func(ctxt context.Context, h *Target, cur *cdp.Frame, ids ...cdp.NodeID) ([]*cdp.Node, error) { if len(ids) != 0 { return nil, ErrHasResults } @@ -402,7 +369,7 @@ func AtLeast(n int) QueryOption { // After is a query option to set a func that will be executed after the wait // has succeeded. -func After(f func(context.Context, *TargetHandler, ...*cdp.Node) error) QueryOption { +func After(f func(context.Context, *Target, ...*cdp.Node) error) QueryOption { return func(s *Selector) { s.after = f }