server.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from fastapi import FastAPI, Request, Form, File, UploadFile
  2. from fastapi.templating import Jinja2Templates
  3. from pydantic import BaseModel
  4. from typing import List, Optional
  5. import cv2
  6. import numpy as np
  7. import torch
  8. import base64
  9. import random
  10. YOLO_DIR = '/workspace/yolov5'
  11. # WEIGHTS = '/data/yolov5/runs/train/yolov5x_layout_reuslt37/weights/best.pt'
  12. WEIGHTS = '/workspace/best.pt'
  13. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  14. app = FastAPI()
  15. templates = Jinja2Templates(directory = 'templates')
  16. model_selection_options = ['ocr-layout']
  17. model_dict = {model_name: None for model_name in model_selection_options} #set up model cache
  18. colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] #for bbox plotting
  19. ##############################################
  20. #-------------GET Request Routes--------------
  21. ##############################################
  22. @app.get("/")
  23. def home(request: Request):
  24. ''' Returns html jinja2 template render for home page form
  25. '''
  26. return templates.TemplateResponse('home.html', {
  27. "request": request,
  28. "model_selection_options": model_selection_options,
  29. })
  30. @app.get("/drag_and_drop_detect")
  31. def drag_and_drop_detect(request: Request):
  32. ''' drag_and_drop_detect detect page. Uses a drag and drop
  33. file interface to upload files to the server, then renders
  34. the image + bboxes + labels on HTML canvas.
  35. '''
  36. return templates.TemplateResponse('drag_and_drop_detect.html',
  37. {"request": request,
  38. "model_selection_options": model_selection_options,
  39. })
  40. ##############################################
  41. #------------POST Request Routes--------------
  42. ##############################################
  43. @app.post("/")
  44. def detect_via_web_form(request: Request,
  45. file_list: List[UploadFile] = File(...),
  46. model_name: str = Form(...),
  47. img_size: int = Form(1824)):
  48. '''
  49. Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
  50. Intended for human (non-api) users.
  51. Returns: HTML template render showing bbox data and base64 encoded image
  52. '''
  53. #assume input validated properly if we got here
  54. if model_dict[model_name] is None:
  55. model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
  56. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  57. for file in file_list]
  58. #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
  59. #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  60. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  61. results = model_dict[model_name](img_batch_rgb, size = img_size)
  62. json_results = results_to_json(results,model_dict[model_name])
  63. img_str_list = []
  64. #plot bboxes on the image
  65. for img, bbox_list in zip(img_batch, json_results):
  66. for bbox in bbox_list:
  67. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  68. plot_one_box(bbox['bbox'], img, label=label,
  69. color=colors[int(bbox['class'])], line_thickness=3)
  70. img_str_list.append(base64EncodeImage(img))
  71. #escape the apostrophes in the json string representation
  72. encoded_json_results = str(json_results).replace("'",r"\'").replace('"',r'\"')
  73. return templates.TemplateResponse('show_results.html', {
  74. 'request': request,
  75. 'bbox_image_data_zipped': zip(img_str_list,json_results), #unzipped in jinja2 template
  76. 'bbox_data_str': encoded_json_results,
  77. })
  78. @app.post("/detect")
  79. def detect_via_api(request: Request,
  80. file_list: List[UploadFile] = File(...),
  81. model_name: str = Form(...),
  82. img_size: Optional[int] = Form(1824),
  83. download_image: Optional[bool] = Form(False)):
  84. '''
  85. Requires an image file upload, model name (ex. yolov5s).
  86. Optional image size parameter (Default 1824)
  87. Optional download_image parameter that includes base64 encoded image(s) with bbox's drawn in the json response
  88. Returns: JSON results of running YOLOv5 on the uploaded image. If download_image parameter is True, images with
  89. bboxes drawn are base64 encoded and returned inside the json response.
  90. Intended for API usage.
  91. '''
  92. if model_dict[model_name] is None:
  93. model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', path=WEIGHTS, source='local').to(device)
  94. img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
  95. for file in file_list]
  96. #create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
  97. #using cvtColor instead of [...,::-1] to keep array contiguous in RAM
  98. img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
  99. results = model_dict[model_name](img_batch_rgb, size = img_size)
  100. json_results = results_to_json(results,model_dict[model_name])
  101. if download_image:
  102. for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
  103. for bbox in bbox_list:
  104. label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
  105. plot_one_box(bbox['bbox'], img, label=label,
  106. color=colors[int(bbox['class'])], line_thickness=3)
  107. payload = {'image_base64':base64EncodeImage(img)}
  108. json_results[idx].append(payload)
  109. encoded_json_results = str(json_results).replace("'",r'"')
  110. return encoded_json_results
  111. ##############################################
  112. #--------------Helper Functions---------------
  113. ##############################################
  114. def results_to_json(results, model):
  115. ''' Converts yolo model output to json (list of list of dicts)'''
  116. return [
  117. [
  118. {
  119. "class": int(pred[5]),
  120. "class_name": model.model.names[int(pred[5])],
  121. "bbox": [int(x) for x in pred[:4].tolist()], #convert bbox results to int from float
  122. "confidence": float(pred[4]),
  123. }
  124. for pred in result
  125. ]
  126. for result in results.xyxy
  127. ]
  128. def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
  129. # Directly copied from: https://github.com/ultralytics/yolov5/blob/cd540d8625bba8a05329ede3522046ee53eb349d/utils/plots.py
  130. # Plots one bounding box on image 'im' using OpenCV
  131. assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
  132. tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
  133. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  134. cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
  135. if label:
  136. tf = max(tl - 1, 1) # font thickness
  137. t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
  138. c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
  139. cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
  140. cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
  141. def base64EncodeImage(img):
  142. ''' Takes an input image and returns a base64 encoded string representation of that image (jpg format)'''
  143. _, im_arr = cv2.imencode('.jpg', img)
  144. im_b64 = base64.b64encode(im_arr.tobytes()).decode('utf-8')
  145. return im_b64
  146. if __name__ == '__main__':
  147. import uvicorn
  148. import argparse
  149. parser = argparse.ArgumentParser()
  150. parser.add_argument('--host', default = 'localhost')
  151. parser.add_argument('--port', default = 8000)
  152. parser.add_argument('--precache-models', action='store_true',
  153. help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
  154. opt = parser.parse_args()
  155. # if opt.precache_models:
  156. # model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
  157. # for model_name in model_selection_options}
  158. app_str = 'server:app' #make the app string equal to whatever the name of this file is
  159. uvicorn.run(app_str, host= opt.host, port=int(opt.port), reload=True)