1、app.py 编写 & 服务启动
- 发布 api 服务时,需要在相应环境中安装 simcse 包。
- 在以下脚本的同级目录下,执行
flask run -h 0.0.0.0 -p 5005
启动服务 ```python from simcse import SimCSE import sys import re import os from flask import jsonify from flask import Flask, request
jy: 指定使用的 GPU 号;
os.environ[‘CUDA_VISIBLE_DEVICES’] = “5”
def remove_cn_punct(str_cn): pattern_cn_punct = “[,|≥|。|、|…|?|=|’|‘|“|”|;|:|!|(|)|%|,|-|” + \ “:| |/|]|[|(|)|>|<]” str_cn = re.sub(pattern_cn_punct, “”, str_cn).strip() return str_cn
model_path = “/home/huangjiayue/04_SimCSE/jy_model/simcse-model”
jy: 模型名称(存放模型的文件夹)
en_model_name = “en-256” cn_model_name = “cn-256_withPunct” encn_model_name = “encn-256-cnWithPunct” sup_simcse_model_name = “sup-simcse-bert-base-uncased” unsup_simcse_model_name = “unsup-simcse-bert-base-uncased”
en_model = SimCSE(os.path.join(model_path, en_model_name)) cn_model = SimCSE(os.path.join(model_path, cn_model_name)) encn_model = SimCSE(os.path.join(model_path, encn_model_name)) sup_simcse_model = SimCSE(os.path.join(model_path, sup_simcse_model_name)) unsup_simcse_model = SimCSE(os.path.join(model_path, unsup_simcse_model_name))
model_map = { “en”: en_model, “cn”: cn_model, “encn”: encn_model, “sup_simcse”: sup_simcse_model, “unsup_simcse”: unsup_simcse_model, }
SUPPORT_MODEL = list(model_map.keys())
app = Flask(name)
def get_model(model_name): if model_name not in SUPPORT_MODEL: return “not supported model: 【%s】” % model_name return model_map[model_name]
@app.route(“/get_sent_embedding”, methods=[“POST”])
def get_sent_embedding():
“””
sentence = “A woman is reading.”
返回 sent_embedding:
print(sent_embedding.shape) # torch.Size([768])
print(type(sent_embedding)) #
#import pdb; pdb.set_trace()
# jy: "en"、"cn" 或 "encn"
model_name = request.json.get("model")
model_ = get_model(model_name.strip())
text_ = request.json.get("text")
embedding = model_.encode(text_.strip())
# jy: 注意, 此字符串类型的返回结果不能在后续被反序列化, 实际应用中应再进一步
# 处理后再返回;
return str(embedding)
@app.route(“/getsents_similarity”, methods=[“POST”]) def get_sents_similarity(): “”” ls_sents1 = [‘A woman is reading.’, ‘A man is playing a guitar.’] ls_sents2 = [‘He plays guitar.’, ‘A woman is making a photo.’] 返回结果: [[0.18968055 0.4845423 ] [0.7468339 0.27683863]] “”” if request.method == “POST”: model_name = request.json.get(“model”) model = get_model(model_name.strip())
ls_sents1 = request.json.get("ls_sents1")
ls_sents2 = request.json.get("ls_sents2")
similarities = model_.similarity(ls_sents1, ls_sents2)
# jy: 注意, 此字符串类型的返回结果不能在后续被反序列化, 实际应用中应再进一步
# 处理后再返回;
return str(similarities)
@app.route(“/search_similarity_sents”, methods=[“POST”]) def similarity_search(): “”” ls_sents = [‘A woman is reading.’, ‘A man is playing a guitar.’] sent_or_ls_sents = “He plays guitar.” “”” if request.method == “POST”:
#import pdb; pdb.set_trace()
model_name = request.json.get("model")
model_ = get_model(model_name.strip())
sent_or_ls_sents = request.json.get("sent_or_ls_sents")
ls_sents = request.json.get("ls_sents")
threshold = request.json.get("threshold", 0.6)
top_k = request.json.get("top_k", 5)
# jy: 如果环境中有安装 faiss 包,则以下的 build_index 方法(在 /simcse/tool.py
# 的 SimCSE 类中定义)会自动导入 faiss 包加速运算。
# 注意:faiss did not well support Nvidia AMPERE GPUs (3090 and A100).
# In that case, you should change to other GPUs or install the CPU
# version of faiss package.
model_.build_index(ls_sents)
results = model_.search(sent_or_ls_sents, threshold=threshold, top_k=top_k)
return str(results)
<a name="gPK6K"></a>
# 2、请求服务
<a name="J32HE"></a>
## (1)shell 脚本请求
```shell
PORT=5005
IP=192.168.3.250
:<<!
curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/get_sent_embedding -d '
{"text": "Immunometabolism features of metabolic deregulation and cancer. immunometabolics for cancer.",
"model": "en"
}'
!
:<<!
curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/get_sents_similarity -d '
{"ls_sents1": ["A woman is reading.", "A man is playing a guitar."],
"ls_sents2": ["He plays guitar.", "A woman is making a photo."],
"model": "unsup_simcse"
}'
!
curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/search_similarity_sents -d '
{"ls_sents": ["A woman is reading.", "A man is playing a guitar."],
"sent_or_ls_sents": ["He plays guitar.", "A woman is making a photo."],
"model": "sup_simcse",
"threshold": -1000,
"top_k": 5
}'
(2)python 脚本请求
import requests
import io
import os
import json
from w3lib.html import remove_tags
SUPPORTED_MODEL = ["en", "cn", "encn"]
def get_sent_embedding(model, sent):
headers = {}
if model not in SUPPORTED_MODEL:
raise Exception("暂不支持输入的模型:【%s】" % model)
json_data = {
'text': sent,
'model': model,
}
#import pdb; pdb.set_trace()
url_ = 'http://192.168.3.250:5005/get_sent_embedding'
response = requests.post(url_, headers=headers, json=json_data)
#print(response.text) # string
#print(response.content) # bytes
# jy: 待确定使用 pt.tensor 类似 eval 转换为 tensor
return response.text
"""
sent = 'Immunometabolism features of metabolic deregulation and cancer. immunometabolics for cancer.'
model = "en33"
embedding_str = get_sent_embedding(model, sent)
print(embedding_str)
"""
def get_sents_similarity(ls_sents1, ls_sents2, model):
if model not in SUPPORTED_MODEL:
raise Exception("暂不支持输入的模型:【%s】" % model)
headers = {}
json_data = {'ls_sents1': ls_sents1,
'ls_sents2': ls_sents2,
'model': model}
url_ = 'http://192.168.3.250:5005/get_sents_similarity'
response = requests.post(url_, headers=headers, json=json_data)
return response.text
"""
ls_sents1 = [
'A woman is reading.',
'A man is playing a guitar.',]
ls_sents2 = [
'He plays guitar.',
'A woman is making a photo.',]
model = "en"
str_sim_score_matrix = get_sents_similarity(ls_sents1, ls_sents2, model)
print(str_sim_score_matrix)
"""
def search_similarity_sents(sent_or_ls_sents, ls_sents, model, threshold=0.6, top_k=5):
"""
sent_or_ls_sents 对应 query 输入, 可以是一段文本或文本列表:
如: ['He plays guitar.', 'A woman is making a photo.']
ls_sents 对应 candidates 列表;
如: ['A woman is reading.', 'A man is playing a guitar.']
返回结果: [[('A man is playing a guitar.', 0.790188193321228)], []]
"""
if model not in SUPPORTED_MODEL:
raise Exception("暂不支持输入的模型:【%s】" % model)
headers = {}
json_data = {'ls_sents': ls_sents,
'sent_or_ls_sents': sent_or_ls_sents,
'model': model,
'threshold': threshold,
'top_k': top_k}
url_ = 'http://192.168.3.250:5005/search_similarity_sents'
#import pdb; pdb.set_trace()
response = requests.post(url_, headers=headers, json=json_data)
ls_res = eval(response.text)
return ls_res
"""
ls_sents = [
'A woman is reading.',
'A man is playing a guitar.'
]
sent_or_ls_sents = ['He plays guitar.', 'A woman is making a photo.']
model = "encn"
threshold = -1
ls_res = search_similarity_sents(sent_or_ls_sents, ls_sents, model, threshold=threshold)
# jy: [[('A man is playing a guitar.', 0.790188193321228)], []]
print(ls_res)
"""