123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220 |
- from fastapi import FastAPI, Request, Form, File, UploadFile
- from fastapi.templating import Jinja2Templates
- from pydantic import BaseModel
- from typing import List, Optional
- from sx_utils import web_try
- import cv2
- import numpy as np
- import torch
- import base64
- import random
- import sys
- from sx_utils import format_print
- YOLO_DIR = './yolov7'
- WEIGHTS = './yiliv7_718.pt'
- app = FastAPI()
- templates = Jinja2Templates(directory='templates')
- format_print()
- model_selection_options = ['ocr-layout']
- model_dict = {model_name: None for model_name in model_selection_options} # set up model cache
- print(f'是否可使用GPU=======>{torch.cuda.is_available()}')
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- colors = [tuple([random.randint(0, 255) for _ in range(3)]) for _ in range(100)] # for bbox plotting
- if model_dict['ocr-layout'] is None:
- model_dict['ocr-layout'] = model = torch.hub.load(YOLO_DIR, 'custom', WEIGHTS, source='local').to(device)
- print("========>模型加载成功")
- ##############################################
- # -------------GET Request Routes--------------
- ##############################################
- @app.get("/")
- def home(request: Request):
- ''' Returns html jinja2 template render for home page form
- '''
- return templates.TemplateResponse('home.html', {
- "request": request,
- "model_selection_options": model_selection_options,
- })
- @app.get("/drag_and_drop_detect")
- def drag_and_drop_detect(request: Request):
- ''' drag_and_drop_detect detect page. Uses a drag and drop
- file interface to upload files to the server, then renders
- the image + bboxes + labels on HTML canvas.
- '''
- return templates.TemplateResponse('drag_and_drop_detect.html',
- {"request": request,
- "model_selection_options": model_selection_options,
- })
- ##############################################
- # ------------POST Request Routes--------------
- ##############################################
- @app.post("/")
- def detect_via_web_form(request: Request,
- file_list: List[UploadFile] = File(...),
- model_name: str = Form(...),
- img_size: int = Form(1824)):
- '''
- Requires an image file upload, model name (ex. yolov5s). Optional image size parameter (Default 1824).
- Intended for human (non-api) users.
- Returns: HTML template render showing bbox data and base64 encoded image
- '''
- # assume input validated properly if we got here
- if model_dict[model_name] is None:
- model_dict[model_name] = model = torch.hub.load(YOLO_DIR, 'custom', WEIGHTS, source='local').to(device)
- img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
- for file in file_list]
- # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB
- # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
- img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
- results = model_dict[model_name](img_batch_rgb, size=img_size)
- json_results = results_to_json(results, model_dict[model_name])
- img_str_list = []
- # plot bboxes on the image
- for img, bbox_list in zip(img_batch, json_results):
- for bbox in bbox_list:
- label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
- plot_one_box(bbox['bbox'], img, label=label,
- color=colors[int(bbox['class'])], line_thickness=3)
- img_str_list.append(base64EncodeImage(img))
- # escape the apostrophes in the json string representation
- encoded_json_results = str(json_results).replace("'", r"\'").replace('"', r'\"')
- return templates.TemplateResponse('show_results.html', {
- 'request': request,
- 'bbox_image_data_zipped': zip(img_str_list, json_results), # unzipped in jinja2 template
- 'bbox_data_str': encoded_json_results,
- })
- @app.post("/detect")
- @web_try()
- def detect_via_api(request: Request,
- file_list: List[UploadFile] = File(...),
- model_name: str = Form(...),
- img_size: Optional[int] = Form(1920),
- download_image: Optional[bool] = Form(False)):
- '''
- Requires an image file upload, model name (ex. yolov5s).
- Optional image size parameter (Default 1824)
- Optional download_image parameter that includes base64 encoded image(s) with bbox's drawn in the json response
- Returns: JSON results of running YOLOv5 on the uploaded image. If download_image parameter is True, images with
- bboxes drawn are base64 encoded and returned inside the json response.
- Intended for API usage.
- '''
- img_batch = [cv2.imdecode(np.fromstring(file.file.read(), np.uint8), cv2.IMREAD_COLOR)
- for file in file_list]
- # create a copy that corrects for cv2.imdecode generating BGR images instead of RGB,
- # using cvtColor instead of [...,::-1] to keep array contiguous in RAM
- # 转换图片格式
- img_batch_rgb = [cv2.cvtColor(img, cv2.COLOR_BGR2RGB) for img in img_batch]
- # 选用相关模型进行模版识别
- results = model_dict[model_name](img_batch_rgb, size=img_size)
- # 处理结果数据
- json_results = results_to_json(results, model_dict[model_name])
- # 如果需要下载图片,在图片上绘制框
- if download_image:
- for idx, (img, bbox_list) in enumerate(zip(img_batch, json_results)):
- for bbox in bbox_list:
- label = f'{bbox["class_name"]} {bbox["confidence"]:.2f}'
- plot_one_box(bbox['bbox'], img, label=label,
- color=colors[int(bbox['class'])], line_thickness=3)
- payload = {'image_base64': base64EncodeImage(img)}
- json_results[idx].append(payload)
- encoded_json_results = str(json_results).replace("'", r'"')
- return encoded_json_results
- ##############################################
- # --------------Helper Functions---------------
- ##############################################
- def results_to_json(results, model):
- ''' Converts yolo model output to json (list of list of dicts)'''
- return [
- [
- {
- "class": int(pred[5]),
- "class_name": model.model.names[int(pred[5])],
- "bbox": [int(x) for x in pred[:4].tolist()], # convert bbox results to int from float
- "confidence": float(pred[4]),
- }
- for pred in result
- ]
- for result in results.xyxy
- ]
- # 在图像上绘制框
- def plot_one_box(x, im, color=(128, 128, 128), label=None, line_thickness=3):
- # Directly copied from: https://github.com/ultralytics/yolov5/blob/cd540d8625bba8a05329ede3522046ee53eb349d/utils/plots.py
- # Plots one bounding box on image 'im' using OpenCV
- assert im.data.contiguous, 'Image not contiguous. Apply np.ascontiguousarray(im) to plot_on_box() input image.'
- tl = line_thickness or round(0.002 * (im.shape[0] + im.shape[1]) / 2) + 1 # line/font thickness
- c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
- cv2.rectangle(im, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
- if label:
- tf = max(tl - 1, 1) # font thickness
- t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
- c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
- cv2.rectangle(im, c1, c2, color, -1, cv2.LINE_AA) # filled
- cv2.putText(im, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)
- def base64EncodeImage(img):
- ''' Takes an input image and returns a base64 encoded string representation of that image (jpg format)'''
- _, im_arr = cv2.imencode('.jpg', img)
- im_b64 = base64.b64encode(im_arr.tobytes()).decode('utf-8')
- return im_b64
- @app.get("/ping", description="健康检查")
- def ping():
- print("->ping")
- return "pong!"
- # if __name__ == '__main__':
- # import uvicorn
- # import argparse
- #
- # parser = argparse.ArgumentParser()
- # parser.add_argument('--host', default='localhost')
- # parser.add_argument('--port', default=8080)
- # parser.add_argument('--precache-models', action='store_true',
- # help='Pre-cache all models in memory upon initialization, otherwise dynamically caches models')
- # opt = parser.parse_args()
- #
- # # if opt.precache_models:
- # # model_dict = {model_name: torch.hub.load('ultralytics/yolov5', model_name, pretrained=True)
- # # for model_name in model_selection_options}
- #
- # app_str = 'server:app' # make the app string equal to whatever the name of this file is
- # uvicorn.run(app_str, host=opt.host, port=int(opt.port), reload=True)
|