450 lines
17 KiB
Python
450 lines
17 KiB
Python
|
import time
|
|||
|
|
|||
|
import faiss
|
|||
|
from flask import Flask, render_template, request, jsonify, send_from_directory
|
|||
|
from markupsafe import escape, escape_silent
|
|||
|
from werkzeug.utils import secure_filename
|
|||
|
|
|||
|
from anti import anti_spoofing, load_anti_model
|
|||
|
from face_api import load_arcface_model, load_npy, findOne, load_image, face_verification, findAll, add_one_to_database, \
|
|||
|
get_claster_tmp_file_embedding, cluster, detect_video
|
|||
|
from gender_age import set_gender_conf, gender_age, load_gender_model
|
|||
|
from retinaface_detect import load_retinaface_model, detect_one, set_retinaface_conf
|
|||
|
from werkzeug.exceptions import RequestEntityTooLarge
|
|||
|
import zipfile
|
|||
|
import os
|
|||
|
import shutil
|
|||
|
import re
|
|||
|
import numpy as np
|
|||
|
import torch
|
|||
|
|
|||
|
ALLOWED_IMG = set(['png', 'jpg', 'jpeg', 'bmp', 'PNG', 'JPG', 'JPEG'])
|
|||
|
# 限制上传的图片最大为10M
|
|||
|
ALLOWED_IMG_SIZE = 10 * 1024 * 1024
|
|||
|
ALLOWED_FILE = set(['zip'])
|
|||
|
ALLOWED_VIDEO = set(['mp4'])
|
|||
|
app = Flask(__name__)
|
|||
|
|
|||
|
# 限制上传的文件最大为100M
|
|||
|
app.config['MAX_CONTENT_LENGTH'] = 100 * 1024 * 1024
|
|||
|
# 使用jsonify,避免中文乱码
|
|||
|
app.config['JSON_AS_ASCII'] = False
|
|||
|
|
|||
|
# 设置使用CPU或者GPU(传入cuda)
|
|||
|
cpu_or_cuda = "cuda" if torch.cuda.is_available() else "cpu"
|
|||
|
# 加载人脸识别模型
|
|||
|
arcface_model = load_arcface_model("./model/backbone100.pth", cpu_or_cuda=cpu_or_cuda)
|
|||
|
# 加载人脸检测模型
|
|||
|
retinaface_args = set_retinaface_conf(cpu_or_cuda=cpu_or_cuda)
|
|||
|
retinaface_model = load_retinaface_model(retinaface_args)
|
|||
|
# 加载性别年龄识别模型
|
|||
|
gender_args = set_gender_conf()
|
|||
|
gender_model = load_gender_model(gender_args, 'fc1')
|
|||
|
anti_spoofing_model_path = "model/anti_spoof_models"
|
|||
|
anti_model = load_anti_model(anti_spoofing_model_path, cpu_or_cuda)
|
|||
|
|
|||
|
|
|||
|
# 读取人脸库
|
|||
|
|
|||
|
|
|||
|
@app.route('/')
|
|||
|
def index():
|
|||
|
return "model"
|
|||
|
|
|||
|
|
|||
|
@app.route('/hello')
|
|||
|
@app.route('/hello/<name>')
|
|||
|
def hello(name=None):
|
|||
|
return render_template('hello.html', name=name)
|
|||
|
|
|||
|
|
|||
|
@app.route('/user', methods=['GET'])
|
|||
|
def show_user_name():
|
|||
|
return request.args.get('username', '')
|
|||
|
|
|||
|
|
|||
|
# 创建返回的json数据
|
|||
|
# 函数参数用是否=None判断,函数中定义的data,result用true,false判断
|
|||
|
def create_response(status, name=None, distance=None, verification=None, gender=None, age=None, num=None, anti=None,
|
|||
|
score=None, box_and_point=None, addfile_names=None,fail_names=None,database_name=None,msg=None,
|
|||
|
delete_names=None,not_exist_names=None):
|
|||
|
# res为总的json结构体
|
|||
|
res = {}
|
|||
|
res['status'] = status
|
|||
|
|
|||
|
data = {}
|
|||
|
try:
|
|||
|
data["box_and_point"] = box_and_point.tolist()
|
|||
|
except AttributeError:
|
|||
|
pass
|
|||
|
if anti != None and score != None:
|
|||
|
liveness = {}
|
|||
|
liveness["spoofing"] = anti
|
|||
|
liveness['score'] = score
|
|||
|
data['liveness'] = liveness
|
|||
|
if distance!=None:
|
|||
|
data['distance'] = float(distance)
|
|||
|
if verification!=None:
|
|||
|
data['verification'] = verification
|
|||
|
if num!=None:
|
|||
|
data['number'] = num
|
|||
|
if gender!=None:
|
|||
|
data['gender'] = gender
|
|||
|
if age!=None:
|
|||
|
data['age'] = age
|
|||
|
if name!=None:
|
|||
|
data['name'] = name
|
|||
|
if data:
|
|||
|
res['data'] = data
|
|||
|
|
|||
|
# 数据库增删接口返回数据
|
|||
|
result = {}
|
|||
|
if msg!=None:
|
|||
|
res['msg'] = msg
|
|||
|
if database_name!=None:
|
|||
|
result['database_name'] = database_name
|
|||
|
# 增加人脸
|
|||
|
if addfile_names!=None or fail_names!=None:
|
|||
|
result['success_names'] = addfile_names
|
|||
|
result['fail_names'] = fail_names
|
|||
|
# 删除人脸
|
|||
|
if delete_names!=None or not_exist_names!=None:
|
|||
|
result['delete_names'] = delete_names
|
|||
|
result['not_exist_names'] = not_exist_names
|
|||
|
if result:
|
|||
|
res['result'] = result
|
|||
|
|
|||
|
return jsonify(res)
|
|||
|
|
|||
|
|
|||
|
# 创建cluster接口返回的json数据
|
|||
|
def create_cluster_response(status, all_cluster):
|
|||
|
res = {}
|
|||
|
data = {}
|
|||
|
for index, cluster in enumerate(all_cluster):
|
|||
|
data['cluster' + str(index)] = cluster
|
|||
|
res['data'] = data
|
|||
|
res['status'] = status
|
|||
|
return res
|
|||
|
|
|||
|
|
|||
|
# 检查上传文件格式
|
|||
|
def check_file_format(file_name, format):
|
|||
|
if '.' in file_name:
|
|||
|
file_format = file_name.rsplit('.')[1]
|
|||
|
if file_format in format:
|
|||
|
return True
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
# 检查img大小,大于10M抛出异常
|
|||
|
def check_img_size(img_path):
|
|||
|
fsize = os.path.getsize(img_path)
|
|||
|
if fsize > ALLOWED_IMG_SIZE:
|
|||
|
raise RequestEntityTooLarge
|
|||
|
|
|||
|
|
|||
|
# 解压zip文件存到某路径:
|
|||
|
def unzip(zip_src, dst_dir):
|
|||
|
f = zipfile.is_zipfile(zip_src)
|
|||
|
if f:
|
|||
|
fz = zipfile.ZipFile(zip_src, 'r')
|
|||
|
for file in fz.namelist():
|
|||
|
fz.extract(file, dst_dir)
|
|||
|
return True
|
|||
|
else:
|
|||
|
return False
|
|||
|
|
|||
|
|
|||
|
# 解压文件
|
|||
|
def un_zip(file_path, output_path):
|
|||
|
zip_file = zipfile.ZipFile(file_path)
|
|||
|
if os.path.isdir(output_path):
|
|||
|
pass
|
|||
|
else:
|
|||
|
os.mkdir(output_path)
|
|||
|
zip_file.extractall(output_path)
|
|||
|
# for names in zip_file.namelist():
|
|||
|
# zip_file.extract(names,output_path)
|
|||
|
zip_file.close()
|
|||
|
|
|||
|
|
|||
|
# 人脸识别、性别年龄识别
|
|||
|
@app.route('/recognition', methods=['POST'])
|
|||
|
def recognition():
|
|||
|
try:
|
|||
|
f = request.files['file_name']
|
|||
|
if f and check_file_format(f.filename, ALLOWED_IMG):
|
|||
|
img_path = './img/recognition/' + secure_filename(f.filename)
|
|||
|
f.save(img_path)
|
|||
|
check_img_size(img_path)
|
|||
|
# img3 = load_image('./file/'+secure_filename(f.filename))
|
|||
|
# img3 = torch.from_numpy(img3)
|
|||
|
tic = time.time()
|
|||
|
img3, box_and_point = detect_one(img_path, retinaface_model, retinaface_args)
|
|||
|
print('detect time: {:.4f}'.format(time.time() - tic))
|
|||
|
if len(img3) == 0:
|
|||
|
return create_response('no face')
|
|||
|
elif len(img3) > 1:
|
|||
|
namelist = findAll(img3, arcface_model, index, database_name_list, cpu_or_cuda)
|
|||
|
gender_list, age_list = [], []
|
|||
|
# gender_list, age_list = gender_age(img3, gender_model)
|
|||
|
res = create_response('success', namelist, gender=gender_list, age=age_list,
|
|||
|
box_and_point=box_and_point)
|
|||
|
else:
|
|||
|
b = box_and_point[0]
|
|||
|
w = b[2] - b[0]
|
|||
|
h = b[3] - b[1]
|
|||
|
b[2] = w
|
|||
|
b[3] = h
|
|||
|
label, value = anti_spoofing(img_path, anti_spoofing_model_path, cpu_or_cuda, np.array(b[:4], int),
|
|||
|
anti_model)
|
|||
|
# print(index,database_name_list)
|
|||
|
name, distance = findOne(img3, arcface_model, index, database_name_list, cpu_or_cuda)
|
|||
|
gender_list, age_list = [], []
|
|||
|
# gender_list, age_list = gender_age(img3, gender_model)
|
|||
|
res = create_response('success', name, gender=gender_list, age=age_list, distance=distance,
|
|||
|
anti=label, score=value, box_and_point=box_and_point)
|
|||
|
return res
|
|||
|
else:
|
|||
|
return create_response('png jpg jpeg bmp are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 两张图片比对
|
|||
|
@app.route('/compare', methods=['POST'])
|
|||
|
def compare_file():
|
|||
|
try:
|
|||
|
file1 = request.files['file1_name']
|
|||
|
file2 = request.files['file2_name']
|
|||
|
if file1 and check_file_format(file1.filename, ALLOWED_IMG) and file2 and check_file_format(file2.filename,
|
|||
|
ALLOWED_IMG):
|
|||
|
img1_path = './img/compare/' + secure_filename(file1.filename)
|
|||
|
img2_path = './img/compare/' + secure_filename(file2.filename)
|
|||
|
file1.save(img1_path)
|
|||
|
file2.save(img2_path)
|
|||
|
check_img_size(img1_path)
|
|||
|
check_img_size(img2_path)
|
|||
|
img1, box_and_point1 = detect_one(img1_path, retinaface_model,
|
|||
|
retinaface_args)
|
|||
|
img2, box_and_point2 = detect_one(img2_path, retinaface_model, retinaface_args)
|
|||
|
if len(img1) == 1 and len(img2) == 1:
|
|||
|
result,distance = face_verification(img1, img2, arcface_model, cpu_or_cuda)
|
|||
|
print(result,distance)
|
|||
|
return create_response('success', verification=result,distance=distance)
|
|||
|
else:
|
|||
|
return create_response('image contains no face or more than 1 face')
|
|||
|
else:
|
|||
|
return create_response('png jpg jpeg bmp are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 数据库增加人脸,可实现向“现有/新建”数据库增加“单张/多张”人脸
|
|||
|
# 增和改
|
|||
|
@app.route('/databaseAdd', methods=['POST'])
|
|||
|
def DB_add_face():
|
|||
|
try:
|
|||
|
# 上传人脸图片(>=1)
|
|||
|
# key都为file_list,value为不同的值可实现批量上传图片
|
|||
|
upload_files = request.files.getlist("file_list")
|
|||
|
# '',[],{},0都可以视为False
|
|||
|
if not upload_files:
|
|||
|
msg = "上传文件为空"
|
|||
|
return create_response(0,msg=msg)
|
|||
|
database_name = request.form.get("database_name")
|
|||
|
database_path = "./Database/" + database_name + ".npy"
|
|||
|
if not os.path.exists(database_path):
|
|||
|
msg = "数据库不存在"
|
|||
|
return create_response(0,msg=msg)
|
|||
|
# 数据库中已存在的人名
|
|||
|
names = load_npy(database_path).keys()
|
|||
|
# print(names)
|
|||
|
|
|||
|
# 这是服务器上用于暂存上传图片的文件夹,每次上传前重建,使用后删除
|
|||
|
# 后面可根据需要改为定期删除
|
|||
|
file_temp_path = './img/uploadNew/'
|
|||
|
if not os.path.exists(file_temp_path):
|
|||
|
os.makedirs(file_temp_path)
|
|||
|
|
|||
|
# 正则表达式用于提取文件名中的中文,用于.npy中的keys
|
|||
|
r = re.compile('[\u4e00-\u9fa5]+')
|
|||
|
# 分别存取添加成功或失败的名字
|
|||
|
success_names = []
|
|||
|
fail_names = {}
|
|||
|
# 添加失败的两种情况:格式错误或已经存在
|
|||
|
format_wrong = []
|
|||
|
alreadyExist = []
|
|||
|
# 分别处理每一张图片,先判断格式对不对,再判断是否存在
|
|||
|
for file in upload_files:
|
|||
|
filename = file.filename
|
|||
|
name = r.findall(filename)[0]
|
|||
|
if file and check_file_format(filename, ALLOWED_IMG):
|
|||
|
if name in names:
|
|||
|
alreadyExist.append(name)
|
|||
|
continue
|
|||
|
save_path = file_temp_path + filename
|
|||
|
file.save(save_path)
|
|||
|
check_img_size(save_path)
|
|||
|
img_file, box_and_point = detect_one(save_path, retinaface_model, retinaface_args)
|
|||
|
add_one_to_database(img=img_file, model=arcface_model, name=name, database_path=database_path,
|
|||
|
cpu_or_cuda=cpu_or_cuda)
|
|||
|
success_names.append(name)
|
|||
|
else:
|
|||
|
format_wrong.append(name)
|
|||
|
continue
|
|||
|
shutil.rmtree(file_temp_path)
|
|||
|
# 如果有错误情况
|
|||
|
if format_wrong or alreadyExist:
|
|||
|
status = 0
|
|||
|
else:
|
|||
|
status = 1
|
|||
|
fail_names['formatWrong'] = format_wrong
|
|||
|
fail_names['alreadyExist'] = alreadyExist
|
|||
|
|
|||
|
return create_response(status=status,addfile_names=success_names,fail_names=fail_names,database_name=database_name,msg="新增人脸操作执行完成")
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response(0,msg='image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 数据库删除人脸,可实现在现有数据库中删除’单/多‘张人脸
|
|||
|
@app.route('/databaseDelete', methods=['POST'])
|
|||
|
def DB_delete_face():
|
|||
|
try:
|
|||
|
delete_names = request.form.getlist("delete_names")
|
|||
|
database_name = request.form.get("database_name")
|
|||
|
database_path = "./Database/" + database_name + ".npy"
|
|||
|
if not os.path.exists(database_path):
|
|||
|
msg = "数据库不存在"
|
|||
|
return create_response(0,msg=msg)
|
|||
|
if not delete_names:
|
|||
|
msg = "delete_names参数为空"
|
|||
|
return create_response(0,msg=msg)
|
|||
|
k_v = load_npy(database_path)
|
|||
|
print(k_v.keys())
|
|||
|
success_list = []
|
|||
|
fail_list = []
|
|||
|
for name in delete_names:
|
|||
|
if name in k_v.keys():
|
|||
|
del k_v[name]
|
|||
|
success_list.append(name)
|
|||
|
else:
|
|||
|
fail_list.append(name)
|
|||
|
continue
|
|||
|
np.save(database_path, k_v)
|
|||
|
status = 1
|
|||
|
if fail_list:
|
|||
|
status = 0
|
|||
|
return create_response(status=status,delete_names=success_list,not_exist_names=fail_list,database_name=database_name,
|
|||
|
msg="删除人脸操作完成")
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response(0,'image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 以图搜图接口:
|
|||
|
# 上传图片压缩包建图片库
|
|||
|
@app.route('/uploadZip', methods=['POST'])
|
|||
|
def upload_Zip():
|
|||
|
try:
|
|||
|
zip = request.files['zip_name']
|
|||
|
dst_dir = './img/search/'
|
|||
|
if unzip(zip, dst_dir):
|
|||
|
return create_response('upload zip success')
|
|||
|
else:
|
|||
|
return create_response('upload zip file please')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 以图搜图
|
|||
|
@app.route('/imgSearchImg', methods=['POST'])
|
|||
|
def img_search_img():
|
|||
|
searchfile = './img/search/face'
|
|||
|
try:
|
|||
|
file = request.files['img_name']
|
|||
|
if file and check_file_format(file.filename, ALLOWED_IMG):
|
|||
|
img_path = './img/search/' + secure_filename(file.filename)
|
|||
|
file.save(img_path)
|
|||
|
check_img_size(img_path)
|
|||
|
img, box_and_point = detect_one(img_path, retinaface_model,
|
|||
|
retinaface_args)
|
|||
|
if len(img) == 1:
|
|||
|
Onename = []
|
|||
|
num = 0
|
|||
|
for filenames in os.listdir(searchfile):
|
|||
|
imgpath = os.path.join(searchfile, filenames)
|
|||
|
imgdata, box_and_point = detect_one(imgpath, retinaface_model, retinaface_args)
|
|||
|
result = face_verification(img, imgdata, arcface_model, cpu_or_cuda)
|
|||
|
isOne, distance = result.split(' ', -1)[0], result.split(' ', -1)[1]
|
|||
|
if isOne == 'same':
|
|||
|
Onename.append(filenames)
|
|||
|
num += 1
|
|||
|
return create_response('success', name=Onename, num=num)
|
|||
|
else:
|
|||
|
return create_response('image contains no face or more than 1 face')
|
|||
|
else:
|
|||
|
return create_response('png jpg jpeg bmp are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('image size should be less than 10M')
|
|||
|
|
|||
|
|
|||
|
# 人脸聚类接口
|
|||
|
@app.route('/cluster', methods=['POST'])
|
|||
|
def zip_cluster():
|
|||
|
try:
|
|||
|
f = request.files['file_name']
|
|||
|
if f and check_file_format(f.filename, ALLOWED_FILE):
|
|||
|
zip_name = secure_filename(f.filename)
|
|||
|
f.save('./img/cluster_tmp_file/' + zip_name)
|
|||
|
un_zip('./img/cluster_tmp_file/' + zip_name, './img/cluster_tmp_file/')
|
|||
|
emb_list, name_list = get_claster_tmp_file_embedding("./img/cluster_tmp_file/" + zip_name.rsplit('.')[0],
|
|||
|
retinaface_model,
|
|||
|
retinaface_args, arcface_model, cpu_or_cuda)
|
|||
|
return create_cluster_response("success", cluster(emb_list, name_list))
|
|||
|
else:
|
|||
|
return create_response('zip are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('file size should be less than 100M')
|
|||
|
|
|||
|
|
|||
|
# 视频识别接口
|
|||
|
@app.route('/videorecognition', methods=['POST'])
|
|||
|
def video_recognition():
|
|||
|
try:
|
|||
|
f = request.files['file_name']
|
|||
|
if f and check_file_format(f.filename, ALLOWED_VIDEO):
|
|||
|
video_name = secure_filename(f.filename)
|
|||
|
f.save('./video/' + video_name)
|
|||
|
detect_video('./video/' + video_name, './videoout/' + video_name, retinaface_model, arcface_model, k_v,
|
|||
|
retinaface_args)
|
|||
|
return create_response("success")
|
|||
|
else:
|
|||
|
return create_response('mp4 are allowed')
|
|||
|
except RequestEntityTooLarge:
|
|||
|
return create_response('file size should be less than 100M')
|
|||
|
|
|||
|
|
|||
|
@app.route('/download/<string:filename>', methods=['GET'])
|
|||
|
def download(filename):
|
|||
|
if os.path.isfile(os.path.join('./videoout/', filename)):
|
|||
|
return send_from_directory('./videoout/', filename, as_attachment=True)
|
|||
|
else:
|
|||
|
return create_response("Download failed")
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
k_v = load_npy("./Database/student.npy")
|
|||
|
database_name_list = list(k_v.keys())
|
|||
|
vector_list = np.array(list(k_v.values()))
|
|||
|
print(vector_list.shape)
|
|||
|
#print(database_name_list)
|
|||
|
nlist = 50
|
|||
|
quantizer = faiss.IndexFlatL2(512) # the other index
|
|||
|
index = faiss.IndexIVFFlat(quantizer, 512, nlist, faiss.METRIC_L2)
|
|||
|
index.train(vector_list)
|
|||
|
# index = faiss.IndexFlatL2(512)
|
|||
|
index.add(vector_list)
|
|||
|
index.nprobe = 50
|
|||
|
app.run(host="0.0.0.0", port=5000)
|