tests/lock_test: add a linux kernel module for MCS locks
[akaros.git] / tests / linux / modules / mcs.c
1 // SPDX-License-Identifier: GPL-2.0-or-later
2 /*
3  * Copyright (c) 2013, 2014 The Regents of the University of California
4  * Copyright (c) 2020 Google Inc
5  *
6  * Barret Rhoden <brho@cs.berkeley.edu>
7  *
8  * Sorry, but you'll need to change your linux source to expose this function:
9
10  EXPORT_SYMBOL_GPL(kthread_create_on_cpu);
11
12  *
13  */
14
15 #include <linux/module.h>
16 #include <linux/moduleparam.h>
17 #include <linux/kobject.h>
18 #include <linux/sysfs.h>
19 #include <linux/slab.h>
20
21 #include <linux/sched/task.h>
22 #include <linux/sched/mm.h>
23 #include <linux/delay.h>
24 #include <linux/kthread.h>
25 #include <linux/completion.h>
26 #include <asm/msr.h>
27
28 struct lock_sample {
29         u64 pre;
30         u64 acq;
31         u64 un;
32         bool valid;
33 };
34
35 /* mtx protects all variables and the test run */
36 static struct mutex mtx;
37
38 static DECLARE_COMPLETION(test_done);
39
40 static unsigned int nr_threads;
41 static unsigned int nr_loops;
42 static unsigned int hold_time;
43 static unsigned int delay_time;
44
45 /* array[nr_thread] of pointers of lock_sample[nr_loops] */
46 static struct lock_sample **times;
47 /* array[nr_thread] of task* */
48 static struct task_struct **threads;
49 /* array[nr_thread] of void* */
50 static void **retvals;
51 static void *results;
52 static size_t results_sz;
53
54 static bool run_locktest __cacheline_aligned_in_smp;
55 static atomic_t horses __cacheline_aligned_in_smp;
56
57 static struct qspinlock l = __ARCH_SPIN_LOCK_UNLOCKED;
58
59 static int __mcs_thread_lock_test(void *arg)
60 {
61         long thread_id = (long)arg;
62         u64 pre_lock, acq_lock, un_lock;
63         struct lock_sample *this_time;
64         int i;
65
66         atomic_dec(&horses);
67         while (atomic_read(&horses))
68                 cpu_relax();
69         for (i = 0; i < nr_loops; i++) {
70                 /*
71                  * might be able to replace this with post-processing.  let the
72                  * test run, and discard all entries after the first finisher
73                  */
74                 if (!READ_ONCE(run_locktest))
75                         break;
76
77                 local_irq_disable();
78                 pre_lock = rdtsc_ordered();
79
80                 queued_spin_lock(&l);
81
82                 acq_lock = rdtsc_ordered();
83
84                 if (hold_time)
85                         ndelay(hold_time);
86
87                 queued_spin_unlock(&l);
88
89                 un_lock = rdtsc_ordered();
90
91                 local_irq_enable();
92
93                 this_time = &times[thread_id][i];
94                 this_time->pre = pre_lock;
95                 this_time->acq = acq_lock;
96                 this_time->un = un_lock;
97                 /* Can turn these on/off to control which samples we gather */
98                 this_time->valid = true;
99                 if (delay_time)
100                         ndelay(delay_time);
101                 /*
102                  * This can throw off your delay_time.  Think of delay_time as
103                  * the least amount of time we'll wait between reacquiring the
104                  * lock.  After all, IRQs are enabled, so all bets are off.
105                  */
106                 cond_resched();
107         }
108         /* First thread to finish stops the test */
109         WRITE_ONCE(run_locktest, false);
110         /*
111          * Wakes the controller thread.  The others will be done soon, to
112          * complete the hokey thread join.
113          */
114         complete(&test_done);
115
116         WRITE_ONCE(retvals[thread_id], (void*)(long)i);
117
118         return 0;
119 }
120
121 /*
122  * This consolidates the results in a format we will export to userspace.  We
123  * could have just used this format for the test itself, but then the times
124  * arrays wouldn't be NUMA local.
125  */
126 static int mcs_build_output(struct lock_sample **times, void **retvals)
127 {
128         int i;
129         size_t sz_rets = nr_threads * sizeof(void*);
130         size_t sz_times_per = nr_loops * sizeof(struct lock_sample);
131
132         results_sz = sz_rets + nr_threads * sz_times_per;
133
134         kvfree(results);
135
136         results = kvzalloc(results_sz, GFP_KERNEL);
137         if (!results) {
138                 pr_err("fucked %d", __LINE__);
139                 return -1;
140         }
141
142         memcpy(results, retvals, sz_rets);
143         for (i = 0; i < nr_threads; i++) {
144                 memcpy(results + sz_rets + i * sz_times_per,
145                        times[i], sz_times_per);
146         }
147
148         return 0;
149 }
150
151 static int mcs_lock_test(void)
152 {
153         int i;
154         int ret = -1;
155         size_t amt;
156
157         atomic_set(&horses, nr_threads);
158         WRITE_ONCE(run_locktest, true);
159
160         times = kcalloc(nr_threads, sizeof(struct lock_sample *), GFP_KERNEL);
161         if (!times) {
162                 pr_err("fucked %d", __LINE__);
163                 return ret;
164         }
165
166         if (check_mul_overflow((size_t)nr_loops, sizeof(struct lock_sample),
167                                &amt)) {
168                 pr_err("fucked %d", __LINE__);
169                 goto out_times;
170         }
171         for (i = 0; i < nr_threads; i++) {
172                 times[i] = kvzalloc_node(amt, GFP_KERNEL, cpu_to_node(i));
173                 if (!times[i]) {
174                         /* we clean up the times[i]s below */
175                         pr_err("fucked %d", __LINE__);
176                         goto out_times;
177                 }
178         }
179
180         retvals = kcalloc(nr_threads, sizeof(void *), GFP_KERNEL);
181         if (!retvals) {
182                 pr_err("fucked %d", __LINE__);
183                 goto out_times;
184         }
185         for (i = 0; i < nr_threads; i++)
186                 retvals[i] = (void*)-1;
187
188         threads = kcalloc(nr_threads, sizeof(struct task_struct *),
189                           GFP_KERNEL);
190         if (!threads) {
191                 pr_err("fucked %d", __LINE__);
192                 goto out_retvals;
193         }
194
195         for (i = 0; i < nr_threads; i++) {
196                 threads[i] = kthread_create_on_cpu(__mcs_thread_lock_test,
197                                                    (void*)(long)i, i, "mcs-%u");
198                 if (IS_ERR(threads[i])) {
199                         while (--i >= 0) {
200                                 /*
201                                  * We could recover, perhaps with something like
202                                  * kthread_stop(threads[i]), but we'd need those
203                                  * threads to check kthread_should_stop(),
204                                  * perhaps in their hokey barrier.  I've never
205                                  * had this fail, so I haven't tested it.
206                                  */
207                         }
208                         pr_err("fucked %d", __LINE__);
209                         goto out_threads;
210                 }
211         }
212         for (i = 0; i < nr_threads; i++) {
213                 /* what's the deal with refcnting here?  it looks like an
214                  * uncounted ref: create->result = current.  so once we start
215                  * them, we probably can't touch this again. */
216                 wake_up_process(threads[i]);
217         }
218
219         /* Hokey join.  We know when the test is done but wait for the others */
220         wait_for_completion(&test_done);
221         for (i = 0; i < nr_threads; i++) {
222                 while (READ_ONCE(retvals[i]) == (void*)-1)
223                         cond_resched();
224         }
225
226         ret = mcs_build_output(times, retvals);
227
228 out_threads:
229         kfree(threads);
230 out_retvals:
231         kfree(retvals);
232 out_times:
233         for (i = 0; i < nr_threads; i++)
234                 kvfree(times[i]);
235         kfree(times);
236         return ret;
237 }
238
239 static ssize_t mcs_read(struct file *filp, struct kobject *kobj,
240                         struct bin_attribute *bin_attr,
241                         char *buf, loff_t off, size_t count)
242 {
243         mutex_lock(&mtx);
244
245         if (!off) {
246                 if (mcs_lock_test()) {
247                         mutex_unlock(&mtx);
248                         return -1;
249                 }
250         }
251         if (!results) {
252                 pr_err("fucked %d", __LINE__);
253                 mutex_unlock(&mtx);
254                 return -1;
255         }
256         /* mildly concerned about addition overflow.  caller's job? */
257         if (count + off > results_sz) {
258                 pr_err("fucked off %lld count %lu sz %lu\n", off, count,
259                        results_sz);
260                 count = results_sz - off;
261         }
262         memcpy(buf, results + off, count);
263
264         mutex_unlock(&mtx);
265
266         return count;
267 }
268
269 static loff_t __mcs_get_results_size(void)
270 {
271         return nr_threads *
272                 (sizeof(void*) + nr_loops * sizeof(struct lock_sample));
273 }
274
275 /*
276  * Unfortunately, this doesn't update the file live.  It'll only take effect the
277  * next time you open it.  So users need to write, close, open, read.
278  */
279 static void __mcs_update_size(void)
280 {
281         struct kernfs_node *kn = kernfs_find_and_get(kernel_kobj->sd, "mcs");
282
283         if (!kn) {
284                 pr_err("fucked %d", __LINE__);
285                 return;
286         }
287         kn->attr.size = __mcs_get_results_size();
288 }
289
290 static ssize_t mcs_write(struct file *filp, struct kobject *kobj,
291                          struct bin_attribute *bin_attr,
292                          char *buf, loff_t off, size_t count)
293 {
294         unsigned int threads, loops, hold, delay;
295         ssize_t ret;
296
297         /* TODO: check_mul_overflow and whatnot, esp for the result_sz buffer */
298         ret = sscanf(buf, "%u %u %u %u", &threads, &loops, &hold,
299                      &delay);
300         if (ret != 4)
301                 return -EINVAL;
302         if (threads > num_online_cpus())
303                 return -ENODEV;
304         if (threads == 0)
305                 threads = num_online_cpus();
306         mutex_lock(&mtx);
307         nr_threads = threads;
308         nr_loops = loops;
309         hold_time = hold;
310         delay_time = delay;
311         __mcs_update_size();
312         mutex_unlock(&mtx);
313         return count;
314 }
315
316 struct bin_attribute mcs_attr = {
317         .attr = {
318                 .name = "mcs",
319                 .mode = 0666,
320         },
321         .size = 0,
322         .private = NULL,
323         .read = mcs_read,
324         .write = mcs_write,
325 };
326
327 static int __init mcs_init(void)
328 {
329         mutex_init(&mtx);
330
331         /*
332          * The user needs to set these, but start with sensible defaults in case
333          * they read without writing.
334          */
335         nr_threads = num_online_cpus();
336         nr_loops = 10000;
337         mcs_attr.size = __mcs_get_results_size();
338
339         if (sysfs_create_bin_file(kernel_kobj, &mcs_attr)) {
340                 pr_err("\n\nfucked %d !!!\n\n\n", __LINE__);
341                 return -1;
342         }
343         return 0;
344 }
345
346 static void __exit mcs_exit(void)
347 {
348         sysfs_remove_bin_file(kernel_kobj, &mcs_attr);
349 }
350
351 module_init(mcs_init);
352 module_exit(mcs_exit);
353
354 MODULE_LICENSE("GPL");
355 MODULE_AUTHOR("Barret Rhoden <brho@google.com>");
356 MODULE_DESCRIPTION("MCS lock test");