--- a/net/netfilter/ipvs/ip_vs_ctl.c +++ b/net/netfilter/ipvs/ip_vs_ctl.c @@ -2384,11 +2384,7 @@ do_ip_vs_set_ctl(struct sock *sk, int cmd, void __user *user, unsigned int len) strlcpy(cfg.mcast_ifn, dm->mcast_ifn, sizeof(cfg.mcast_ifn)); cfg.syncid = dm->syncid; - rtnl_lock(); - mutex_lock(&ipvs->sync_mutex); ret = start_sync_thread(ipvs, &cfg, dm->state); - mutex_unlock(&ipvs->sync_mutex); - rtnl_unlock(); } else { mutex_lock(&ipvs->sync_mutex); ret = stop_sync_thread(ipvs, dm->state); @@ -3481,12 +3477,8 @@ static int ip_vs_genl_new_daemon(struct netns_ipvs *ipvs, struct nlattr **attrs) if (ipvs->mixed_address_family_dests > 0) return -EINVAL; - rtnl_lock(); - mutex_lock(&ipvs->sync_mutex); ret = start_sync_thread(ipvs, &c, nla_get_u32(attrs[IPVS_DAEMON_ATTR_STATE])); - mutex_unlock(&ipvs->sync_mutex); - rtnl_unlock(); return ret; } diff --git a/net/netfilter/ipvs/ip_vs_sync.c b/net/netfilter/ipvs/ip_vs_sync.c index fbaf3bd..0d56b5a 100644 --- a/net/netfilter/ipvs/ip_vs_sync.c +++ b/net/netfilter/ipvs/ip_vs_sync.c @@ -1360,15 +1360,9 @@ static void set_mcast_pmtudisc(struct sock *sk, int val) /* * Specifiy default interface for outgoing multicasts */ -static int set_mcast_if(struct sock *sk, char *ifname) +static int set_mcast_if(struct sock *sk, struct net_device *dev) { - struct net_device *dev; struct inet_sock *inet = inet_sk(sk); - struct net *net = sock_net(sk); - - dev = __dev_get_by_name(net, ifname); - if (!dev) - return -ENODEV; if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) return -EINVAL; @@ -1396,19 +1390,14 @@ static int set_mcast_if(struct sock *sk, char *ifname) * in the in_addr structure passed in as a parameter. */ static int -join_mcast_group(struct sock *sk, struct in_addr *addr, char *ifname) +join_mcast_group(struct sock *sk, struct in_addr *addr, struct net_device *dev) { - struct net *net = sock_net(sk); struct ip_mreqn mreq; - struct net_device *dev; int ret; memset(&mreq, 0, sizeof(mreq)); memcpy(&mreq.imr_multiaddr, addr, sizeof(struct in_addr)); - dev = __dev_get_by_name(net, ifname); - if (!dev) - return -ENODEV; if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) return -EINVAL; @@ -1423,15 +1412,10 @@ join_mcast_group(struct sock *sk, struct in_addr *addr, char *ifname) #ifdef CONFIG_IP_VS_IPV6 static int join_mcast_group6(struct sock *sk, struct in6_addr *addr, - char *ifname) + struct net_device *dev) { - struct net *net = sock_net(sk); - struct net_device *dev; int ret; - dev = __dev_get_by_name(net, ifname); - if (!dev) - return -ENODEV; if (sk->sk_bound_dev_if && dev->ifindex != sk->sk_bound_dev_if) return -EINVAL; @@ -1443,24 +1427,18 @@ static int join_mcast_group6(struct sock *sk, struct in6_addr *addr, } #endif -static int bind_mcastif_addr(struct socket *sock, char *ifname) +static int bind_mcastif_addr(struct socket *sock, struct net_device *dev) { - struct net *net = sock_net(sock->sk); - struct net_device *dev; __be32 addr; struct sockaddr_in sin; - dev = __dev_get_by_name(net, ifname); - if (!dev) - return -ENODEV; - addr = inet_select_addr(dev, 0, RT_SCOPE_UNIVERSE); if (!addr) pr_err("You probably need to specify IP address on " "multicast interface.\n"); IP_VS_DBG(7, "binding socket with (%s) %pI4\n", - ifname, &addr); + dev->name, &addr); /* Now bind the socket with the address of multicast interface */ sin.sin_family = AF_INET; @@ -1493,7 +1471,8 @@ static void get_mcast_sockaddr(union ipvs_sockaddr *sa, int *salen, /* * Set up sending multicast socket over UDP */ -static struct socket *make_send_sock(struct netns_ipvs *ipvs, int id) +static int make_send_sock(struct netns_ipvs *ipvs, int id, + struct net_device *dev, struct socket **sock_ret) { /* multicast addr */ union ipvs_sockaddr mcast_addr; @@ -1505,9 +1484,10 @@ static struct socket *make_send_sock(struct netns_ipvs *ipvs, int id) IPPROTO_UDP, &sock); if (result < 0) { pr_err("Error during creation of socket; terminating\n"); - return ERR_PTR(result); + goto error; } - result = set_mcast_if(sock->sk, ipvs->mcfg.mcast_ifn); + *sock_ret = sock; + result = set_mcast_if(sock->sk, dev); if (result < 0) { pr_err("Error setting outbound mcast interface\n"); goto error; @@ -1522,7 +1502,7 @@ static struct socket *make_send_sock(struct netns_ipvs *ipvs, int id) set_sock_size(sock->sk, 1, result); if (AF_INET == ipvs->mcfg.mcast_af) - result = bind_mcastif_addr(sock, ipvs->mcfg.mcast_ifn); + result = bind_mcastif_addr(sock, dev); else result = 0; if (result < 0) { @@ -1538,19 +1518,18 @@ static struct socket *make_send_sock(struct netns_ipvs *ipvs, int id) goto error; } - return sock; + return 0; error: - sock_release(sock); - return ERR_PTR(result); + return result; } /* * Set up receiving multicast socket over UDP */ -static struct socket *make_receive_sock(struct netns_ipvs *ipvs, int id, - int ifindex) +static int make_receive_sock(struct netns_ipvs *ipvs, int id, + struct net_device *dev, struct socket **sock_ret) { /* multicast addr */ union ipvs_sockaddr mcast_addr; @@ -1562,8 +1541,9 @@ static struct socket *make_receive_sock(struct netns_ipvs *ipvs, int id, IPPROTO_UDP, &sock); if (result < 0) { pr_err("Error during creation of socket; terminating\n"); - return ERR_PTR(result); + goto error; } + *sock_ret = sock; /* it is equivalent to the REUSEADDR option in user-space */ sock->sk->sk_reuse = SK_CAN_REUSE; result = sysctl_sync_sock_size(ipvs); @@ -1571,7 +1551,7 @@ static struct socket *make_receive_sock(struct netns_ipvs *ipvs, int id, set_sock_size(sock->sk, 0, result); get_mcast_sockaddr(&mcast_addr, &salen, &ipvs->bcfg, id); - sock->sk->sk_bound_dev_if = ifindex; + sock->sk->sk_bound_dev_if = dev->ifindex; result = sock->ops->bind(sock, (struct sockaddr *)&mcast_addr, salen); if (result < 0) { pr_err("Error binding to the multicast addr\n"); @@ -1582,21 +1562,20 @@ static struct socket *make_receive_sock(struct netns_ipvs *ipvs, int id, #ifdef CONFIG_IP_VS_IPV6 if (ipvs->bcfg.mcast_af == AF_INET6) result = join_mcast_group6(sock->sk, &mcast_addr.in6.sin6_addr, - ipvs->bcfg.mcast_ifn); + dev); else #endif result = join_mcast_group(sock->sk, &mcast_addr.in.sin_addr, - ipvs->bcfg.mcast_ifn); + dev); if (result < 0) { pr_err("Error joining to the multicast group\n"); goto error; } - return sock; + return 0; error: - sock_release(sock); - return ERR_PTR(result); + return result; } @@ -1778,13 +1757,12 @@ static int sync_thread_backup(void *data) int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, int state) { - struct ip_vs_sync_thread_data *tinfo; + struct ip_vs_sync_thread_data *tinfo = NULL; struct task_struct **array = NULL, *task; - struct socket *sock; struct net_device *dev; char *name; int (*threadfn)(void *data); - int id, count, hlen; + int id = 0, count, hlen; int result = -ENOMEM; u16 mtu, min_mtu; @@ -1792,6 +1770,8 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, IP_VS_DBG(7, "Each ip_vs_sync_conn entry needs %zd bytes\n", sizeof(struct ip_vs_sync_conn_v0)); + rtnl_lock(); + mutex_lock(&ipvs->sync_mutex); if (!ipvs->sync_state) { count = clamp(sysctl_sync_ports(ipvs), 1, IPVS_SYNC_PORTS_MAX); ipvs->threads_mask = count - 1; @@ -1810,7 +1790,8 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, dev = __dev_get_by_name(ipvs->net, c->mcast_ifn); if (!dev) { pr_err("Unknown mcast interface: %s\n", c->mcast_ifn); - return -ENODEV; + result = -ENODEV; + goto out_early; } hlen = (AF_INET6 == c->mcast_af) ? sizeof(struct ipv6hdr) + sizeof(struct udphdr) : @@ -1827,26 +1808,30 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, c->sync_maxlen = mtu - hlen; if (state == IP_VS_STATE_MASTER) { + result = -EEXIST; if (ipvs->ms) - return -EEXIST; + goto out_early; ipvs->mcfg = *c; name = "ipvs-m:%d:%d"; threadfn = sync_thread_master; } else if (state == IP_VS_STATE_BACKUP) { + result = -EEXIST; if (ipvs->backup_threads) - return -EEXIST; + goto out_early; ipvs->bcfg = *c; name = "ipvs-b:%d:%d"; threadfn = sync_thread_backup; } else { - return -EINVAL; + result = -EINVAL; + goto out_early; } if (state == IP_VS_STATE_MASTER) { struct ipvs_master_sync_state *ms; + result = -ENOMEM; ipvs->ms = kcalloc(count, sizeof(ipvs->ms[0]), GFP_KERNEL); if (!ipvs->ms) goto out; @@ -1862,39 +1847,38 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, } else { array = kcalloc(count, sizeof(struct task_struct *), GFP_KERNEL); + result = -ENOMEM; if (!array) goto out; } - tinfo = NULL; for (id = 0; id < count; id++) { - if (state == IP_VS_STATE_MASTER) - sock = make_send_sock(ipvs, id); - else - sock = make_receive_sock(ipvs, id, dev->ifindex); - if (IS_ERR(sock)) { - result = PTR_ERR(sock); - goto outtinfo; - } + result = -ENOMEM; tinfo = kmalloc(sizeof(*tinfo), GFP_KERNEL); if (!tinfo) - goto outsocket; + goto out; tinfo->ipvs = ipvs; - tinfo->sock = sock; + tinfo->sock = NULL; if (state == IP_VS_STATE_BACKUP) { tinfo->buf = kmalloc(ipvs->bcfg.sync_maxlen, GFP_KERNEL); if (!tinfo->buf) - goto outtinfo; + goto out; } else { tinfo->buf = NULL; } tinfo->id = id; + if (state == IP_VS_STATE_MASTER) + result = make_send_sock(ipvs, id, dev, &tinfo->sock); + else + result = make_receive_sock(ipvs, id, dev, &tinfo->sock); + if (result < 0) + goto out; task = kthread_run(threadfn, tinfo, name, ipvs->gen, id); if (IS_ERR(task)) { result = PTR_ERR(task); - goto outtinfo; + goto out; } tinfo = NULL; if (state == IP_VS_STATE_MASTER) @@ -1911,20 +1895,20 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, ipvs->sync_state |= state; spin_unlock_bh(&ipvs->sync_buff_lock); + mutex_unlock(&ipvs->sync_mutex); + rtnl_unlock(); + /* increase the module use count */ ip_vs_use_count_inc(); return 0; -outsocket: - sock_release(sock); - -outtinfo: - if (tinfo) { - sock_release(tinfo->sock); - kfree(tinfo->buf); - kfree(tinfo); - } +out: + /* We do not need RTNL lock anymore, release it here so that + * sock_release below and in the kthreads can use rtnl_lock + * to leave the mcast group. + */ + rtnl_unlock(); count = id; while (count-- > 0) { if (state == IP_VS_STATE_MASTER) @@ -1932,13 +1916,23 @@ int start_sync_thread(struct netns_ipvs *ipvs, struct ipvs_sync_daemon_cfg *c, else kthread_stop(array[count]); } - kfree(array); - -out: if (!(ipvs->sync_state & IP_VS_STATE_MASTER)) { kfree(ipvs->ms); ipvs->ms = NULL; } + mutex_unlock(&ipvs->sync_mutex); + if (tinfo) { + if (tinfo->sock) + sock_release(tinfo->sock); + kfree(tinfo->buf); + kfree(tinfo); + } + kfree(array); + return result; + +out_early: + mutex_unlock(&ipvs->sync_mutex); + rtnl_unlock(); return result; }