test_files.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import os
  2. import re
  3. import time
  4. import datetime
  5. from io import BytesIO
  6. from fastapi import Depends, FastAPI, HTTPException, status, UploadFile, APIRouter, File
  7. from starlette.responses import StreamingResponse, FileResponse
  8. from starlette.testclient import TestClient
  9. from pydantic import BaseModel
  10. import json
  11. from minio import Minio
  12. from cacheout import Cache
  13. import os
  14. from datetime import timedelta
  15. from typing import Union
  16. from dotenv import load_dotenv
  17. from fastapi import Depends, FastAPI, HTTPException, status
  18. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  19. from jose import JWTError, jwt
  20. from passlib.context import CryptContext
  21. class Item(BaseModel):
  22. id: int
  23. name: str
  24. class Token(BaseModel):
  25. access_token: str
  26. token_type: str
  27. class TokenData(BaseModel):
  28. username: Union[str, None] = None
  29. class User(BaseModel):
  30. username: str
  31. email: Union[str, None] = None
  32. full_name: Union[str, None] = None
  33. disabled: Union[bool, None] = None
  34. class UserInDB(User):
  35. hashed_password: str
  36. fake_users_db = {
  37. "johndoe": {
  38. "username": "johndoe",
  39. "full_name": "John Doe",
  40. "email": "johndoe@example.com",
  41. "hashed_password": "$2b$12$EixZaYVK1fsbw1ZfbX3OXePaWxn96p36WQoeG6Lruj3vjPGga31lW",
  42. "disabled": False,
  43. }
  44. }
  45. # openssl rand -hex 32
  46. # SECRET_KEY="09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7"
  47. # ALGORITHM="HS256"
  48. # ACCESS_TOKEN_EXPIRE_MINUTES=30
  49. # 加载环境变量
  50. load_dotenv()
  51. SECRET_KEY = os.getenv("SECRET_KEY")
  52. ALGORITHM = os.getenv("ALGORITHM")
  53. ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
  54. # schemes 加密方式,默认第一个
  55. pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
  56. # 请求/token 返回一个令牌
  57. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
  58. # 比较哈希值是否一直,一直就返回True,否则返回False
  59. def verify_password(plain_password, hashed_password):
  60. return pwd_context.verify(plain_password, hashed_password)
  61. # 对用户输入的密码进行hash加密
  62. def get_password_hash(password):
  63. return pwd_context.hash(password)
  64. # 去数据库中寻找,用户是否在数据库中,能找到就返回用户的所有信息
  65. def get_user(db, username: str):
  66. if username in db:
  67. user_dict = db[username]
  68. return UserInDB(**user_dict)
  69. # 判断用户是否存在于数据库中,存在就比较hash密码,比对成功,返回用户信息
  70. def authenticate_user(fake_db, username: str, password: str):
  71. user = get_user(fake_db, username)
  72. if not user:
  73. return False
  74. if not verify_password(password, user.hashed_password):
  75. return False
  76. return user
  77. # data:{"sub":user.username} datetime.utcnow():2023-05-31 23:46:27.912774
  78. # utcnow()用于记录当前时间,datetime模块中的timedelta返回的数据可与utcnow()相加
  79. def create_access_token(data: dict, expires_delta: Union[timedelta, None] = None):
  80. to_encode = data.copy()
  81. if expires_delta:
  82. expire = datetime.utcnow() + expires_delta
  83. else:
  84. expire = datetime.utcnow() + timedelta(minutes=15)
  85. # update可以将过期时间加入到to_encode字典中 => to_encoded = {"sub":user.username,"exp":expire}
  86. to_encode.update({"exp": expire})
  87. # jwt.encode(加密数据,密钥,加密方式)
  88. encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
  89. # 最终返回被加密后的to_encoded
  90. return encoded_jwt
  91. # 这个函数必须携带令牌才能执行,携带令牌获取用户,返回用户信息
  92. async def get_current_user(token: str = Depends(oauth2_scheme)):
  93. credentials_exception = HTTPException(
  94. status_code=status.HTTP_401_UNAUTHORIZED,
  95. detail="Could not validate credentials",
  96. headers={"WWW-Authenticate": "Bearer"},
  97. )
  98. try:
  99. # 对token进行解码
  100. payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
  101. username: str = payload.get("sub")
  102. if username is None:
  103. raise credentials_exception
  104. token_data = TokenData(username=username)
  105. except JWTError:
  106. raise credentials_exception
  107. user = get_user(fake_users_db, username=token_data.username)
  108. if user is None:
  109. raise credentials_exception
  110. return user
  111. # 这里接收用户信息,接收到了就返回用户信息,没接收到,就表示令牌过期了,或者是未登陆
  112. async def get_current_active_user(current_user: User = Depends(get_current_user)):
  113. if current_user.disabled:
  114. raise HTTPException(status_code=400, detail="Inactive user")
  115. return current_user
  116. # 从配置文件读取设置
  117. class MinioOperate:
  118. def __init__(self):
  119. with open(r"D:\pythonProject\django\fastapi_01\config\config.json", "r") as f:
  120. self.__config = json.load(f)
  121. self.minio_client = None
  122. def link_minio(self):
  123. self.minio_client = Minio(**self.__config["minio"])
  124. return self.minio_client
  125. def create_bucket(self, buckets: []):
  126. for bucket_name in buckets:
  127. if not self.minio_client.bucket_exists(bucket_name):
  128. try:
  129. self.minio_client.make_bucket(bucket_name)
  130. except Exception as e:
  131. print(f"Bucket creation failed: {e}")
  132. else:
  133. print(f"Bucket {bucket_name} already exists")
  134. class SetCache:
  135. def __init__(self, maxsize, ttl):
  136. self.cache = Cache(maxsize=maxsize, ttl=ttl)
  137. def get(self, uid):
  138. self.data = self.cache.get(uid)
  139. return self.data
  140. def add(self, uid, data):
  141. self.cache.add(uid, data)
  142. # 创建minio对象
  143. minio_class = MinioOperate()
  144. # 连接minio
  145. minio_client = minio_class.link_minio()
  146. # 创建bucket
  147. minio_class.create_bucket(["file", "image"])
  148. # 初始化缓存
  149. cache = SetCache(maxsize=128, ttl=10)
  150. app = FastAPI()
  151. PIC_NAME = None
  152. @app.post("/file")
  153. async def create_file(file: UploadFile = File(...)):
  154. timestamp = str(time.time()).ljust(18, "0")
  155. uid = re.sub(r"\.", "", timestamp)
  156. front, ext = os.path.splitext(file.filename)
  157. file_name = uid + ext # 168549427474778.png
  158. global PIC_NAME
  159. PIC_NAME = file_name
  160. data = await file.read()
  161. file_stream = BytesIO(initial_bytes=data)
  162. size = len(data)
  163. date = str(datetime.date.today())
  164. object_path = date + "/{}".format(file_name)
  165. if (minio_client.put_object(
  166. "image",
  167. object_path,
  168. file_stream,
  169. size
  170. )):
  171. return {"status": 200, "data": [file_name], "msg": ""}
  172. else:
  173. return {"status": 400, "data": [], "msg": "Post Failed!"}
  174. @app.get("/file/{uid}")
  175. async def download_file(uid: str):
  176. try:
  177. timestamp, ext = os.path.splitext(uid)
  178. timestamp = float(str(float(timestamp) / 10000000).ljust(18, "0"))
  179. object_path = str(time.localtime(timestamp).tm_year) + "-" + str(time.localtime(timestamp).tm_mon).rjust(2,
  180. "0") + "-" \
  181. + str(time.localtime(timestamp).tm_mday).rjust(2, "0") + "/{}".format(uid)
  182. file_obj = minio_client.get_object("image", object_path)
  183. if not cache.get(uid):
  184. # 添加缓存
  185. # print("第一次获取,添加到缓存")
  186. cache.add(uid, file_obj.read())
  187. else:
  188. # print("从缓存中找到uid,获取缓存")
  189. file_bytes = cache.get(uid)
  190. return StreamingResponse(BytesIO(file_bytes), media_type="image/{}".format(ext[1:]))
  191. file_content = BytesIO(file_obj.read())
  192. response = StreamingResponse(file_content, media_type='image/{}'.format(ext[1:]))
  193. except Exception as e:
  194. return {"status": 400, "data": [], "msg": "Get Failed!"}
  195. # return response
  196. return {"status": 200, "data": [uid], "msg": ""}
  197. # 删除 鉴权 current_user: User = Depends(get_current_active_user)
  198. @app.delete("/file/{uid}")
  199. async def delete_file(uid: str):
  200. try:
  201. timestamp, ext = os.path.splitext(uid)
  202. timestamp = float(str(float(timestamp) / 10000000).ljust(18, "0"))
  203. object_path = str(time.localtime(timestamp).tm_year) + "-" + str(time.localtime(timestamp).tm_mon).rjust(2,
  204. "0") + "-" \
  205. + str(time.localtime(timestamp).tm_mday).rjust(2, "0") + "/{}".format(uid)
  206. minio_client.get_object("image", object_path)
  207. minio_client.remove_object("image", object_path)
  208. return {"status": 200, "data": [], "msg": "Delete Success!"}
  209. except:
  210. return {"status": 404, "data": [], "msg": "Not Found"}
  211. client = TestClient()
  212. def test_create_file():
  213. file = {"file": open(r"E:\wallhaven_pic\wallhaven-5gw639.jpg","rb") }
  214. response = client.post(f"/file",files=file)
  215. assert response.json() == {
  216. "status": 200,
  217. "data": [
  218. PIC_NAME
  219. ],
  220. "msg": ""
  221. }
  222. def test_download_file():
  223. response = client.get(f"/file/{PIC_NAME}")
  224. assert response.status_code == 200
  225. assert response.json() == {
  226. "status": 200,
  227. "data": [
  228. PIC_NAME
  229. ],
  230. "msg": ""
  231. }
  232. def test_delete_file():
  233. response = client.delete(f"/file/{PIC_NAME}")
  234. assert response.status_code == 200
  235. assert response.json() == {"status": 200, "data": [], "msg": "Delete Success!"}