Skip to content

Refactor api #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from sqlmodel import SQLModel, Field, Session, create_engine, select
from passlib.context import CryptContext
from typing import Generator
import random


SECRET_KEY = "your_secret_key_here"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30

hash_context = CryptContext(schemes=["bcrypt"], deprecated="auto")

connect_args = {"check_same_thread": False}
engine = create_engine("sqlite:///orm-user.db", connect_args=connect_args)


class User(SQLModel, table=True):
name: str
username: str | None = Field(default=None, primary_key=True)
email: str | None = Field(default=None)
hashed_password: str | None = Field(default=None)
disabled: bool | None = Field(default=None)


def create_db_and_table():
SQLModel.metadata.create_all(engine)


def get_session() -> Generator:
with Session(engine) as session:
yield session


def get_hash(password: str) -> str:
return hash_context.hash(password)

def verify_password(hashed_password, password):
return hash_context.verify(password, hashed_password)

def generate_unique_username(name: str, session: Session) -> str:
while True:
username = name + str(random.randint(1, 100))
statement = select(User).where(User.username == username)
if not session.exec(statement).first():
return username

111 changes: 28 additions & 83 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,43 @@
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from datetime import datetime, timedelta, timezone
from jose import JWTError, jwt
from passlib.context import CryptContext
from sqlmodel import Field, Session, SQLModel, create_engine, select
import random
from typing import Annotated

from database import (
User, create_db_and_table, get_session, get_hash, verify_password,
generate_unique_username, SECRET_KEY, ALGORITHM, ACCESS_TOKEN_EXPIRE_MINUTES
)

SECRET_KEY = "your_secret_key_here"
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
from sqlmodel import Session, select

hash_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
Oauth2Scheme = OAuth2PasswordBearer(tokenUrl="login")


class User(SQLModel, table=True):
name: str
username: str | None = Field(default=None, primary_key=True)
email: str | None = Field(default=None)
hashed_password: str | None = Field(default=None)
disabled: bool | None = Field(default=None)


connect_args = {"check_same_thread": False}
engine = create_engine("sqlite:///orm-user.db", connect_args=connect_args)

def create_db_and_table():
SQLModel.metadata.create_all(engine)


def get_session():
with Session(engine) as session:
yield session


SessionDep = Annotated[Session, Depends(get_session)]


def get_hash(password: str) -> str:
return hash_context.hash(password)
app = FastAPI()


def verify_password(hashed_password, password):
return hash_context.verify(password, hashed_password)
@app.on_event("startup")
def on_startup():
create_db_and_table()


def authenticate_user(username: str, session: SessionDep, password: str):
def authenticate_user(username: str, session: Session, password: str):
user = session.get(User, username)
if not user:
return False
if user.username != username:
return False
if not verify_password(user.hashed_password, password):
return False
return True


def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
expire = datetime.now(timezone.utc) + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)


async def get_current_user(session: SessionDep, token: str = Depends(Oauth2Scheme)):
Expand All @@ -84,7 +55,7 @@ async def get_current_user(session: SessionDep, token: str = Depends(Oauth2Schem
raise credentials_exception

user = session.get(User, username)
if user is None or user.username != username:
if user is None:
raise credentials_exception
return user

Expand All @@ -94,34 +65,20 @@ async def get_current_active_user(current_user: User = Depends(get_current_user)
raise HTTPException(status_code=400, detail="Inactive user")
return current_user

app = FastAPI()


@app.on_event("startup")
def on_startup():
create_db_and_table()


@app.post("/signup")
async def add_user(session: SessionDep, name: str, password: str, email: str):
new_user = User()
new_user.name = name

statement = select(User).where(User.email == email)
result = session.exec(statement).first()
if not result:
new_user.email = email
else:
if session.exec(statement).first():
raise HTTPException(status_code=404, detail="email already exists")
while 1:
new_user.username = name + str(random.randint(1, 100))
statement = select(User).where(User.username == new_user.username)
result = session.exec(statement).first()
if not result:
break

new_user.hashed_password = get_hash(password)
new_user.disabled = False

new_user = User(
name=name,
email=email,
username=generate_unique_username(name, session),
hashed_password=get_hash(password),
disabled=False
)
session.add(new_user)
session.commit()
session.refresh(new_user)
Expand All @@ -130,17 +87,5 @@ async def add_user(session: SessionDep, name: str, password: str, email: str):

@app.post("/login")
async def login(session: SessionDep, form_data: OAuth2PasswordRequestForm = Depends()):
user = authenticate_user(form_data.username, session, form_data.password)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"})
user = session.get(User, form_data.username)
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires)
return {"access_token": access_token, "token_type": "bearer"}


@app.get("/users/me", response_model=User)
async def read_users_me(current_user: User = Depends(get_current_active_user)):
return current_user
if not authenticate_user(form_data.username, session, form_data.password):