1、app.py 编写 & 服务启动

  • 发布 api 服务时,需要在相应环境中安装 simcse 包。
  • 在以下脚本的同级目录下,执行flask run -h 0.0.0.0 -p 5005启动服务 ```python from simcse import SimCSE import sys import re import os from flask import jsonify from flask import Flask, request

jy: 指定使用的 GPU 号;

os.environ[‘CUDA_VISIBLE_DEVICES’] = “5”

def remove_cn_punct(str_cn): pattern_cn_punct = “[,|≥|。|、|…|?|=|’|‘|“|”|;|:|!|(|)|%|,|-|” + \ “:| |/|]|[|(|)|>|<]” str_cn = re.sub(pattern_cn_punct, “”, str_cn).strip() return str_cn

model_path = “/home/huangjiayue/04_SimCSE/jy_model/simcse-model”

jy: 模型名称(存放模型的文件夹)

en_model_name = “en-256” cn_model_name = “cn-256_withPunct” encn_model_name = “encn-256-cnWithPunct” sup_simcse_model_name = “sup-simcse-bert-base-uncased” unsup_simcse_model_name = “unsup-simcse-bert-base-uncased”

en_model = SimCSE(os.path.join(model_path, en_model_name)) cn_model = SimCSE(os.path.join(model_path, cn_model_name)) encn_model = SimCSE(os.path.join(model_path, encn_model_name)) sup_simcse_model = SimCSE(os.path.join(model_path, sup_simcse_model_name)) unsup_simcse_model = SimCSE(os.path.join(model_path, unsup_simcse_model_name))

model_map = { “en”: en_model, “cn”: cn_model, “encn”: encn_model, “sup_simcse”: sup_simcse_model, “unsup_simcse”: unsup_simcse_model, }

SUPPORT_MODEL = list(model_map.keys())

app = Flask(name)

def get_model(model_name): if model_name not in SUPPORT_MODEL: return “not supported model: 【%s】” % model_name return model_map[model_name]

@app.route(“/get_sent_embedding”, methods=[“POST”]) def get_sent_embedding(): “”” sentence = “A woman is reading.” 返回 sent_embedding: print(sent_embedding.shape) # torch.Size([768]) print(type(sent_embedding)) # “”” if request.method == “POST”:

  1. #import pdb; pdb.set_trace()
  2. # jy: "en"、"cn" 或 "encn"
  3. model_name = request.json.get("model")
  4. model_ = get_model(model_name.strip())
  5. text_ = request.json.get("text")
  6. embedding = model_.encode(text_.strip())
  7. # jy: 注意, 此字符串类型的返回结果不能在后续被反序列化, 实际应用中应再进一步
  8. # 处理后再返回;
  9. return str(embedding)

@app.route(“/getsents_similarity”, methods=[“POST”]) def get_sents_similarity(): “”” ls_sents1 = [‘A woman is reading.’, ‘A man is playing a guitar.’] ls_sents2 = [‘He plays guitar.’, ‘A woman is making a photo.’] 返回结果: [[0.18968055 0.4845423 ] [0.7468339 0.27683863]] “”” if request.method == “POST”: model_name = request.json.get(“model”) model = get_model(model_name.strip())

  1. ls_sents1 = request.json.get("ls_sents1")
  2. ls_sents2 = request.json.get("ls_sents2")
  3. similarities = model_.similarity(ls_sents1, ls_sents2)
  4. # jy: 注意, 此字符串类型的返回结果不能在后续被反序列化, 实际应用中应再进一步
  5. # 处理后再返回;
  6. return str(similarities)

@app.route(“/search_similarity_sents”, methods=[“POST”]) def similarity_search(): “”” ls_sents = [‘A woman is reading.’, ‘A man is playing a guitar.’] sent_or_ls_sents = “He plays guitar.” “”” if request.method == “POST”:

  1. #import pdb; pdb.set_trace()
  2. model_name = request.json.get("model")
  3. model_ = get_model(model_name.strip())
  4. sent_or_ls_sents = request.json.get("sent_or_ls_sents")
  5. ls_sents = request.json.get("ls_sents")
  6. threshold = request.json.get("threshold", 0.6)
  7. top_k = request.json.get("top_k", 5)
  8. # jy: 如果环境中有安装 faiss 包,则以下的 build_index 方法(在 /simcse/tool.py
  9. # 的 SimCSE 类中定义)会自动导入 faiss 包加速运算。
  10. # 注意:faiss did not well support Nvidia AMPERE GPUs (3090 and A100).
  11. # In that case, you should change to other GPUs or install the CPU
  12. # version of faiss package.
  13. model_.build_index(ls_sents)
  14. results = model_.search(sent_or_ls_sents, threshold=threshold, top_k=top_k)
  15. return str(results)
  1. <a name="gPK6K"></a>
  2. # 2、请求服务
  3. <a name="J32HE"></a>
  4. ## (1)shell 脚本请求
  5. ```shell
  6. PORT=5005
  7. IP=192.168.3.250
  8. :<<!
  9. curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/get_sent_embedding -d '
  10. {"text": "Immunometabolism features of metabolic deregulation and cancer. immunometabolics for cancer.",
  11. "model": "en"
  12. }'
  13. !
  14. :<<!
  15. curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/get_sents_similarity -d '
  16. {"ls_sents1": ["A woman is reading.", "A man is playing a guitar."],
  17. "ls_sents2": ["He plays guitar.", "A woman is making a photo."],
  18. "model": "unsup_simcse"
  19. }'
  20. !
  21. curl -X POST -H 'Content-Type: application/json' http://${IP}:${PORT}/search_similarity_sents -d '
  22. {"ls_sents": ["A woman is reading.", "A man is playing a guitar."],
  23. "sent_or_ls_sents": ["He plays guitar.", "A woman is making a photo."],
  24. "model": "sup_simcse",
  25. "threshold": -1000,
  26. "top_k": 5
  27. }'

(2)python 脚本请求

import requests
import io
import os
import json
from w3lib.html import remove_tags

SUPPORTED_MODEL = ["en", "cn", "encn"]

def get_sent_embedding(model, sent):
    headers = {}
    if model not in SUPPORTED_MODEL:
        raise Exception("暂不支持输入的模型:【%s】" % model)
    json_data = {
        'text': sent,
        'model': model,
    }

    #import pdb; pdb.set_trace()
    url_ = 'http://192.168.3.250:5005/get_sent_embedding'
    response = requests.post(url_, headers=headers, json=json_data)
    #print(response.text)     # string
    #print(response.content)  # bytes
    # jy: 待确定使用 pt.tensor 类似 eval 转换为 tensor
    return response.text

"""
sent = 'Immunometabolism features of metabolic deregulation and cancer. immunometabolics for cancer.'
model = "en33"
embedding_str = get_sent_embedding(model, sent)
print(embedding_str)
"""

def get_sents_similarity(ls_sents1, ls_sents2, model):
    if model not in SUPPORTED_MODEL:
        raise Exception("暂不支持输入的模型:【%s】" % model)
    headers = {}

    json_data = {'ls_sents1': ls_sents1,
                 'ls_sents2': ls_sents2,
                 'model': model}

    url_ = 'http://192.168.3.250:5005/get_sents_similarity'
    response = requests.post(url_, headers=headers, json=json_data)
    return response.text

"""
ls_sents1 = [
'A woman is reading.',
'A man is playing a guitar.',]

ls_sents2 = [
'He plays guitar.',
'A woman is making a photo.',]

model = "en"

str_sim_score_matrix = get_sents_similarity(ls_sents1, ls_sents2, model)
print(str_sim_score_matrix)
"""

def search_similarity_sents(sent_or_ls_sents, ls_sents, model, threshold=0.6, top_k=5):
    """
    sent_or_ls_sents 对应 query 输入, 可以是一段文本或文本列表:
        如: ['He plays guitar.', 'A woman is making a photo.']
    ls_sents  对应 candidates 列表;
        如: ['A woman is reading.', 'A man is playing a guitar.']
    返回结果: [[('A man is playing a guitar.', 0.790188193321228)], []]
    """
    if model not in SUPPORTED_MODEL:
        raise Exception("暂不支持输入的模型:【%s】" % model)
    headers = {}
    json_data = {'ls_sents': ls_sents,
                 'sent_or_ls_sents': sent_or_ls_sents,
                 'model': model,
                 'threshold': threshold,
                 'top_k': top_k}
    url_ = 'http://192.168.3.250:5005/search_similarity_sents'
    #import pdb; pdb.set_trace()
    response = requests.post(url_, headers=headers, json=json_data)
    ls_res = eval(response.text)
    return ls_res

"""
ls_sents = [
'A woman is reading.',
'A man is playing a guitar.'
]

sent_or_ls_sents = ['He plays guitar.', 'A woman is making a photo.']
model = "encn"
threshold = -1
ls_res = search_similarity_sents(sent_or_ls_sents, ls_sents, model, threshold=threshold)
# jy: [[('A man is playing a guitar.', 0.790188193321228)], []]
print(ls_res)
"""