server.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import random
  2. from fastapi import FastAPI, Request, Form, File, UploadFile
  3. from fastapi.templating import Jinja2Templates
  4. from typing import Dict, List, Optional
  5. from sx_utils import web_try
  6. import cv2
  7. import numpy as np
  8. import base64
  9. from core.predictor import predict_img
  10. from core.layout import LayoutBox
  11. from sx_utils import format_print
  12. app = FastAPI()
  13. templates = Jinja2Templates(directory='templates')
  14. format_print()
  15. colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] # for bbox plotting
  16. model_selection_options = [
  17. 'ocr-layout',
  18. 'ocr-layout-paddle'
  19. ]
  20. clazz_names = [
  21. "code",
  22. "logo_hb",
  23. "logo_qz",
  24. "logo_rain",
  25. "logo_stack",
  26. "logo_sun",
  27. "logo_up",
  28. "logo_ys",
  29. "style",
  30. "table",
  31. "text",
  32. "title",
  33. ]
  34. ##############################################
  35. # -------------GET Request Routes--------------
  36. ##############################################
  37. @app.get("/")
  38. def home(request: Request):
  39. ''' Returns html jinja2 template render for home page form
  40. '''
  41. return templates.TemplateResponse('home.html', {
  42. "request": request,
  43. "model_selection_options": model_selection_options,
  44. })
  45. @app.get("/drag_and_drop_detect")
  46. def drag_and_drop_detect(request: Request):
  47. ''' drag_and_drop_detect detect page. Uses a drag and drop
  48. file interface to upload files to the server, then renders
  49. the image + bboxes + labels on HTML canvas.
  50. '''
  51. return templates.TemplateResponse('drag_and_drop_detect.html',
  52. {"request": request,
  53. "model_selection_options": model_selection_options,
  54. })
  55. ##############################################
  56. # ------------POST Request Routes--------------
  57. ##############################################
  58. @app.post("/")
  59. def detect_via_web_form(request: Request,
  60. file_list: List[UploadFile] = File(...),
  61. model_name: str = Form(...),
  62. img_size: int = Form(1824),
  63. multi_scale: bool = Form(False),
  64. ):
  65. '''
  66. Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
  67. Intended for human (non-api) users.
  68. Returns: HTML template render showing bbox data and base64 encoded image
  69. '''
  70. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  71. for file in file_list]
  72. # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
  73. # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  74. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  75. results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
  76. json_results = boxes_list_to_json(results, clazz_names)
  77. img_str_list = []
  78. # plot bboxes on the image
  79. for img, bbox_list in zip(img_batch, json_results):
  80. for bbox in bbox_list:
  81. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  82. plot_one_box(bbox['bbox'], img, label=label,
  83. color=colors[int(bbox['class'])], line_thickness=3)
  84. img_str_list.append(base64EncodeImage(img))
  85. # escape the apostrophes in the json string representation
  86. encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
  87. return templates.TemplateResponse('show_results.html', {
  88. 'request': request,
  89. 'bbox_image_data_zipped': zip(img_str_list, json_results), # unzipped in jinja2 template
  90. 'bbox_data_str': encoded_json_results,
  91. })
  92. @app.post("/detect")
  93. @web_try()
  94. def detect_via_api(request: Request,
  95. file_list: List[UploadFile] = File(...),
  96. model_name: str = Form(...),
  97. img_size: int = Form(1920),
  98. multi_scale: bool = Form(False),
  99. download_image: Optional[bool] = Form(False)):
  100. '''
  101. Requires an image file upload, model name (ex. yolov5s).
  102. Optional image size parameter (Default 1920)
  103. Optional download_image parameter that includes base64 encoded image(s) with bbox's drawn in the json response
  104. Returns: JSON results of running YOLOv5 on the uploaded image. If download_image parameter is True, images with
  105. bboxes drawn are base64 encoded and returned inside the json response.
  106. Intended for API usage.
  107. '''
  108. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  109. for file in file_list]
  110. # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
  111. # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  112. # 转换图片格式
  113. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  114. # 选用相关模型进行模版识别
  115. results = [predict_img(img, model_name, img_size, multi_scale) for img in img_batch_rgb]
  116. # 处理结果数据
  117. json_results = boxes_list_to_json(results, clazz_names)
  118. # 如果需要下载图片,在图片上绘制框
  119. if download_image:
  120. for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
  121. for bbox in bbox_list:
  122. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  123. plot_one_box(bbox['bbox'], img, label=label,
  124. color=colors[int(bbox['class'])], line_thickness=3)
  125. payload = {'image_base64': base64EncodeImage(img)}
  126. json_results[idx].append(payload)
  127. encoded_json_results = str(json_results).replace("'", r'"')
  128. return encoded_json_results
  129. ##############################################
  130. # --------------Helper Functions---------------
  131. ##############################################
  132. def results_to_json(results, model):
  133. ''' Converts yolo model output to json (list of list of dicts)'''
  134. return [
  135. [
  136. {
  137. "class": int(pred[5]),
  138. "class_name": model.model.names[int(pred[5])],
  139. "bbox": [int(x) for x in pred[:4].tolist()], # convert bbox results to int from float
  140. "confidence": float(pred[4]),
  141. }
  142. for pred in result
  143. ]
  144. for result in results.xyxy
  145. ]
  146. # 在图像上绘制框
  147. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  148. # Directly copied from: https://github.com/ultralytics/yolov5/blob/cd540d8625bba8a05329ede3522046ee53eb349d/utils/plots.py
  149. # Plots one bounding box on image 'im' using OpenCV
  150. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  151. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  152. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  153. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  154. if label:
  155. tf = max(tl - 1, 1) # font thickness
  156. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  157. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  158. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  159. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  160. def base64EncodeImage(img):
  161. ''' Takes an input image and returns a base64 encoded string representation of that image (jpg format)'''
  162. _, im_arr = cv2.imencode('.jpg', img)
  163. im_b64 = base64.b64encode(im_arr.tobytes()).decode('utf-8')
  164. return im_b64
  165. def boxes_list_to_json(boxes_list: List[List[LayoutBox]], clazz_names: List[str]) -> List[List[Dict]]:
  166. for boxes in boxes_list:
  167. for box in boxes:
  168. box.clazz_name = clazz_names[box.clazz]
  169. return [
  170. [
  171. box.to_service_dict()
  172. for box in boxes
  173. ]
  174. for boxes in boxes_list
  175. ]
  176. @app.get("/ping", description="健康检查")
  177. def ping():
  178. print("->ping")
  179. return "pong!"