31
31
import time
32
32
from functools import partial
33
33
from itertools import groupby
34
- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence
34
+ from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence , TypeVar
35
35
36
36
from ..components import ActionRow as ActionRowComponent
37
37
from ..components import Button as ButtonComponent
57
57
from ..state import ConnectionState
58
58
from ..types .components import Component as ComponentPayload
59
59
60
+ V = TypeVar ("V" , bound = "View" , covariant = True )
61
+
60
62
61
63
def _walk_all_components (components : list [Component ]) -> Iterator [Component ]:
62
64
for item in components :
@@ -66,6 +68,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
66
68
yield item
67
69
68
70
71
+
69
72
def _walk_all_components_v2 (components : list [Component ]) -> Iterator [Component ]:
70
73
for item in components :
71
74
if isinstance (item , ActionRowComponent ):
@@ -74,9 +77,10 @@ def _walk_all_components_v2(components: list[Component]) -> Iterator[Component]:
74
77
yield from item .walk_components ()
75
78
else :
76
79
yield item
80
+
81
+
82
+ def _component_to_item (component : Component ) -> Item [V ]:
77
83
78
-
79
- def _component_to_item (component : Component ) -> Item :
80
84
if isinstance (component , ButtonComponent ):
81
85
from .button import Button
82
86
@@ -123,7 +127,7 @@ def _component_to_item(component: Component) -> Item:
123
127
class _ViewWeights :
124
128
__slots__ = ("weights" ,)
125
129
126
- def __init__ (self , children : list [Item ]):
130
+ def __init__ (self , children : list [Item [ V ] ]):
127
131
self .weights : list [int ] = [0 , 0 , 0 , 0 , 0 ]
128
132
129
133
key = lambda i : sys .maxsize if i .row is None else i .row
@@ -132,7 +136,7 @@ def __init__(self, children: list[Item]):
132
136
for item in group :
133
137
self .add_item (item )
134
138
135
- def find_open_space (self , item : Item ) -> int :
139
+ def find_open_space (self , item : Item [ V ] ) -> int :
136
140
for index , weight in enumerate (self .weights ):
137
141
# check if open space AND (next row has no items OR this is the last row)
138
142
if (weight + item .width <= 5 ) and (
@@ -143,11 +147,12 @@ def find_open_space(self, item: Item) -> int:
143
147
144
148
raise ValueError ("could not find open space for item" )
145
149
146
- def add_item (self , item : Item ) -> None :
150
+ def add_item (self , item : Item [ V ] ) -> None :
147
151
if (
148
152
item ._underlying .is_v2 () or not self .fits_legacy (item )
149
153
) and not self .requires_v2 ():
150
154
self .weights .extend ([0 , 0 , 0 , 0 , 0 ] * 7 )
155
+
151
156
if item .row is not None :
152
157
total = self .weights [item .row ] + item .width
153
158
if total > 5 :
@@ -161,7 +166,7 @@ def add_item(self, item: Item) -> None:
161
166
self .weights [index ] += item .width
162
167
item ._rendered_row = index
163
168
164
- def remove_item (self , item : Item ) -> None :
169
+ def remove_item (self , item : Item [ V ] ) -> None :
165
170
if item ._rendered_row is not None :
166
171
self .weights [item ._rendered_row ] -= item .width
167
172
item ._rendered_row = None
@@ -227,15 +232,15 @@ def __init_subclass__(cls) -> None:
227
232
228
233
def __init__ (
229
234
self ,
230
- * items : Item ,
235
+ * items : Item [ V ] ,
231
236
timeout : float | None = 180.0 ,
232
237
disable_on_timeout : bool = False ,
233
238
):
234
239
self .timeout = timeout
235
240
self .disable_on_timeout = disable_on_timeout
236
- self .children : list [Item ] = []
241
+ self .children : list [Item [ V ] ] = []
237
242
for func in self .__view_children_items__ :
238
- item : Item = func .__discord_ui_model_type__ (
243
+ item : Item [ V ] = func .__discord_ui_model_type__ (
239
244
** func .__discord_ui_model_kwargs__
240
245
)
241
246
item .callback = partial (func , self , item )
@@ -278,7 +283,7 @@ async def __timeout_task_impl(self) -> None:
278
283
await asyncio .sleep (self .__timeout_expiry - now )
279
284
280
285
def to_components (self ) -> list [dict [str , Any ]]:
281
- def key (item : Item ) -> int :
286
+ def key (item : Item [ V ] ) -> int :
282
287
return item ._rendered_row or 0
283
288
284
289
children = sorted (self .children , key = key )
@@ -365,7 +370,7 @@ def _expires_at(self) -> float | None:
365
370
return time .monotonic () + self .timeout
366
371
return None
367
372
368
- def add_item (self , item : Item ) -> None :
373
+ def add_item (self , item : Item [ V ] ) -> None :
369
374
"""Adds an item to the view.
370
375
371
376
Parameters
@@ -397,7 +402,7 @@ def add_item(self, item: Item) -> None:
397
402
self .children .append (item )
398
403
return self
399
404
400
- def remove_item (self , item : Item | int | str ) -> None :
405
+ def remove_item (self , item : Item [ V ] | int | str ) -> None :
401
406
"""Removes an item from the view. If an int or str is passed, it will remove by Item :attr:`id` or ``custom_id`` respectively.
402
407
403
408
Parameters
@@ -422,7 +427,7 @@ def clear_items(self) -> None:
422
427
self .__weights .clear ()
423
428
return self
424
429
425
- def get_item (self , custom_id : str | int ) -> Item | None :
430
+ def get_item (self , custom_id : str | int ) -> Item [ V ] | None :
426
431
"""Get an item from the view. Roughly equal to `utils.get(view.children, ...)`.
427
432
If an ``int`` is provided it will retrieve by ``id``, otherwise it will check ``custom_id``.
428
433
This method will also search nested items.
@@ -508,7 +513,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
508
513
"""
509
514
510
515
async def on_error (
511
- self , error : Exception , item : Item , interaction : Interaction
516
+ self , error : Exception , item : Item [ V ] , interaction : Interaction
512
517
) -> None :
513
518
"""|coro|
514
519
@@ -528,7 +533,7 @@ async def on_error(
528
533
"""
529
534
interaction .client .dispatch ("view_error" , error , item , interaction )
530
535
531
- async def _scheduled_task (self , item : Item , interaction : Interaction ):
536
+ async def _scheduled_task (self , item : Item [ V ] , interaction : Interaction ):
532
537
try :
533
538
if self .timeout :
534
539
self .__timeout_expiry = time .monotonic () + self .timeout
@@ -560,7 +565,7 @@ def _dispatch_timeout(self):
560
565
self .on_timeout (), name = f"discord-ui-view-timeout-{ self .id } "
561
566
)
562
567
563
- def _dispatch_item (self , item : Item , interaction : Interaction ):
568
+ def _dispatch_item (self , item : Item [ V ] , interaction : Interaction ):
564
569
if self .__stopped .done ():
565
570
return
566
571
@@ -656,7 +661,7 @@ async def wait(self) -> bool:
656
661
"""
657
662
return await self .__stopped
658
663
659
- def disable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
664
+ def disable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
660
665
"""
661
666
Disables all buttons and select menus in the view.
662
667
@@ -674,7 +679,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
674
679
child .disable_all_items (exclusions = exclusions )
675
680
return self
676
681
677
- def enable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
682
+ def enable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
678
683
"""
679
684
Enables all buttons and select menus in the view.
680
685
@@ -715,7 +720,7 @@ def message(self, value):
715
720
class ViewStore :
716
721
def __init__ (self , state : ConnectionState ):
717
722
# (component_type, message_id, custom_id): (View, Item)
718
- self ._views : dict [tuple [int , int | None , str ], tuple [View , Item ]] = {}
723
+ self ._views : dict [tuple [int , int | None , str ], tuple [View , Item [ V ] ]] = {}
719
724
# message_id: View
720
725
self ._synced_message_views : dict [int , View ] = {}
721
726
self ._state : ConnectionState = state
0 commit comments