jwt.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import os
  2. from datetime import datetime, timedelta
  3. from typing import Union
  4. from dotenv import load_dotenv
  5. from fastapi import Depends, FastAPI, HTTPException, status
  6. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  7. from jose import JWTError, jwt
  8. from passlib.context import CryptContext
  9. from models.model import UserInDB, TokenData, User, Token, fake_users_db
  10. # 加载环境变量
  11. load_dotenv()
  12. SECRET_KEY = os.getenv("SECRET_KEY")
  13. ALGORITHM = os.getenv("ALGORITHM")
  14. ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
  15. # schemes 加密方式,默认第一个
  16. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  17. # 请求/token 返回一个令牌
  18. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  19. # 比较哈希值是否一直,一直就返回True,否则返回False
  20. def verify_password(plain_password, hashed_password):
  21. return pwd_context.verify(plain_password, hashed_password)
  22. # 对用户输入的密码进行hash加密
  23. def get_password_hash(password):
  24. return pwd_context.hash(password)
  25. # 去数据库中寻找,用户是否在数据库中,能找到就返回用户的所有信息
  26. def get_user(db, username: str):
  27. if username in db:
  28. user_dict = db[username]
  29. return UserInDB(**user_dict)
  30. # 判断用户是否存在于数据库中,存在就比较hash密码,比对成功,返回用户信息
  31. def authenticate_user(fake_db, username: str, password: str):
  32. user = get_user(fake_db, username)
  33. if not user:
  34. return False
  35. if not verify_password(password, user.hashed_password):
  36. return False
  37. return user
  38. # data:{"sub":user.username} datetime.utcnow():2023-05-31 23:46:27.912774
  39. # utcnow()用于记录当前时间,datetime模块中的timedelta返回的数据可与utcnow()相加
  40. def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
  41. to_encode = data.copy()
  42. if expires_delta:
  43. expire = datetime.utcnow() + expires_delta
  44. else:
  45. expire = datetime.utcnow() + timedelta(minutes=15)
  46. # update可以将过期时间加入到to_encode字典中 => to_encoded = {"sub":user.username,"exp":expire}
  47. to_encode.update({"exp": expire})
  48. # jwt.encode(加密数据,密钥,加密方式)
  49. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  50. # 最终返回被加密后的to_encoded
  51. return encoded_jwt
  52. # 这个函数必须携带令牌才能执行,携带令牌获取用户,返回用户信息
  53. async def get_current_user(token: str = Depends(oauth2_scheme)):
  54. credentials_exception = HTTPException(
  55. status_code=status.HTTP_401_UNAUTHORIZED,
  56. detail="Could not validate credentials",
  57. headers={"WWW-Authenticate": "Bearer"},
  58. )
  59. try:
  60. # 对token进行解码
  61. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  62. username: str = payload.get("sub")
  63. if username is None:
  64. raise credentials_exception
  65. token_data = TokenData(username=username)
  66. except JWTError:
  67. raise credentials_exception
  68. user = get_user(fake_users_db, username=token_data.username)
  69. if user is None:
  70. raise credentials_exception
  71. return user
  72. # 这里接收用户信息,接收到了就返回用户信息,没接收到,就表示令牌过期了,或者是未登陆
  73. async def get_current_active_user(current_user: User = Depends(get_current_user)):
  74. if current_user.disabled:
  75. raise HTTPException(status_code=400, detail="Inactive user")
  76. return current_user