]> bbs.cooldavid.org Git - net-next-2.6.git/blobdiff - net/ipv4/route.c
ipv4: rcu conversion in ip_route_output_slow
[net-next-2.6.git] / net / ipv4 / route.c
index 3f56b6e6c6aab583d65902e7190bbf1a6eaa60b6..a61acea975f10a2ae2490070a1da284d7ff4f518 100644 (file)
@@ -1268,18 +1268,11 @@ skip_hashing:
 
 void rt_bind_peer(struct rtable *rt, int create)
 {
-       static DEFINE_SPINLOCK(rt_peer_lock);
        struct inet_peer *peer;
 
        peer = inet_getpeer(rt->rt_dst, create);
 
-       spin_lock_bh(&rt_peer_lock);
-       if (rt->peer == NULL) {
-               rt->peer = peer;
-               peer = NULL;
-       }
-       spin_unlock_bh(&rt_peer_lock);
-       if (peer)
+       if (peer && cmpxchg(&rt->peer, NULL, peer) != NULL)
                inet_putpeer(peer);
 }
 
@@ -2365,9 +2358,8 @@ static int __mkroute_output(struct rtable **result,
        struct rtable *rth;
        struct in_device *in_dev;
        u32 tos = RT_FL_TOS(oldflp);
-       int err = 0;
 
-       if (ipv4_is_loopback(fl->fl4_src) && !(dev_out->flags&IFF_LOOPBACK))
+       if (ipv4_is_loopback(fl->fl4_src) && !(dev_out->flags & IFF_LOOPBACK))
                return -EINVAL;
 
        if (fl->fl4_dst == htonl(0xFFFFFFFF))
@@ -2380,11 +2372,12 @@ static int __mkroute_output(struct rtable **result,
        if (dev_out->flags & IFF_LOOPBACK)
                flags |= RTCF_LOCAL;
 
-       /* get work reference to inet device */
-       in_dev = in_dev_get(dev_out);
-       if (!in_dev)
+       rcu_read_lock();
+       in_dev = __in_dev_get_rcu(dev_out);
+       if (!in_dev) {
+               rcu_read_unlock();
                return -EINVAL;
-
+       }
        if (res->type == RTN_BROADCAST) {
                flags |= RTCF_BROADCAST | RTCF_LOCAL;
                if (res->fi) {
@@ -2392,13 +2385,13 @@ static int __mkroute_output(struct rtable **result,
                        res->fi = NULL;
                }
        } else if (res->type == RTN_MULTICAST) {
-               flags |= RTCF_MULTICAST|RTCF_LOCAL;
+               flags |= RTCF_MULTICAST | RTCF_LOCAL;
                if (!ip_check_mc(in_dev, oldflp->fl4_dst, oldflp->fl4_src,
                                 oldflp->proto))
                        flags &= ~RTCF_LOCAL;
                /* If multicast route do not exist use
-                  default one, but do not gateway in this case.
-                  Yes, it is hack.
+                * default one, but do not gateway in this case.
+                * Yes, it is hack.
                 */
                if (res->fi && res->prefixlen < 4) {
                        fib_info_put(res->fi);
@@ -2409,9 +2402,12 @@ static int __mkroute_output(struct rtable **result,
 
        rth = dst_alloc(&ipv4_dst_ops);
        if (!rth) {
-               err = -ENOBUFS;
-               goto cleanup;
+               rcu_read_unlock();
+               return -ENOBUFS;
        }
+       in_dev_hold(in_dev);
+       rcu_read_unlock();
+       rth->idev = in_dev;
 
        atomic_set(&rth->dst.__refcnt, 1);
        rth->dst.flags= DST_HOST;
@@ -2432,7 +2428,6 @@ static int __mkroute_output(struct rtable **result,
           cache entry */
        rth->dst.dev    = dev_out;
        dev_hold(dev_out);
-       rth->idev       = in_dev_get(dev_out);
        rth->rt_gateway = fl->fl4_dst;
        rth->rt_spec_dst= fl->fl4_src;
 
@@ -2467,13 +2462,8 @@ static int __mkroute_output(struct rtable **result,
        rt_set_nexthop(rth, res, 0);
 
        rth->rt_flags = flags;
-
        *result = rth;
- cleanup:
-       /* release work reference to inet device */
-       in_dev_put(in_dev);
-
-       return err;
+       return 0;
 }
 
 static int ip_mkroute_output(struct rtable **rp,
@@ -2497,6 +2487,7 @@ static int ip_mkroute_output(struct rtable **rp,
 
 /*
  * Major route resolver routine.
+ * called with rcu_read_lock();
  */
 
 static int ip_route_output_slow(struct net *net, struct rtable **rp,
@@ -2515,7 +2506,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
                            .iif = net->loopback_dev->ifindex,
                            .oif = oldflp->oif };
        struct fib_result res;
-       unsigned flags = 0;
+       unsigned int flags = 0;
        struct net_device *dev_out = NULL;
        int free_res = 0;
        int err;
@@ -2545,7 +2536,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
                    (ipv4_is_multicast(oldflp->fl4_dst) ||
                     oldflp->fl4_dst == htonl(0xFFFFFFFF))) {
                        /* It is equivalent to inet_addr_type(saddr) == RTN_LOCAL */
-                       dev_out = ip_dev_find(net, oldflp->fl4_src);
+                       dev_out = __ip_dev_find(net, oldflp->fl4_src, false);
                        if (dev_out == NULL)
                                goto out;
 
@@ -2570,26 +2561,21 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
 
                if (!(oldflp->flags & FLOWI_FLAG_ANYSRC)) {
                        /* It is equivalent to inet_addr_type(saddr) == RTN_LOCAL */
-                       dev_out = ip_dev_find(net, oldflp->fl4_src);
-                       if (dev_out == NULL)
+                       if (!__ip_dev_find(net, oldflp->fl4_src, false))
                                goto out;
-                       dev_put(dev_out);
-                       dev_out = NULL;
                }
        }
 
 
        if (oldflp->oif) {
-               dev_out = dev_get_by_index(net, oldflp->oif);
+               dev_out = dev_get_by_index_rcu(net, oldflp->oif);
                err = -ENODEV;
                if (dev_out == NULL)
                        goto out;
 
                /* RACE: Check return value of inet_select_addr instead. */
-               if (__in_dev_get_rtnl(dev_out) == NULL) {
-                       dev_put(dev_out);
+               if (rcu_dereference(dev_out->ip_ptr) == NULL)
                        goto out;       /* Wrong error code */
-               }
 
                if (ipv4_is_local_multicast(oldflp->fl4_dst) ||
                    oldflp->fl4_dst == htonl(0xFFFFFFFF)) {
@@ -2612,10 +2598,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
                fl.fl4_dst = fl.fl4_src;
                if (!fl.fl4_dst)
                        fl.fl4_dst = fl.fl4_src = htonl(INADDR_LOOPBACK);
-               if (dev_out)
-                       dev_put(dev_out);
                dev_out = net->loopback_dev;
-               dev_hold(dev_out);
                fl.oif = net->loopback_dev->ifindex;
                res.type = RTN_LOCAL;
                flags |= RTCF_LOCAL;
@@ -2649,8 +2632,6 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
                        res.type = RTN_UNICAST;
                        goto make_route;
                }
-               if (dev_out)
-                       dev_put(dev_out);
                err = -ENETUNREACH;
                goto out;
        }
@@ -2659,10 +2640,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
        if (res.type == RTN_LOCAL) {
                if (!fl.fl4_src)
                        fl.fl4_src = fl.fl4_dst;
-               if (dev_out)
-                       dev_put(dev_out);
                dev_out = net->loopback_dev;
-               dev_hold(dev_out);
                fl.oif = dev_out->ifindex;
                if (res.fi)
                        fib_info_put(res.fi);
@@ -2682,28 +2660,23 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
        if (!fl.fl4_src)
                fl.fl4_src = FIB_RES_PREFSRC(res);
 
-       if (dev_out)
-               dev_put(dev_out);
        dev_out = FIB_RES_DEV(res);
-       dev_hold(dev_out);
        fl.oif = dev_out->ifindex;
 
 
 make_route:
        err = ip_mkroute_output(rp, &res, &fl, oldflp, dev_out, flags);
 
-
        if (free_res)
                fib_res_put(&res);
-       if (dev_out)
-               dev_put(dev_out);
 out:   return err;
 }
 
 int __ip_route_output_key(struct net *net, struct rtable **rp,
                          const struct flowi *flp)
 {
-       unsigned hash;
+       unsigned int hash;
+       int res;
        struct rtable *rth;
 
        if (!rt_caching(net))
@@ -2734,10 +2707,18 @@ int __ip_route_output_key(struct net *net, struct rtable **rp,
        rcu_read_unlock_bh();
 
 slow_output:
-       return ip_route_output_slow(net, rp, flp);
+       rcu_read_lock();
+       res = ip_route_output_slow(net, rp, flp);
+       rcu_read_unlock();
+       return res;
 }
 EXPORT_SYMBOL_GPL(__ip_route_output_key);
 
+static struct dst_entry *ipv4_blackhole_dst_check(struct dst_entry *dst, u32 cookie)
+{
+       return NULL;
+}
+
 static void ipv4_rt_blackhole_update_pmtu(struct dst_entry *dst, u32 mtu)
 {
 }
@@ -2746,7 +2727,7 @@ static struct dst_ops ipv4_dst_blackhole_ops = {
        .family                 =       AF_INET,
        .protocol               =       cpu_to_be16(ETH_P_IP),
        .destroy                =       ipv4_dst_destroy,
-       .check                  =       ipv4_dst_check,
+       .check                  =       ipv4_blackhole_dst_check,
        .update_pmtu            =       ipv4_rt_blackhole_update_pmtu,
        .entries                =       ATOMIC_INIT(0),
 };
@@ -2793,7 +2774,7 @@ static int ipv4_dst_blackhole(struct net *net, struct rtable **rp, struct flowi
 
        dst_release(&(*rp)->dst);
        *rp = rt;
-       return (rt ? 0 : -ENOMEM);
+       return rt ? 0 : -ENOMEM;
 }
 
 int ip_route_output_flow(struct net *net, struct rtable **rp, struct flowi *flp,