jwt.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. import os
  2. from datetime import datetime, timedelta
  3. from typing import Union
  4. from dotenv import load_dotenv
  5. from fastapi import Depends, FastAPI, HTTPException, status
  6. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  7. from jose import JWTError, jwt
  8. from passlib.context import CryptContext
  9. from models.model import UserInDB, TokenData, User, Token, fake_users_db
  10. # 加载环境变量
  11. load_dotenv()
  12. SECRET_KEY = os.getenv("SECRET_KEY")
  13. ALGORITHM = os.getenv("ALGORITHM")
  14. ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
  15. print(SECRET_KEY)
  16. # schemes 加密方式,默认第一个
  17. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  18. # 请求/token 返回一个令牌
  19. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  20. # 比较哈希值是否一直,一直就返回True,否则返回False
  21. def verify_password(plain_password, hashed_password):
  22. return pwd_context.verify(plain_password, hashed_password)
  23. # 对用户输入的密码进行hash加密
  24. def get_password_hash(password):
  25. return pwd_context.hash(password)
  26. # 去数据库中寻找,用户是否在数据库中,能找到就返回用户的所有信息
  27. def get_user(db, username: str):
  28. if username in db:
  29. user_dict = db[username]
  30. return UserInDB(**user_dict)
  31. # 判断用户是否存在于数据库中,存在就比较hash密码,比对成功,返回用户信息
  32. def authenticate_user(fake_db, username: str, password: str):
  33. user = get_user(fake_db, username)
  34. if not user:
  35. return False
  36. if not verify_password(password, user.hashed_password):
  37. return False
  38. return user
  39. # data:{"sub":user.username} datetime.utcnow():2023-05-31 23:46:27.912774
  40. # utcnow()用于记录当前时间,datetime模块中的timedelta返回的数据可与utcnow()相加
  41. def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
  42. to_encode = data.copy()
  43. if expires_delta:
  44. expire = datetime.utcnow() + expires_delta
  45. else:
  46. expire = datetime.utcnow() + timedelta(minutes=15)
  47. # update可以将过期时间加入到to_encode字典中 => to_encoded = {"sub":user.username,"exp":expire}
  48. to_encode.update({"exp": expire})
  49. # jwt.encode(加密数据,密钥,加密方式)
  50. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  51. # 最终返回被加密后的to_encoded
  52. return encoded_jwt
  53. # 这个函数必须携带令牌才能执行,携带令牌获取用户,返回用户信息
  54. async def get_current_user(token: str = Depends(oauth2_scheme)):
  55. credentials_exception = HTTPException(
  56. status_code=status.HTTP_401_UNAUTHORIZED,
  57. detail="Could not validate credentials",
  58. headers={"WWW-Authenticate": "Bearer"},
  59. )
  60. try:
  61. # 对token进行解码
  62. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  63. username: str = payload.get("sub")
  64. if username is None:
  65. raise credentials_exception
  66. token_data = TokenData(username=username)
  67. except JWTError:
  68. raise credentials_exception
  69. user = get_user(fake_users_db, username=token_data.username)
  70. if user is None:
  71. raise credentials_exception
  72. return user
  73. # 这里接收用户信息,接收到了就返回用户信息,没接收到,就表示令牌过期了,或者是未登陆
  74. async def get_current_active_user(current_user: User = Depends(get_current_user)):
  75. if current_user.disabled:
  76. raise HTTPException(status_code=400, detail="Inactive user")
  77. return current_user