Skip to content

Commit 9363703

Browse files
gh-116738: Make grp module thread-safe (#135434)
Make grp module methods getgrgid() and getgrnam() thread-safe when the GIL is disabled and getgrgid_r()/getgrnam_r() C APIs are not available. --------- Co-authored-by: Kumar Aditya <[email protected]>
1 parent d995922 commit 9363703

File tree

7 files changed

+115
-46
lines changed

7 files changed

+115
-46
lines changed

Doc/library/test.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,13 @@ The :mod:`test.support.threading_helper` module provides support for threading t
13841384
.. versionadded:: 3.8
13851385

13861386

1387+
.. function:: run_concurrently(worker_func, nthreads, args=(), kwargs={})
1388+
1389+
Run the worker function concurrently in multiple threads.
1390+
Re-raises an exception if any thread raises one, after all threads have
1391+
finished.
1392+
1393+
13871394
:mod:`test.support.os_helper` --- Utilities for os tests
13881395
========================================================================
13891396

Lib/test/support/threading_helper.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,27 @@ def requires_working_threading(*, module=False):
248248
raise unittest.SkipTest(msg)
249249
else:
250250
return unittest.skipUnless(can_start_thread, msg)
251+
252+
253+
def run_concurrently(worker_func, nthreads, args=(), kwargs={}):
254+
"""
255+
Run the worker function concurrently in multiple threads.
256+
"""
257+
barrier = threading.Barrier(nthreads)
258+
259+
def wrapper_func(*args, **kwargs):
260+
# Wait for all threads to reach this point before proceeding.
261+
barrier.wait()
262+
worker_func(*args, **kwargs)
263+
264+
with catch_threading_exception() as cm:
265+
workers = [
266+
threading.Thread(target=wrapper_func, args=args, kwargs=kwargs)
267+
for _ in range(nthreads)
268+
]
269+
with start_threads(workers):
270+
pass
271+
272+
# If a worker thread raises an exception, re-raise it.
273+
if cm.exc_value is not None:
274+
raise cm.exc_value
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import unittest
2+
3+
from test.support import import_helper, threading_helper
4+
from test.support.threading_helper import run_concurrently
5+
6+
grp = import_helper.import_module("grp")
7+
8+
from test import test_grp
9+
10+
11+
NTHREADS = 10
12+
13+
14+
@threading_helper.requires_working_threading()
15+
class TestGrp(unittest.TestCase):
16+
def setUp(self):
17+
self.test_grp = test_grp.GroupDatabaseTestCase()
18+
19+
def test_racing_test_values(self):
20+
# test_grp.test_values() calls grp.getgrall() and checks the entries
21+
run_concurrently(
22+
worker_func=self.test_grp.test_values, nthreads=NTHREADS
23+
)
24+
25+
def test_racing_test_values_extended(self):
26+
# test_grp.test_values_extended() calls grp.getgrall(), grp.getgrgid(),
27+
# grp.getgrnam() and checks the entries
28+
run_concurrently(
29+
worker_func=self.test_grp.test_values_extended,
30+
nthreads=NTHREADS,
31+
)
32+
33+
34+
if __name__ == "__main__":
35+
unittest.main()

Lib/test/test_free_threading/test_heapq.py

Lines changed: 23 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
import heapq
44

55
from enum import Enum
6-
from threading import Thread, Barrier, Lock
6+
from threading import Barrier, Lock
77
from random import shuffle, randint
88

99
from test.support import threading_helper
10+
from test.support.threading_helper import run_concurrently
1011
from test import test_heapq
1112

1213

@@ -28,8 +29,8 @@ def test_racing_heapify(self):
2829
heap = list(range(OBJECT_COUNT))
2930
shuffle(heap)
3031

31-
self.run_concurrently(
32-
worker_func=heapq.heapify, args=(heap,), nthreads=NTHREADS
32+
run_concurrently(
33+
worker_func=heapq.heapify, nthreads=NTHREADS, args=(heap,)
3334
)
3435
self.test_heapq.check_invariant(heap)
3536

@@ -40,8 +41,8 @@ def heappush_func(heap):
4041
for item in reversed(range(OBJECT_COUNT)):
4142
heapq.heappush(heap, item)
4243

43-
self.run_concurrently(
44-
worker_func=heappush_func, args=(heap,), nthreads=NTHREADS
44+
run_concurrently(
45+
worker_func=heappush_func, nthreads=NTHREADS, args=(heap,)
4546
)
4647
self.test_heapq.check_invariant(heap)
4748

@@ -61,10 +62,10 @@ def heappop_func(heap, pop_count):
6162
# Each local list should be sorted
6263
self.assertTrue(self.is_sorted_ascending(local_list))
6364

64-
self.run_concurrently(
65+
run_concurrently(
6566
worker_func=heappop_func,
66-
args=(heap, per_thread_pop_count),
6767
nthreads=NTHREADS,
68+
args=(heap, per_thread_pop_count),
6869
)
6970
self.assertEqual(len(heap), 0)
7071

@@ -77,10 +78,10 @@ def heappushpop_func(heap, pushpop_items):
7778
popped_item = heapq.heappushpop(heap, item)
7879
self.assertTrue(popped_item <= item)
7980

80-
self.run_concurrently(
81+
run_concurrently(
8182
worker_func=heappushpop_func,
82-
args=(heap, pushpop_items),
8383
nthreads=NTHREADS,
84+
args=(heap, pushpop_items),
8485
)
8586
self.assertEqual(len(heap), OBJECT_COUNT)
8687
self.test_heapq.check_invariant(heap)
@@ -93,10 +94,10 @@ def heapreplace_func(heap, replace_items):
9394
for item in replace_items:
9495
heapq.heapreplace(heap, item)
9596

96-
self.run_concurrently(
97+
run_concurrently(
9798
worker_func=heapreplace_func,
98-
args=(heap, replace_items),
9999
nthreads=NTHREADS,
100+
args=(heap, replace_items),
100101
)
101102
self.assertEqual(len(heap), OBJECT_COUNT)
102103
self.test_heapq.check_invariant(heap)
@@ -105,8 +106,8 @@ def test_racing_heapify_max(self):
105106
max_heap = list(range(OBJECT_COUNT))
106107
shuffle(max_heap)
107108

108-
self.run_concurrently(
109-
worker_func=heapq.heapify_max, args=(max_heap,), nthreads=NTHREADS
109+
run_concurrently(
110+
worker_func=heapq.heapify_max, nthreads=NTHREADS, args=(max_heap,)
110111
)
111112
self.test_heapq.check_max_invariant(max_heap)
112113

@@ -117,8 +118,8 @@ def heappush_max_func(max_heap):
117118
for item in range(OBJECT_COUNT):
118119
heapq.heappush_max(max_heap, item)
119120

120-
self.run_concurrently(
121-
worker_func=heappush_max_func, args=(max_heap,), nthreads=NTHREADS
121+
run_concurrently(
122+
worker_func=heappush_max_func, nthreads=NTHREADS, args=(max_heap,)
122123
)
123124
self.test_heapq.check_max_invariant(max_heap)
124125

@@ -138,10 +139,10 @@ def heappop_max_func(max_heap, pop_count):
138139
# Each local list should be sorted
139140
self.assertTrue(self.is_sorted_descending(local_list))
140141

141-
self.run_concurrently(
142+
run_concurrently(
142143
worker_func=heappop_max_func,
143-
args=(max_heap, per_thread_pop_count),
144144
nthreads=NTHREADS,
145+
args=(max_heap, per_thread_pop_count),
145146
)
146147
self.assertEqual(len(max_heap), 0)
147148

@@ -154,10 +155,10 @@ def heappushpop_max_func(max_heap, pushpop_items):
154155
popped_item = heapq.heappushpop_max(max_heap, item)
155156
self.assertTrue(popped_item >= item)
156157

157-
self.run_concurrently(
158+
run_concurrently(
158159
worker_func=heappushpop_max_func,
159-
args=(max_heap, pushpop_items),
160160
nthreads=NTHREADS,
161+
args=(max_heap, pushpop_items),
161162
)
162163
self.assertEqual(len(max_heap), OBJECT_COUNT)
163164
self.test_heapq.check_max_invariant(max_heap)
@@ -170,10 +171,10 @@ def heapreplace_max_func(max_heap, replace_items):
170171
for item in replace_items:
171172
heapq.heapreplace_max(max_heap, item)
172173

173-
self.run_concurrently(
174+
run_concurrently(
174175
worker_func=heapreplace_max_func,
175-
args=(max_heap, replace_items),
176176
nthreads=NTHREADS,
177+
args=(max_heap, replace_items),
177178
)
178179
self.assertEqual(len(max_heap), OBJECT_COUNT)
179180
self.test_heapq.check_max_invariant(max_heap)
@@ -203,7 +204,7 @@ def worker():
203204
except IndexError:
204205
pass
205206

206-
self.run_concurrently(worker, (), n_threads * 2)
207+
run_concurrently(worker, n_threads * 2)
207208

208209
@staticmethod
209210
def is_sorted_ascending(lst):
@@ -241,27 +242,6 @@ def create_random_list(a, b, size):
241242
"""
242243
return [randint(-a, b) for _ in range(size)]
243244

244-
def run_concurrently(self, worker_func, args, nthreads):
245-
"""
246-
Run the worker function concurrently in multiple threads.
247-
"""
248-
barrier = Barrier(nthreads)
249-
250-
def wrapper_func(*args):
251-
# Wait for all threads to reach this point before proceeding.
252-
barrier.wait()
253-
worker_func(*args)
254-
255-
with threading_helper.catch_threading_exception() as cm:
256-
workers = (
257-
Thread(target=wrapper_func, args=args) for _ in range(nthreads)
258-
)
259-
with threading_helper.start_threads(workers):
260-
pass
261-
262-
# Worker threads should not raise any exceptions
263-
self.assertIsNone(cm.exc_value)
264-
265245

266246
if __name__ == "__main__":
267247
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make functions in :mod:`grp` thread-safe on the :term:`free threaded <free threading>` build.

Modules/grpmodule.c

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ get_grp_state(PyObject *module)
5555

5656
static struct PyModuleDef grpmodule;
5757

58+
/* Mutex to protect calls to getgrgid(), getgrnam(), and getgrent().
59+
* These functions return pointer to static data structure, which
60+
* may be overwritten by any subsequent calls. */
61+
static PyMutex group_db_mutex = {0};
62+
5863
#define DEFAULT_BUFFER_SIZE 1024
5964

6065
static PyObject *
@@ -168,9 +173,15 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
168173

169174
Py_END_ALLOW_THREADS
170175
#else
176+
PyMutex_Lock(&group_db_mutex);
177+
// The getgrgid() function need not be thread-safe.
178+
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrgid.html
171179
p = getgrgid(gid);
172180
#endif
173181
if (p == NULL) {
182+
#ifndef HAVE_GETGRGID_R
183+
PyMutex_Unlock(&group_db_mutex);
184+
#endif
174185
PyMem_RawFree(buf);
175186
if (nomem == 1) {
176187
return PyErr_NoMemory();
@@ -185,6 +196,8 @@ grp_getgrgid_impl(PyObject *module, PyObject *id)
185196
retval = mkgrent(module, p);
186197
#ifdef HAVE_GETGRGID_R
187198
PyMem_RawFree(buf);
199+
#else
200+
PyMutex_Unlock(&group_db_mutex);
188201
#endif
189202
return retval;
190203
}
@@ -249,9 +262,15 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
249262

250263
Py_END_ALLOW_THREADS
251264
#else
265+
PyMutex_Lock(&group_db_mutex);
266+
// The getgrnam() function need not be thread-safe.
267+
// https://pubs.opengroup.org/onlinepubs/9699919799/functions/getgrnam.html
252268
p = getgrnam(name_chars);
253269
#endif
254270
if (p == NULL) {
271+
#ifndef HAVE_GETGRNAM_R
272+
PyMutex_Unlock(&group_db_mutex);
273+
#endif
255274
if (nomem == 1) {
256275
PyErr_NoMemory();
257276
}
@@ -261,6 +280,9 @@ grp_getgrnam_impl(PyObject *module, PyObject *name)
261280
goto out;
262281
}
263282
retval = mkgrent(module, p);
283+
#ifndef HAVE_GETGRNAM_R
284+
PyMutex_Unlock(&group_db_mutex);
285+
#endif
264286
out:
265287
PyMem_RawFree(buf);
266288
Py_DECREF(bytes);
@@ -285,8 +307,7 @@ grp_getgrall_impl(PyObject *module)
285307
return NULL;
286308
}
287309

288-
static PyMutex getgrall_mutex = {0};
289-
PyMutex_Lock(&getgrall_mutex);
310+
PyMutex_Lock(&group_db_mutex);
290311
setgrent();
291312

292313
struct group *p;
@@ -306,7 +327,7 @@ grp_getgrall_impl(PyObject *module)
306327

307328
done:
308329
endgrent();
309-
PyMutex_Unlock(&getgrall_mutex);
330+
PyMutex_Unlock(&group_db_mutex);
310331
return d;
311332
}
312333

Tools/c-analyzer/cpython/ignored.tsv

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ Python/sysmodule.c - _preinit_xoptions -
167167
# XXX need race protection?
168168
Modules/faulthandler.c faulthandler_dump_traceback reentrant -
169169
Modules/faulthandler.c faulthandler_dump_c_stack reentrant -
170+
Modules/grpmodule.c - group_db_mutex -
170171
Python/pylifecycle.c _Py_FatalErrorFormat reentrant -
171172
Python/pylifecycle.c fatal_error reentrant -
172173

0 commit comments

Comments
 (0)