diff --git a/drivers/net/usb/usbnet.c b/drivers/net/usb/usbnet.c index 9a6450f796dc..b1f93810a6f3 100644 --- a/drivers/net/usb/usbnet.c +++ b/drivers/net/usb/usbnet.c @@ -91,6 +91,31 @@ static const char * const usbnet_event_names[] = { [EVENT_NO_IP_ALIGN] = "EVENT_NO_IP_ALIGN", }; +bool usbnet_validate_endpoints(struct usbnet *dev, struct usb_interface *intf, const struct driver_info *info) +{ + struct usb_host_interface *alt = intf->cur_altsetting; + struct usb_host_endpoint *e; + int num_endpoints = alt->desc.bNumEndpoints; + + if (info->in > num_endpoints) + return false; + e = alt->endpoint + info->in; + if (!e) + return false; + if (!usb_endpoint_is_bulk_in(&e->desc)) + return false; + + if (info->out > num_endpoints) + return false; + e = alt->endpoint + info->out; + if (!e) + return false; + if (!usb_endpoint_is_bulk_out(&e->desc)) + return false; + + return true; +} + /* handles CDC Ethernet and many other network "bulk data" interfaces */ int usbnet_get_endpoints(struct usbnet *dev, struct usb_interface *intf) { @@ -1772,6 +1797,8 @@ usbnet_probe (struct usb_interface *udev, const struct usb_device_id *prod) } else if (!info->in || !info->out) status = usbnet_get_endpoints (dev, udev); else { + if (!usbnet_validate_endpoints(dev, udev, info)) + goto out3; dev->in = usb_rcvbulkpipe (xdev, info->in); dev->out = usb_sndbulkpipe (xdev, info->out); if (!(info->flags & FLAG_NO_SETINT))