1
1
from datetime import datetime , timedelta
2
2
from typing import Annotated
3
3
4
- from fastapi import Depends , FastAPI , HTTPException , status
5
- from fastapi .security import OAuth2PasswordBearer , OAuth2PasswordRequestForm
6
- from jose import JWTError , jwt
7
- from passlib .context import CryptContext
8
- from sqlalchemy .orm import Session
4
+ import bcrypt
5
+ from chat import models
6
+ from chat .database import get_db
7
+ from chat .schema import TokenData , User
9
8
from chat .setting import setting
10
9
from chat .utils .exception import CredentialsException
11
- from chat .database import SessionLocal , engine
12
- import chat .models as models
13
- from chat .schema import TokenData , User
10
+ from fastapi import Depends , HTTPException
11
+ from fastapi .security import OAuth2PasswordBearer
12
+ from jose import JWTError , jwt
13
+ from sqlalchemy .orm import Session
14
14
15
- models . Base . metadata . create_all ( bind = engine )
15
+ oauth2_scheme = OAuth2PasswordBearer ( tokenUrl = "token" )
16
16
17
17
18
- def get_db () -> Session :
19
- db = SessionLocal ()
20
- try :
21
- yield db
22
- finally :
23
- db .close ()
18
+ def verify_password (plain_password : str , hashed_password : str ) -> bool :
19
+ """
20
+ Verifies if the provided plain text password matches the stored hashed password.
21
+
22
+ Args:
23
+ plain_password: The plain text password entered by the user.
24
+ hashed_password: The stored hashed password from the database.
24
25
26
+ Returns:
27
+ True if the passwords match, False otherwise.
28
+ """
29
+ encoded_hashed_password = hashed_password .encode ("utf-8" )
30
+ return bcrypt .checkpw (
31
+ plain_password .encode ("utf-8" ),
32
+ encoded_hashed_password ,
33
+ )
25
34
26
- pwd_context = CryptContext (schemes = ["bcrypt" ], deprecated = "auto" )
27
35
28
- oauth2_scheme = OAuth2PasswordBearer (tokenUrl = "token" )
36
+ def get_password_hash (password : str ) -> str :
37
+ """
38
+ Generates a bcrypt hash for the provided password.
29
39
40
+ Args:
41
+ password: The plain text password to hash.
30
42
31
- def verify_password (plain_password , hashed_password ):
32
- return pwd_context .verify (plain_password , hashed_password )
43
+ Returns:
44
+ The generated password hash.
45
+ """
46
+ hashed_bytes = bcrypt .hashpw (password .encode ("utf-8" ), bcrypt .gensalt ())
47
+ return hashed_bytes .decode ("utf-8" )
33
48
34
49
35
- def get_password_hash (password ):
36
- return pwd_context .hash (password )
50
+ def get_user (user_db : Session , username : str ) -> models .User :
51
+ """
52
+ Retrieve a user from the database based on the username.
37
53
54
+ Args:
55
+ user_db (Session): The database session.
56
+ username (str): The username of the user to retrieve.
38
57
39
- def get_user (user_db : Session , username : str ):
58
+ Returns:
59
+ Optional[models.User]: The user object if found, None otherwise.
60
+ """
40
61
user = user_db .query (models .User ).filter (models .User .username == username ).first ()
41
62
return user
42
63
43
64
44
- def authenticate_user (user_db : Session , username : str , password : str ):
65
+ def authenticate_user (
66
+ user_db : Session , username : str , password : str
67
+ ) -> models .User | None :
68
+ """
69
+ Authenticate a user based on the provided username and password.
70
+
71
+ Args:
72
+ user_db (Session): The database session.
73
+ username (str): The username of the user to authenticate.
74
+ password (str): The password of the user to authenticate.
75
+
76
+ Returns:
77
+ Optional[models.User]: The authenticated user object if successful, None otherwise.
78
+ """
45
79
user = get_user (user_db , username )
46
80
if not user :
47
- return False
81
+ return None
48
82
if not verify_password (password , user .password ):
49
- return False
83
+ return None
50
84
return user
51
85
52
86
53
- def create_access_token (data : dict , expires_delta : timedelta | None = None ):
87
+ def create_access_token (data : dict , expires_delta : timedelta | None = None ) -> str :
88
+ """
89
+ Create an access token with the provided data.
90
+
91
+ Args:
92
+ data (dict): The data to include in the token payload.
93
+ expires_delta (timedelta, optional): The expiration time delta for the token. Defaults to None.
94
+
95
+ Returns:
96
+ str: The generated access token.
97
+ """
54
98
to_encode = data .copy ()
55
99
if expires_delta :
56
100
expire = datetime .utcnow () + expires_delta
@@ -71,7 +115,17 @@ async def get_current_user(
71
115
user_db : Session = Depends (
72
116
get_db ,
73
117
),
74
- ):
118
+ ) -> User :
119
+ """
120
+ Get the current authenticated user from the provided token.
121
+
122
+ Args:
123
+ token (str): The JWT token representing the user.
124
+ user_db (Session, optional): The database session. Defaults to Depends(get_db).
125
+
126
+ Returns:
127
+ models.User: The current authenticated user.
128
+ """
75
129
token_data = decode_jwt (token )
76
130
user = get_user (user_db , username = token_data .username )
77
131
if user is None :
@@ -80,14 +134,37 @@ async def get_current_user(
80
134
81
135
82
136
async def get_current_active_user (
83
- current_user : Annotated [User , Depends (get_current_user )]
84
- ):
137
+ current_user : Annotated [User , Depends (get_current_user )],
138
+ ) -> User :
139
+ """
140
+ Get the current active authenticated user.
141
+
142
+ Args:
143
+ current_user (User): The current authenticated user.
144
+
145
+ Raises:
146
+ HTTPException: If the user is inactive.
147
+
148
+ Returns:
149
+ models.User: The current active authenticated user.
150
+ """
85
151
if current_user .disabled :
86
- raise HTTPException (status_code = 400 , detail = "Inactive user" )
152
+ raise HTTPException (
153
+ status_code = 400 , detail = "Inactive user"
154
+ ) # TODO Add this to Exceptions
87
155
return current_user
88
156
89
157
90
- def get_admin_payload (token : str ):
158
+ def get_admin_payload (token : str ) -> dict | None :
159
+ """
160
+ Decode the payload of the provided JWT token for admin user.
161
+
162
+ Args:
163
+ token (str): The JWT token to decode.
164
+
165
+ Returns:
166
+ Optional[dict]: The payload data containing username and id if the token is valid, None otherwise.
167
+ """
91
168
try :
92
169
payload = jwt .decode (token , setting .SECRET_KEY , setting .ALGORITHM )
93
170
username : str = payload .get ("username" )
@@ -97,7 +174,18 @@ def get_admin_payload(token: str):
97
174
return
98
175
99
176
100
- def decode_jwt (token : Annotated [str , Depends (oauth2_scheme )]) -> TokenData :
177
+ def decode_jwt (
178
+ token : Annotated [str , Depends (oauth2_scheme )]
179
+ ) -> TokenData | CredentialsException :
180
+ """
181
+ Decode the provided JWT token and extract the token data.
182
+
183
+ Args:
184
+ token (str): The JWT token to decode.
185
+
186
+ Returns:
187
+ TokenData: The token data containing username and id.
188
+ """
101
189
try :
102
190
payload = jwt .decode (
103
191
token ,
0 commit comments