server.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. from fastapi import FastAPI, Request, Form, File, UploadFile
  2. from fastapi.templating import Jinja2Templates
  3. from pydantic import BaseModel
  4. from typing import List, Optional
  5. from sx_utils import web_try
  6. import cv2
  7. import numpy as np
  8. import torch
  9. import base64
  10. import random
  11. import sys
  12. import logging
  13. YOLO_DIR = '/workspace/yolov5'
  14. # WEIGHTS = '/data/yolov5/runs/train/yolov5x_layout_reuslt37/weights/best.pt'
  15. # WEIGHTS = '/workspace/best.pt'
  16. # WEIGHTS = '/workspace/yili.pt'
  17. # WEIGHTS = '/workspace/best2.pt'
  18. WEIGHTS = '/workspace/11-17.pt'
  19. logger = logging.getLogger('log')
  20. logger.setLevel(logging.DEBUG)
  21. # 调用模块时,如果错误引用,比如多次调用,每次会添加Handler,造成重复日志,这边每次都移除掉所有的handler,后面在重新添加,可以解决这类问题
  22. while logger.hasHandlers():
  23. for i in logger.handlers:
  24. logger.removeHandler(i)
  25. # file log 写入文件配置
  26. formatter = logging.Formatter('%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s') # 日志的格式
  27. fh = logging.FileHandler(r'/var/log/be.log', encoding='utf-8') # 日志文件路径文件名称,编码格式
  28. fh.setLevel(logging.DEBUG) # 日志打印级别
  29. fh.setFormatter(formatter)
  30. logger.addHandler(fh)
  31. # console log 控制台输出控制
  32. ch = logging.StreamHandler(sys.stdout)
  33. ch.setLevel(logging.DEBUG)
  34. ch.setFormatter(formatter)
  35. logger.addHandler(ch)
  36. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  37. bl = torch.cuda.is_available()
  38. logger.info(f'是否可使用GPU=======>{bl}')
  39. app = FastAPI()
  40. templates = Jinja2Templates(directory='templates')
  41. model_selection_options = ['ocr-layout', 'ocr-logo']
  42. model_dict = {model_name: None for model_name in model_selection_options} # set up model cache
  43. colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] # for bbox plotting
  44. if model_dict['ocr-layout'] is None:
  45. model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
  46. logger.info("========>模型加载成功")
  47. # logo检测
  48. if model_dict['ocr-logo'] is None:
  49. model_dict['ocr-logo'] = torch.hub.load(YOLO_DIR, 'custom', path='/workspace/logo.pt', source='local').to(device)
  50. logger.info("========>模型加载成功")
  51. ##############################################
  52. # -------------GET Request Routes--------------
  53. ##############################################
  54. @app.get("/")
  55. def home(request: Request):
  56. ''' Returns html jinja2 template render for home page form
  57. '''
  58. return templates.TemplateResponse('home.html', {
  59. "request": request,
  60. "model_selection_options": model_selection_options,
  61. })
  62. @app.get("/drag_and_drop_detect")
  63. def drag_and_drop_detect(request: Request):
  64. ''' drag_and_drop_detect detect page. Uses a drag and drop
  65. file interface to upload files to the server, then renders
  66. the image + bboxes + labels on HTML canvas.
  67. '''
  68. return templates.TemplateResponse('drag_and_drop_detect.html',
  69. {"request": request,
  70. "model_selection_options": model_selection_options,
  71. })
  72. ##############################################
  73. # ------------POST Request Routes--------------
  74. ##############################################
  75. @app.post("/")
  76. def detect_via_web_form(request: Request,
  77. file_list: List[UploadFile] = File(...),
  78. model_name: str = Form(...),
  79. img_size: int = Form(1824)):
  80. '''
  81. Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
  82. Intended for human (non-api) users.
  83. Returns: HTML template render showing bbox data and base64 encoded image
  84. '''
  85. # assume input validated properly if we got here
  86. if model_dict[model_name] is None:
  87. model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
  88. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  89. for file in file_list]
  90. # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
  91. # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  92. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  93. results = model_dict[model_name](img_batch_rgb, size=img_size)
  94. json_results = results_to_json(results, model_dict[model_name])
  95. img_str_list = []
  96. # plot bboxes on the image
  97. for img, bbox_list in zip(img_batch, json_results):
  98. for bbox in bbox_list:
  99. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  100. plot_one_box(bbox['bbox'], img, label=label,
  101. color=colors[int(bbox['class'])], line_thickness=3)
  102. img_str_list.append(base64EncodeImage(img))
  103. # escape the apostrophes in the json string representation
  104. encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
  105. return templates.TemplateResponse('show_results.html', {
  106. 'request': request,
  107. 'bbox_image_data_zipped': zip(img_str_list, json_results), # unzipped in jinja2 template
  108. 'bbox_data_str': encoded_json_results,
  109. })
  110. @app.post("/detect")
  111. @web_try()
  112. def detect_via_api(request: Request,
  113. file_list: List[UploadFile] = File(...),
  114. model_name: str = Form(...),
  115. img_size: Optional[int] = Form(1920),
  116. download_image: Optional[bool] = Form(False)):
  117. '''
  118. Requires an image file upload, model name (ex. yolov5s).
  119. Optional image size parameter (Default 1824)
  120. Optional download_image parameter that includes base64 encoded image(s) with bbox's drawn in the json response
  121. Returns: JSON results of running YOLOv5 on the uploaded image. If download_image parameter is True, images with
  122. bboxes drawn are base64 encoded and returned inside the json response.
  123. Intended for API usage.
  124. '''
  125. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  126. for file in file_list]
  127. # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
  128. # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  129. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  130. results = model_dict[model_name](img_batch_rgb, size=img_size)
  131. json_results = results_to_json(results, model_dict[model_name])
  132. if download_image:
  133. for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
  134. for bbox in bbox_list:
  135. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  136. plot_one_box(bbox['bbox'], img, label=label,
  137. color=colors[int(bbox['class'])], line_thickness=3)
  138. payload = {'image_base64': base64EncodeImage(img)}
  139. json_results[idx].append(payload)
  140. encoded_json_results = str(json_results).replace("'", r'"')
  141. return encoded_json_results
  142. ##############################################
  143. # --------------Helper Functions---------------
  144. ##############################################
  145. def results_to_json(results, model):
  146. ''' Converts yolo model output to json (list of list of dicts)'''
  147. return [
  148. [
  149. {
  150. "class": int(pred[5]),
  151. "class_name": model.model.names[int(pred[5])],
  152. "bbox": [int(x) for x in pred[:4].tolist()], # convert bbox results to int from float
  153. "confidence": float(pred[4]),
  154. }
  155. for pred in result
  156. ]
  157. for result in results.xyxy
  158. ]
  159. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  160. # Directly copied from: https://github.com/ultralytics/yolov5/blob/cd540d8625bba8a05329ede3522046ee53eb349d/utils/plots.py
  161. # Plots one bounding box on image 'im' using OpenCV
  162. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  163. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  164. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  165. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  166. if label:
  167. tf = max(tl - 1, 1) # font thickness
  168. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  169. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  170. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  171. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  172. def base64EncodeImage(img):
  173. ''' Takes an input image and returns a base64 encoded string representation of that image (jpg format)'''
  174. _, im_arr = cv2.imencode('.jpg', img)
  175. im_b64 = base64.b64encode(im_arr.tobytes()).decode('utf-8')
  176. return im_b64
  177. @app.get("/ping", description="健康检查")
  178. def ping():
  179. logger.info("->ping")
  180. return "pong!"
  181. if __name__ == '__main__':
  182. import uvicorn
  183. import argparse
  184. parser = argparse.ArgumentParser()
  185. parser.add_argument('--host', default='localhost')
  186. parser.add_argument('--port', default=8080)
  187. parser.add_argument('--precache-models', action='store_true',
  188. help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
  189. opt = parser.parse_args()
  190. # if opt.precache_models:
  191. # model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
  192. # for model_name in model_selection_options}
  193. app_str = 'server:app' # make the app string equal to whatever the name of this file is
  194. uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)