建立HTTPS API server-client提供AI模型進行推論(以文本糾錯模型為例)

Writing HTTPS API server code pair to provide AI model inference using python.

前言

本篇延續前一篇模型訓練,將建立一個HTTPS server佈署訓練好的模型使其可以被呼叫

 https://dotblogs.com.tw/Ryuichi/2024/06/29/171320

AI模型在經過漫長的探索、微調與驗證後,最終都是要落地應用才能發揮它的價值。一般來說要開放其他系統進行介接,最簡單的方式就是建一個API server讓其他人呼叫,計算完之後再將結果吐回去。

使用python建立server的方式有很多種,你可以使用現成的flask、gradio…等套件,快速的將server建立起來,但實務上在應用的時候,用這些現成的套件可能會遇到風險軟體(whitesource、checkmarx)掃描出高風險或是作業環境更換後為了維持相容性,導致需要更換版本的問題,通常在更換版本後隨之而來的問題就是得修改程式碼…

在職場上這麼多年,我自己歸納出來的心得:在創建功能的時候,盡量用低階、原生的語法完成,不用其他人封裝好的套件(雖然可能很省時間),往後在程式碼與專案的維護上會比較輕鬆,程式的相容性也會比較好(因為你不知道其他人封裝好的套件裡面是不是有用到某些環境不支援的語法)。

本篇文章將借助ChatGPT的力量,並修改其回饋之內容,建立一個HTTPS server,提供一個文字糾錯的API server進行應用。

一、HTTPS SSL自簽憑證產生

因為安全性的問題,原則上現在的系統都不太建立HTTP server了,傳輸資訊明碼很容易被竊取重要資訊,如果需要建立HTTPS server,則需要有憑證,本篇使用自簽憑證進行建置,這部分網路上很多教學,你也可以跟我一樣參考這篇產生出來,紅框部分就是server在啟動時需要引入的檔案。

建立時記得設定要換成你的網路IP,我這邊是本機測試所以使用127.0.0.1。

https://footmark.com.tw/news/linux/ca-localip-ssl-https/

更換IP
產生自簽憑證
二、詢問ChatGPT產生server code並進行修改

用英文下promt詢問,讓ChatGPT產生出基本的code pair出來,我下的prompt是use python code to write a https server client code pair。

我不是故意用英文,只是我覺得用英文問得到的回答通常比較讓我滿意。

第一次下prompt詢問code pair

Server code沒什麼問題,但你會發現用Jupyter Notebook在跑的時候沒有辦法手動停掉server,因此我們會增加try catch區段在httpd.serve_forever(),讓他可以接受人為停止的key interrupt。

修改 Server Code

Client code的部分就需要修改比較多,因為ChatGPT的回答是使用request套件,我記得用checkmarx有掃出高風險,這迫使我得使用其他方式撰寫。

ChatGPT回應使用request套件

經過survey後決定使用比較底層的http.client套件進行改寫,於是接著對ChatGPT下prompt進行追問,我下的prompt是:client use http.client lib。

再次下prompt修改client code
三、加入模型inference code,引入模型進行推論

接著我們將訓練好的模型實際加入server code中,並啟動server,完整程式碼如下,中間有針對checkmarx認為的一些高風險語法做了補強。

註解已經寫得很清楚,應該不用多作解釋,過程中我遇到比較弔詭exception是出現Unexpected key(s) in state_dict: "bert.bert.embeddings.position_ids,經過google確認後發現是因為不同版本的transformer訓練出來的模型混搭load會有問題,load模型時候加strict=False就可以解決(我中間環境有切換過,才發現這個問題)。

pycorrector_config.json的設定如下,主要是設定模型設定的位置,以及一些建立server時的需求設定。

# @Author           : Ryuichi
# @Time             : 2024/11/30 16:50
# @Code Description : Start a SSL server POST API with private certification to load text error corrector model.

from transformers import BertTokenizerFast
from pycorrector.macbert.macbert4csc import MacBert4Csc
from pycorrector.macbert.softmaskedbert4csc import SoftMaskedBert4Csc
from pycorrector.macbert.defaults import _C as cfg
import torch

import http.server
import ssl
import socket
import urllib.parse

import logging
from logging.handlers import RotatingFileHandler
import time
import uuid
import json
import os
from markupsafe import escape

# Load config JSON file
configFile           = open(r'D:/Data/PyCorrector/app/config/pycorrector_config.json')
configFileLoader     = json.load(configFile)

# Server port
global_binding_port  = configFileLoader.get('tecserver_config').get('binding_port')
# 呼叫API時的url
global_valid_url     = configFileLoader.get('tecserver_config').get('valid_url')
# SSL Key
global_keyfile_path  = configFileLoader.get('tecserver_config').get('keyfile_path')
# SSL cert
global_certfile_path = configFileLoader.get('tecserver_config').get('certfile_path')
# 模型路徑
global_ckpt_path     = configFileLoader.get('model_config').get('ckpt_path')
# 模型字典路徑
global_vocab_dir     = configFileLoader.get('model_config').get('vocab_dir')
# 模型設定路徑
global_cfg_path      = configFileLoader.get('model_config').get('cfg_path')
# 前端呼叫時的token
global_channel_id_token_config = configFileLoader.get('channel_id_token_config')

returnMsg = ''

# 定義推理類
class Inference:
    def __init__(self, ckpt_path, vocab_dir, cfg_path):
        # 設定裝置為 CPU
        #device = torch.device('cpu')
        #print(f'set device to: {device}')
        #要用GPU就改成cuda:0
        #device = torch.device("cuda:0")
        
        # 加載 BertTokenizer
        tokenizer = BertTokenizerFast.from_pretrained(vocab_dir)
        # 從配置文件合併配置
        cfg.merge_from_file(cfg_path)

        # 根據配置文件中的模型類型加載對應的模型
        model = MacBert4Csc.load_from_checkpoint(
                checkpoint_path=ckpt_path,
                cfg=cfg,
                #map_location=device,
                tokenizer=tokenizer,
                # 如果是使用不同版本的transformer訓練出來的模型混搭load會有問題
                # 會出現Unexpected key(s) in state_dict: "bert.bert.embeddings.position_ids
                # 加上strict=False即可解決問題
                strict=False
        )
        # 將模型移動到裝置上(這裡是 CPU)
        #model.to(device)
        
        # 設置模型為評估模式
        model.eval()
        # 儲存模型和其他屬性到類的實例中
        self.model = model
        self.tokenizer = tokenizer
        #self.device = device

    # 定義預測函數
    def predict(self, text):
        # 設置模型為評估模式
        self.model.eval()
        with torch.no_grad():
            # 使用模型進行預測並返回結果
            return self.model.predict([text])[0]


class MyHTTPRequestHandler(http.server.BaseHTTPRequestHandler):
    def _set_response(self, status_code, result):
        self.send_response(status_code)
        self.send_header('Strict-Transport-Security', 'max-age=31536000; includeSubDomains')#checkmarx 網頁應用程式沒有設定HTTP強制安全傳輸技術(HTTP Strict Transport Security,簡稱HSTS) header 導致容易被受到攻擊
        self.end_headers()
        # 將結果組成Json回傳
        response = {
            'data': {
                'correction': result
            }
        }
        returnData = json.dumps(response, ensure_ascii=False)
        self.wfile.write(returnData.encode('utf-8'))

    def empty_or_none(self, x): 
        return x is None or x == ''

    def do_POST(self):
        try:
            # Generate a UUID,用來記錄是哪一個client呼叫
            generated_uuid = uuid.uuid4()
            
            # Convert UUID to string
            uuid_str = str(generated_uuid)
            
            # print client info
            print(f'{uuid_str}-Connection established with: {self.client_address}')
            
            parsed_path = urllib.parse.urlparse(self.path)
            print(f'{uuid_str}-Client Calling Path: {parsed_path.path}')
            
            if parsed_path.path == global_valid_url:
                # 讀取client端傳過來的JSON
                content_length = int(self.headers['Content-Length'])
                post_data = self.rfile.read(content_length).decode('utf-8')
                
                # parsed_data = urllib.parse.parse_qs(post_data)
                request_json = json.loads(post_data)
                channel_id   = escape(request_json.get("channel_id"))
                token        = escape(request_json.get("token"))
                calling_uuid = escape(request_json.get("calling_uuid"))
                input_string = escape(request_json.get("input_string"))
                print(f'{uuid_str}-client post data: channel_id={channel_id}, token={token}, calling_uuid={calling_uuid}, input_string={input_string}')
                
                # 檢查參數是否為空
                if self.empty_or_none(channel_id) or self.empty_or_none(token) or self.empty_or_none(calling_uuid) or self.empty_or_none(input_string):
                    raise Exception(f'Error: parameter has empty.')
                
                getTokenPair = global_channel_id_token_config.get(channel_id)
                if getTokenPair != None and getTokenPair != token:
                    raise Exception(f'Error: channel token mismatch or not exist.')
                
                result = m.predict(input_string)
                print(f'{uuid_str}-Return Result: {result}')
                
                self._set_response(200, result)
            else:
                print(f'{uuid_str}-Error: Inavlid Url Calling.')
                raise Exception(f'Error: Inavlid Url Calling.')

        except Exception as e:
            returnMsg = str(e)
            print(f'{uuid_str}-{returnMsg}')
            self._set_response(500, returnMsg)
            

def run(server_class=http.server.HTTPServer, handler_class=MyHTTPRequestHandler):
    server_address = ('127.0.0.1', int(global_binding_port))
    httpd = server_class(server_address, handler_class)
    
    # Load SSL certificate and key
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    context.load_cert_chain(certfile=global_certfile_path, keyfile=global_keyfile_path)
    httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
    
    print(f'Starting httpd server on port {global_binding_port}...')
    # httpd.serve_forever() -->只寫這樣,ctrl+c停用server時port會卡住
    try:
        httpd.serve_forever()
    except KeyboardInterrupt:
        print('Ctrl-C Interrupt.')
    except Exception as e:
        print('Exception Occur: ' + str(e))
        print(str(e))
    finally:
        # Clean-up server (close socket, etc.)
        httpd.server_close()

# 創建推理類的實例
m = Inference(
    ckpt_path = global_ckpt_path,  # 檢查點路徑
    vocab_dir = global_vocab_dir,  # 詞彙表目錄
    cfg_path  = global_cfg_path    # 配置文件路徑
)

if __name__ == '__main__':
    run()

Client的部分,就是很單純的做POST呼叫,並把需要的token、設定以及需要糾正的句子帶過去。

import http.client
import ssl
import urllib.parse
import sys
import json

def send_post_request():
    # Establish a connection to the server
    client_context = ssl.SSLContext()
    
    #忽略憑證驗證時打開
#     client_context.verify_mode = ssl.CERT_OPTIONAL
#     client_context.check_hostname = False
#     client_context.verify_mode = ssl.CERT_NONE
    
    
    # client_context.load_verify_locations(r'D:/Data/PyCorrector/app/ssl/certs/ft.com.crt')
    
    conn = http.client.HTTPSConnection('127.0.0.1', port=7000, context=client_context)

    # Define data to be sent in the POST request
    #data = b'Hello, this is a POST request.'
    
    data = {
        'channel_id'   : 'ASR',
        'token'        : 'TokenForASR',
        'calling_uuid'  : '202411300000000001',
        "input_string" : '尼知道我在燈你嗎?'
    }
    # params = urllib.parse.urlencode(payload)
    # headers = {'Content-type': 'application/json'}

    # Send POST request
    conn.request('POST', '/predict/pycorrector_system', json.dumps(data).encode('utf-8'))

    # Get response from the server
    response = conn.getresponse()
    
    print('status_code: ' + str(response.status))

    # Print server's response
    response_text = response.read().decode()
    response_data = json.loads(response_text)
    print(response_data.get('data').get('correction'))

    # Close the connection
    conn.close()

if __name__ == '__main__':
    send_post_request()

我們設定一個錯誤句子"尼知道我在燈你嗎"送給api server,看起來模型能夠正確的做文字糾錯。

Server端收到Post,糾錯完後回傳
Client端收到糾錯完的結果

到這邊就實作完成了,其實官方的PyCorrector套件中有自帶Gradio包能夠自動啟動server開放呼叫,但考慮到日後使用的靈活度,模型與API server這兩件事情我認為要拆開去耦合性會比較好,自己客製化的API server使用較低階的語法,在日後的環境相容性較高,同時也可以套用到其他AI模型上,算是一舉多得。