]> bbs.cooldavid.org Git - net-next-2.6.git/blobdiff - net/ipv4/route.c
ipv4: add __rcu annotations to routes.c
[net-next-2.6.git] / net / ipv4 / route.c
index 3888f6ba0a5c559a9ce6437a059c91831f08ba2b..987bf9adb31833c19a0db04ce76060306d8e6994 100644 (file)
@@ -159,7 +159,6 @@ static struct dst_ops ipv4_dst_ops = {
        .link_failure =         ipv4_link_failure,
        .update_pmtu =          ip_rt_update_pmtu,
        .local_out =            __ip_local_out,
-       .entries =              ATOMIC_INIT(0),
 };
 
 #define ECN_OR_COST(class)     TC_PRIO_##class
@@ -199,7 +198,7 @@ const __u8 ip_tos2prio[16] = {
  */
 
 struct rt_hash_bucket {
-       struct rtable   *chain;
+       struct rtable __rcu     *chain;
 };
 
 #if defined(CONFIG_SMP) || defined(CONFIG_DEBUG_SPINLOCK) || \
@@ -281,7 +280,7 @@ static struct rtable *rt_cache_get_first(struct seq_file *seq)
        struct rtable *r = NULL;
 
        for (st->bucket = rt_hash_mask; st->bucket >= 0; --st->bucket) {
-               if (!rt_hash_table[st->bucket].chain)
+               if (!rcu_dereference_raw(rt_hash_table[st->bucket].chain))
                        continue;
                rcu_read_lock_bh();
                r = rcu_dereference_bh(rt_hash_table[st->bucket].chain);
@@ -301,17 +300,17 @@ static struct rtable *__rt_cache_get_next(struct seq_file *seq,
 {
        struct rt_cache_iter_state *st = seq->private;
 
-       r = r->dst.rt_next;
+       r = rcu_dereference_bh(r->dst.rt_next);
        while (!r) {
                rcu_read_unlock_bh();
                do {
                        if (--st->bucket < 0)
                                return NULL;
-               } while (!rt_hash_table[st->bucket].chain);
+               } while (!rcu_dereference_raw(rt_hash_table[st->bucket].chain));
                rcu_read_lock_bh();
-               r = rt_hash_table[st->bucket].chain;
+               r = rcu_dereference_bh(rt_hash_table[st->bucket].chain);
        }
-       return rcu_dereference_bh(r);
+       return r;
 }
 
 static struct rtable *rt_cache_get_next(struct seq_file *seq,
@@ -466,7 +465,7 @@ static int rt_cpu_seq_show(struct seq_file *seq, void *v)
 
        seq_printf(seq,"%08x  %08x %08x %08x %08x %08x %08x %08x "
                   " %08x %08x %08x %08x %08x %08x %08x %08x %08x \n",
-                  atomic_read(&ipv4_dst_ops.entries),
+                  dst_entries_get_slow(&ipv4_dst_ops),
                   st->in_hit,
                   st->in_slow_tot,
                   st->in_slow_mc,
@@ -722,19 +721,23 @@ static void rt_do_flush(int process_context)
        for (i = 0; i <= rt_hash_mask; i++) {
                if (process_context && need_resched())
                        cond_resched();
-               rth = rt_hash_table[i].chain;
+               rth = rcu_dereference_raw(rt_hash_table[i].chain);
                if (!rth)
                        continue;
 
                spin_lock_bh(rt_hash_lock_addr(i));
 #ifdef CONFIG_NET_NS
                {
-               struct rtable ** prev, * p;
+               struct rtable __rcu **prev;
+               struct rtable *p;
 
-               rth = rt_hash_table[i].chain;
+               rth = rcu_dereference_protected(rt_hash_table[i].chain,
+                       lockdep_is_held(rt_hash_lock_addr(i)));
 
                /* defer releasing the head of the list after spin_unlock */
-               for (tail = rth; tail; tail = tail->dst.rt_next)
+               for (tail = rth; tail;
+                    tail = rcu_dereference_protected(tail->dst.rt_next,
+                               lockdep_is_held(rt_hash_lock_addr(i))))
                        if (!rt_is_expired(tail))
                                break;
                if (rth != tail)
@@ -742,8 +745,12 @@ static void rt_do_flush(int process_context)
 
                /* call rt_free on entries after the tail requiring flush */
                prev = &rt_hash_table[i].chain;
-               for (p = *prev; p; p = next) {
-                       next = p->dst.rt_next;
+               for (p = rcu_dereference_protected(*prev,
+                               lockdep_is_held(rt_hash_lock_addr(i)));
+                    p != NULL;
+                    p = next) {
+                       next = rcu_dereference_protected(p->dst.rt_next,
+                               lockdep_is_held(rt_hash_lock_addr(i)));
                        if (!rt_is_expired(p)) {
                                prev = &p->dst.rt_next;
                        } else {
@@ -753,14 +760,15 @@ static void rt_do_flush(int process_context)
                }
                }
 #else
-               rth = rt_hash_table[i].chain;
-               rt_hash_table[i].chain = NULL;
+               rth = rcu_dereference_protected(rt_hash_table[i].chain,
+                       lockdep_is_held(rt_hash_lock_addr(i)));
+               rcu_assign_pointer(rt_hash_table[i].chain, NULL);
                tail = NULL;
 #endif
                spin_unlock_bh(rt_hash_lock_addr(i));
 
                for (; rth != tail; rth = next) {
-                       next = rth->dst.rt_next;
+                       next = rcu_dereference_protected(rth->dst.rt_next, 1);
                        rt_free(rth);
                }
        }
@@ -791,7 +799,7 @@ static int has_noalias(const struct rtable *head, const struct rtable *rth)
        while (aux != rth) {
                if (compare_hash_inputs(&aux->fl, &rth->fl))
                        return 0;
-               aux = aux->dst.rt_next;
+               aux = rcu_dereference_protected(aux->dst.rt_next, 1);
        }
        return ONE;
 }
@@ -800,7 +808,8 @@ static void rt_check_expire(void)
 {
        static unsigned int rover;
        unsigned int i = rover, goal;
-       struct rtable *rth, **rthp;
+       struct rtable *rth;
+       struct rtable __rcu **rthp;
        unsigned long samples = 0;
        unsigned long sum = 0, sum2 = 0;
        unsigned long delta;
@@ -826,11 +835,12 @@ static void rt_check_expire(void)
 
                samples++;
 
-               if (*rthp == NULL)
+               if (rcu_dereference_raw(*rthp) == NULL)
                        continue;
                length = 0;
                spin_lock_bh(rt_hash_lock_addr(i));
-               while ((rth = *rthp) != NULL) {
+               while ((rth = rcu_dereference_protected(*rthp,
+                                       lockdep_is_held(rt_hash_lock_addr(i)))) != NULL) {
                        prefetch(rth->dst.rt_next);
                        if (rt_is_expired(rth)) {
                                *rthp = rth->dst.rt_next;
@@ -942,9 +952,11 @@ static int rt_garbage_collect(struct dst_ops *ops)
        static unsigned long last_gc;
        static int rover;
        static int equilibrium;
-       struct rtable *rth, **rthp;
+       struct rtable *rth;
+       struct rtable __rcu **rthp;
        unsigned long now = jiffies;
        int goal;
+       int entries = dst_entries_get_fast(&ipv4_dst_ops);
 
        /*
         * Garbage collection is pretty expensive,
@@ -954,28 +966,28 @@ static int rt_garbage_collect(struct dst_ops *ops)
        RT_CACHE_STAT_INC(gc_total);
 
        if (now - last_gc < ip_rt_gc_min_interval &&
-           atomic_read(&ipv4_dst_ops.entries) < ip_rt_max_size) {
+           entries < ip_rt_max_size) {
                RT_CACHE_STAT_INC(gc_ignored);
                goto out;
        }
 
+       entries = dst_entries_get_slow(&ipv4_dst_ops);
        /* Calculate number of entries, which we want to expire now. */
-       goal = atomic_read(&ipv4_dst_ops.entries) -
-               (ip_rt_gc_elasticity << rt_hash_log);
+       goal = entries - (ip_rt_gc_elasticity << rt_hash_log);
        if (goal <= 0) {
                if (equilibrium < ipv4_dst_ops.gc_thresh)
                        equilibrium = ipv4_dst_ops.gc_thresh;
-               goal = atomic_read(&ipv4_dst_ops.entries) - equilibrium;
+               goal = entries - equilibrium;
                if (goal > 0) {
                        equilibrium += min_t(unsigned int, goal >> 1, rt_hash_mask + 1);
-                       goal = atomic_read(&ipv4_dst_ops.entries) - equilibrium;
+                       goal = entries - equilibrium;
                }
        } else {
                /* We are in dangerous area. Try to reduce cache really
                 * aggressively.
                 */
                goal = max_t(unsigned int, goal >> 1, rt_hash_mask + 1);
-               equilibrium = atomic_read(&ipv4_dst_ops.entries) - goal;
+               equilibrium = entries - goal;
        }
 
        if (now - last_gc >= ip_rt_gc_min_interval)
@@ -995,7 +1007,8 @@ static int rt_garbage_collect(struct dst_ops *ops)
                        k = (k + 1) & rt_hash_mask;
                        rthp = &rt_hash_table[k].chain;
                        spin_lock_bh(rt_hash_lock_addr(k));
-                       while ((rth = *rthp) != NULL) {
+                       while ((rth = rcu_dereference_protected(*rthp,
+                                       lockdep_is_held(rt_hash_lock_addr(k)))) != NULL) {
                                if (!rt_is_expired(rth) &&
                                        !rt_may_expire(rth, tmo, expire)) {
                                        tmo >>= 1;
@@ -1032,14 +1045,16 @@ static int rt_garbage_collect(struct dst_ops *ops)
                expire >>= 1;
 #if RT_CACHE_DEBUG >= 2
                printk(KERN_DEBUG "expire>> %u %d %d %d\n", expire,
-                               atomic_read(&ipv4_dst_ops.entries), goal, i);
+                               dst_entries_get_fast(&ipv4_dst_ops), goal, i);
 #endif
 
-               if (atomic_read(&ipv4_dst_ops.entries) < ip_rt_max_size)
+               if (dst_entries_get_fast(&ipv4_dst_ops) < ip_rt_max_size)
                        goto out;
        } while (!in_softirq() && time_before_eq(jiffies, now));
 
-       if (atomic_read(&ipv4_dst_ops.entries) < ip_rt_max_size)
+       if (dst_entries_get_fast(&ipv4_dst_ops) < ip_rt_max_size)
+               goto out;
+       if (dst_entries_get_slow(&ipv4_dst_ops) < ip_rt_max_size)
                goto out;
        if (net_ratelimit())
                printk(KERN_WARNING "dst cache overflow\n");
@@ -1049,11 +1064,12 @@ static int rt_garbage_collect(struct dst_ops *ops)
 work_done:
        expire += ip_rt_gc_min_interval;
        if (expire > ip_rt_gc_timeout ||
-           atomic_read(&ipv4_dst_ops.entries) < ipv4_dst_ops.gc_thresh)
+           dst_entries_get_fast(&ipv4_dst_ops) < ipv4_dst_ops.gc_thresh ||
+           dst_entries_get_slow(&ipv4_dst_ops) < ipv4_dst_ops.gc_thresh)
                expire = ip_rt_gc_timeout;
 #if RT_CACHE_DEBUG >= 2
        printk(KERN_DEBUG "expire++ %u %d %d %d\n", expire,
-                       atomic_read(&ipv4_dst_ops.entries), goal, rover);
+                       dst_entries_get_fast(&ipv4_dst_ops), goal, rover);
 #endif
 out:   return 0;
 }
@@ -1068,7 +1084,7 @@ static int slow_chain_length(const struct rtable *head)
 
        while (rth) {
                length += has_noalias(head, rth);
-               rth = rth->dst.rt_next;
+               rth = rcu_dereference_protected(rth->dst.rt_next, 1);
        }
        return length >> FRACT_BITS;
 }
@@ -1076,9 +1092,9 @@ static int slow_chain_length(const struct rtable *head)
 static int rt_intern_hash(unsigned hash, struct rtable *rt,
                          struct rtable **rp, struct sk_buff *skb, int ifindex)
 {
-       struct rtable   *rth, **rthp;
+       struct rtable   *rth, *cand;
+       struct rtable __rcu **rthp, **candp;
        unsigned long   now;
-       struct rtable *cand, **candp;
        u32             min_score;
        int             chain_length;
        int attempts = !in_softirq();
@@ -1102,9 +1118,9 @@ restart:
                 * Note that we do rt_free on this new route entry, so that
                 * once its refcount hits zero, we are still able to reap it
                 * (Thanks Alexey)
-                * Note also the rt_free uses call_rcu.  We don't actually
-                * need rcu protection here, this is just our path to get
-                * on the route gc list.
+                * Note: To avoid expensive rcu stuff for this uncached dst,
+                * we set DST_NOCACHE so that dst_release() can free dst without
+                * waiting a grace period.
                 */
 
                rt->dst.flags |= DST_NOCACHE;
@@ -1114,19 +1130,19 @@ restart:
                                if (net_ratelimit())
                                        printk(KERN_WARNING
                                            "Neighbour table failure & not caching routes.\n");
-                               rt_drop(rt);
+                               ip_rt_put(rt);
                                return err;
                        }
                }
 
-               rt_free(rt);
                goto skip_hashing;
        }
 
        rthp = &rt_hash_table[hash].chain;
 
        spin_lock_bh(rt_hash_lock_addr(hash));
-       while ((rth = *rthp) != NULL) {
+       while ((rth = rcu_dereference_protected(*rthp,
+                       lockdep_is_held(rt_hash_lock_addr(hash)))) != NULL) {
                if (rt_is_expired(rth)) {
                        *rthp = rth->dst.rt_next;
                        rt_free(rth);
@@ -1322,12 +1338,14 @@ EXPORT_SYMBOL(__ip_select_ident);
 
 static void rt_del(unsigned hash, struct rtable *rt)
 {
-       struct rtable **rthp, *aux;
+       struct rtable __rcu **rthp;
+       struct rtable *aux;
 
        rthp = &rt_hash_table[hash].chain;
        spin_lock_bh(rt_hash_lock_addr(hash));
        ip_rt_put(rt);
-       while ((aux = *rthp) != NULL) {
+       while ((aux = rcu_dereference_protected(*rthp,
+                       lockdep_is_held(rt_hash_lock_addr(hash)))) != NULL) {
                if (aux == rt || rt_is_expired(aux)) {
                        *rthp = aux->dst.rt_next;
                        rt_free(aux);
@@ -1344,7 +1362,8 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw,
 {
        int i, k;
        struct in_device *in_dev = __in_dev_get_rcu(dev);
-       struct rtable *rth, **rthp;
+       struct rtable *rth;
+       struct rtable __rcu **rthp;
        __be32  skeys[2] = { saddr, 0 };
        int  ikeys[2] = { dev->ifindex, 0 };
        struct netevent_redirect netevent;
@@ -1377,7 +1396,7 @@ void ip_rt_redirect(__be32 old_gw, __be32 daddr, __be32 new_gw,
                        unsigned hash = rt_hash(daddr, skeys[i], ikeys[k],
                                                rt_genid(net));
 
-                       rthp=&rt_hash_table[hash].chain;
+                       rthp = &rt_hash_table[hash].chain;
 
                        while ((rth = rcu_dereference(*rthp)) != NULL) {
                                struct rtable *rt;
@@ -2121,7 +2140,7 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr,
            ipv4_is_loopback(saddr))
                goto martian_source;
 
-       if (daddr == htonl(0xFFFFFFFF) || (saddr == 0 && daddr == 0))
+       if (ipv4_is_lbcast(daddr) || (saddr == 0 && daddr == 0))
                goto brd_input;
 
        /* Accept zero addresses only to limited broadcast;
@@ -2130,8 +2149,7 @@ static int ip_route_input_slow(struct sk_buff *skb, __be32 daddr, __be32 saddr,
        if (ipv4_is_zeronet(saddr))
                goto martian_source;
 
-       if (ipv4_is_lbcast(daddr) || ipv4_is_zeronet(daddr) ||
-           ipv4_is_loopback(daddr))
+       if (ipv4_is_zeronet(daddr) || ipv4_is_loopback(daddr))
                goto martian_destination;
 
        /*
@@ -2364,11 +2382,11 @@ static int __mkroute_output(struct rtable **result,
        if (ipv4_is_loopback(fl->fl4_src) && !(dev_out->flags & IFF_LOOPBACK))
                return -EINVAL;
 
-       if (fl->fl4_dst == htonl(0xFFFFFFFF))
+       if (ipv4_is_lbcast(fl->fl4_dst))
                res->type = RTN_BROADCAST;
        else if (ipv4_is_multicast(fl->fl4_dst))
                res->type = RTN_MULTICAST;
-       else if (ipv4_is_lbcast(fl->fl4_dst) || ipv4_is_zeronet(fl->fl4_dst))
+       else if (ipv4_is_zeronet(fl->fl4_dst))
                return -EINVAL;
 
        if (dev_out->flags & IFF_LOOPBACK)
@@ -2527,7 +2545,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
 
                if (oldflp->oif == 0 &&
                    (ipv4_is_multicast(oldflp->fl4_dst) ||
-                    oldflp->fl4_dst == htonl(0xFFFFFFFF))) {
+                    ipv4_is_lbcast(oldflp->fl4_dst))) {
                        /* It is equivalent to inet_addr_type(saddr) == RTN_LOCAL */
                        dev_out = __ip_dev_find(net, oldflp->fl4_src, false);
                        if (dev_out == NULL)
@@ -2571,7 +2589,7 @@ static int ip_route_output_slow(struct net *net, struct rtable **rp,
                        goto out;       /* Wrong error code */
 
                if (ipv4_is_local_multicast(oldflp->fl4_dst) ||
-                   oldflp->fl4_dst == htonl(0xFFFFFFFF)) {
+                   ipv4_is_lbcast(oldflp->fl4_dst)) {
                        if (!fl.fl4_src)
                                fl.fl4_src = inet_select_addr(dev_out, 0,
                                                              RT_SCOPE_LINK);
@@ -2717,7 +2735,6 @@ static struct dst_ops ipv4_dst_blackhole_ops = {
        .destroy                =       ipv4_dst_destroy,
        .check                  =       ipv4_blackhole_dst_check,
        .update_pmtu            =       ipv4_rt_blackhole_update_pmtu,
-       .entries                =       ATOMIC_INIT(0),
 };
 
 
@@ -3287,6 +3304,12 @@ int __init ip_rt_init(void)
 
        ipv4_dst_blackhole_ops.kmem_cachep = ipv4_dst_ops.kmem_cachep;
 
+       if (dst_entries_init(&ipv4_dst_ops) < 0)
+               panic("IP: failed to allocate ipv4_dst_ops counter\n");
+
+       if (dst_entries_init(&ipv4_dst_blackhole_ops) < 0)
+               panic("IP: failed to allocate ipv4_dst_blackhole_ops counter\n");
+
        rt_hash_table = (struct rt_hash_bucket *)
                alloc_large_system_hash("IP route cache",
                                        sizeof(struct rt_hash_bucket),