12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394 |
- import os
- from datetime import datetime, timedelta
- from typing import Union
- from dotenv import load_dotenv
- from fastapi import Depends, FastAPI, HTTPException, status
- from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
- from jose import JWTError, jwt
- from passlib.context import CryptContext
- from models.model import UserInDB, TokenData, User, Token, fake_users_db
- # 加载环境变量
- load_dotenv()
- SECRET_KEY = os.getenv("SECRET_KEY")
- ALGORITHM = os.getenv("ALGORITHM")
- ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
- print(SECRET_KEY)
- # schemes 加密方式,默认第一个
- pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
- # 请求/token 返回一个令牌
- oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
- # 比较哈希值是否一直,一直就返回True,否则返回False
- def verify_password(plain_password, hashed_password):
- return pwd_context.verify(plain_password, hashed_password)
- # 对用户输入的密码进行hash加密
- def get_password_hash(password):
- return pwd_context.hash(password)
- # 去数据库中寻找,用户是否在数据库中,能找到就返回用户的所有信息
- def get_user(db, username: str):
- if username in db:
- user_dict = db[username]
- return UserInDB(**user_dict)
- # 判断用户是否存在于数据库中,存在就比较hash密码,比对成功,返回用户信息
- def authenticate_user(fake_db, username: str, password: str):
- user = get_user(fake_db, username)
- if not user:
- return False
- if not verify_password(password, user.hashed_password):
- return False
- return user
- # data:{"sub":user.username} datetime.utcnow():2023-05-31 23:46:27.912774
- # utcnow()用于记录当前时间,datetime模块中的timedelta返回的数据可与utcnow()相加
- def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
- to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
- # update可以将过期时间加入到to_encode字典中 => to_encoded = {"sub":user.username,"exp":expire}
- to_encode.update({"exp": expire})
- # jwt.encode(加密数据,密钥,加密方式)
- encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
- # 最终返回被加密后的to_encoded
- return encoded_jwt
- # 这个函数必须携带令牌才能执行,携带令牌获取用户,返回用户信息
- async def get_current_user(token: str = Depends(oauth2_scheme)):
- credentials_exception = HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED,
- detail="Could not validate credentials",
- headers={"WWW-Authenticate": "Bearer"},
- )
- try:
- # 对token进行解码
- payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
- username: str = payload.get("sub")
- if username is None:
- raise credentials_exception
- token_data = TokenData(username=username)
- except JWTError:
- raise credentials_exception
- user = get_user(fake_users_db, username=token_data.username)
- if user is None:
- raise credentials_exception
- return user
- # 这里接收用户信息,接收到了就返回用户信息,没接收到,就表示令牌过期了,或者是未登陆
- async def get_current_active_user(current_user: User = Depends(get_current_user)):
- if current_user.disabled:
- raise HTTPException(status_code=400, detail="Inactive user")
- return current_user
|