zhizhi-aiservice / main.py
hujameson's picture
create app with docker mode on Space
c1e5d84 verified
raw
history blame
12.2 kB
#pip install fastapi ###for fastapi
#pip install uvicorn ###for server. to run the api serice from terminal: uvicorn main:app --reload
#pip install gunicorn ###gunicorn --bind 0.0.0.0:8000 -k uvicorn.workers.UvicornWorker main:app
#pip install python-multipart ###for UploadFile
#pip install pillow ###for PIL
#pip install transformers ###for transformers
#pip install torch ###for torch
#pip install sentencepiece ###for AutoTokenizer
#pip install -U cos-python-sdk-v5 ###腾讯云对象存储SDK(COS-SDK)
from typing import Optional
from fastapi import FastAPI, Header
from PIL import Image
#from transformers import pipeline, EfficientNetImageProcessor, EfficientNetForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, pipeline
from models import ItemInHistory, ItemUploaded, ServiceLoginInfo
from openai import OpenAI
from qcloud_cos import CosConfig, CosS3Client
import sys, os, logging
import urllib.parse as urlparse
import json, requests
# class Conversation:
# def __init__(self, openai_client: OpenAI, prompt, num_of_round):
# self.openai_client = openai_client
# self.prompt = prompt
# self.num_of_round = num_of_round
# self.messages = []
# self.messages.append({"role": "system", "content": self.prompt})
# def ask(self, question):
# message = ''
# num_of_tokens = 0
# try:
# self.messages.append( {"role": "user", "content": question})
# chat_completion = self.openai_client.chat.completions.create(
# model="gpt-3.5-turbo",
# messages=self.messages,
# temperature=0.5,
# max_tokens=2048,
# top_p=1,
# )
# message = chat_completion.choices[0].message.content
# # num_of_tokens = chat_completion.usage.total_tokens
# self.messages.append({"role": "assistant", "content": message})
# except Exception as e:
# print(e)
# return e
# if len(self.messages) > self.num_of_round*2 + 1:
# del self.messages[1:3]
# return message, num_of_tokens
app = FastAPI()
logging.basicConfig(level=logging.INFO, stream=sys.stdout)
cos_secret_id = os.environ['COS_SECRET_ID']
cos_secret_key = os.environ['COS_SECRET_KEY']
cos_region = 'ap-shanghai'
cos_bucket = '7072-prod-3g52ms9o7a81f23c-1324125412'
token = None
scheme = 'https'
config = CosConfig(Region=cos_region, SecretId=cos_secret_id, SecretKey=cos_secret_key, Token=token, Scheme=scheme)
client = CosS3Client(config)
logging.info(f"COS init succeeded.")
try:
ai_model_bc_preprocessor = EfficientNetImageProcessor.from_pretrained("./birds-classifier-efficientnetb2")
ai_model_bc_model = EfficientNetForImageClassification.from_pretrained("./birds-classifier-efficientnetb2")
logging.info(f"local model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.")
except Exception as e:
logging.error(e)
try:
openai_client = OpenAI(
api_key=os.environ.get("OPENAI_API_KEY"),
)
# prompt = """你是一个鸟类学家,用中文回答关于鸟类的问题。你的回答需要满足以下要求:
# 1. 你的回答必须是中文
# 2. 回答限制在100个字以内"""
# conv = Conversation(open_client, prompt, 3)
logging.info(f"openai chat model loaded.")
except Exception as e:
logging.error(e)
try:
ai_model_bc_pipe= pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2")
logging.info(f"remote model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.")
except Exception as e:
print(e)
#try:
# ai_model_ez_preprocessor = AutoTokenizer.from_pretrained("./opus-mt-en-zh")
# ai_model_ez_model = AutoModelForSeq2SeqLM.from_pretrained("./opus-mt-en-zh")
# print(f"local model Helsinki-NLP/opus-mt-en-zh loaded.")
#except Exception as e:
# print(e)
#try:
# ai_model_ez_pipe= pipeline(task="translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh", device=0)
# print(f"remote model Helsinki-NLP/opus-mt-en-zh loaded.")
#except Exception as e:
# print(e)
def bird_classifier(image_file: str) -> str:
# Opening the image using PIL
img = Image.open(image_file)
logging.info(f"image file {image_file} is opened.")
result:str = ""
try:
inputs = ai_model_bc_preprocessor(img, return_tensors="pt")
# Running the inference
with torch.no_grad():
logits = ai_model_bc_model(**inputs).logits
# Getting the predicted label
predicted_label = logits.argmax(-1).item()
result = ai_model_bc_model.config.id2label[predicted_label]
logging.info(f"{ai_model_bc_model.config.id2label[predicted_label]}:{ai_model_bc_pipe(img)[0]['label']}")
except Exception as e:
logging.error(e)
logging.info(result)
return result
# def text_en_zh(text_en: str) -> str:
# text_zh = ""
# if ai_model_ez_status is MODEL_STATUS.LOCAL:
# input = ai_model_ez_preprocessor(text_en)
# translated = ai_model_ez_model.generate(**ai_model_ez_preprocessor(text_en, return_tensors="pt", padding=True))
# for t in translated:
# text_zh += ai_model_ez_preprocessor.decode(t, skip_special_tokens=True)
# elif ai_model_ez_status is MODEL_STATUS.REMOTE:
# text_zh = ai_model_ez_pipe(text_en)
# return text_zh
# Route to upload a file
# @app.post("/uploadfile/")
# async def create_upload_file(file: UploadFile):
# contents: bytes = await file.read()
# contents_len = len(contents)
# file_name = file.filename
# server_file_name = f"server-{file_name}"
# with open(server_file_name,"wb") as server_file:
# server_file.write(contents)
# logging.info(f"{file_name} is received and saved as {server_file_name}.")
# bird_classification = bird_classifier(server_file_name)
# # if bird_classification != "":
# # bird_classification = "the species of bird is " + bird_classification
# # bird_classification = text_en_zh(bird_classification)
# logging.info(f"AI feedback: {bird_classification}.")
# return {"filename": server_file_name, "AI feedback": bird_classification}
# Route to login to zhizhi-service
@app.post("/login/")
def service_login(item: ServiceLoginInfo):
logging.info("service_login")
logging.info(item)
code2Session = f"http://api.weixin.qq.com/sns/jscode2session?appid={item.appid}&secret={item.secret}&js_code={item.js_code}&grant_type={item.grant_type}"
logging.info(code2Session)
response = requests.get(code2Session)
json_response = response.json()
logging.info(json_response)
return {"user_openid": json_response.get("openid")}
# Route to create an item
@app.post("/items/")
async def create_item(item: ItemUploaded, x_wx_openid: Optional[str]=Header(None)):
logging.info("create_item")
logging.info(item)
logging.info(x_wx_openid)
if x_wx_openid is None:
x_wx_openid = ""
url = urlparse.urlparse(item.item_fileurl)
key = url[2][1::]
bucket = url[1].split('.')[1]
contentfile = key.split('/')[1]
historyid = contentfile.split('.')[0]
# historyfile = f'{historyid}.json'
response = client.get_object(
Bucket = bucket,
Key = key
)
response['Body'].get_stream_to_file(contentfile)
if item.item_mediatype == "image":
bird_classification = bird_classifier(contentfile)
try:
# question = f"鸟类的英文名是{bird_classification},它的中文名是什么?有什么样的习性?"
# answer, num_of_tokens = conv.ask(question)
# logging.info(f"chatgpt feedback: {answer}.\n")
prompt = """你是一个鸟类学家,用中文回答关于鸟类的问题。你的回答需要满足以下要求:
1. 你的回答必须是中文
2. 回答限制在100个字以内"""
messages = []
messages.append({"role": "system", "content": prompt})
question = f"鸟类的英文名是{bird_classification},它的中文名是什么?有什么样的习性?"
messages.append( {"role": "user", "content": question})
chat_completion = openai_client.chat.completions.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0.5,
max_tokens=2048,
top_p=1,
)
response = chat_completion.choices[0].message.content
logging.info(f"chatgpt feedback: {response}.\n")
except Exception as e:
logging.error(e)
else:
bird_classification = "不是image类型,暂不能识别"
logging.info(f"AI feedback: {bird_classification}.\n")
historyfile = itemToJsonFile(ItemInHistory(history_id = historyid,union_id = x_wx_openid,
item_fileurl = item.item_fileurl,item_mediatype = item.item_mediatype,
upload_datetime = item.upload_datetime,ai_feedback = bird_classification))
response = client.upload_file(
Bucket = cos_bucket,
LocalFilePath=historyfile,
Key=f'{x_wx_openid}/history/{historyfile}',
PartSize=1,
MAXThread=10,
EnableMD5=False
)
logging.info(response['ETag'])
return {"filename": historyfile, "AI feedback": bird_classification}
# Route to list all items uploaded by a specific user by unionid
# @app.get("/items/{user_unionid}")
# def list_items(user_unionid: str) -> dict[str, list[ItemInHistory]]:
# logging.info("list_items")
# logging.info(user_unionid)
# items: list[ItemInHistory] = []
# response = client.list_objects(
# Bucket=cos_bucket,
# Prefix=f'{user_unionid}/history/'
# )
# logging.info(response['Contents'])
# for obj in response['Contents']:
# key:str = obj['Key']
# response = client.get_object(
# Bucket = cos_bucket,
# Key = key
# )
# localfile = key.split('/')[2]
# response['Body'].get_stream_to_file(localfile)
# item = itemFromJsonFile(localfile)
# items.append(item)
return {"items": items}
# Route to list all items uploaded by a specific user by unionid from header
@app.get("/items/")
def list_items_byheader(x_wx_openid: Optional[str]=Header(None)) -> dict[str, list[ItemInHistory]]:
logging.info("list_items_byheader")
logging.info(x_wx_openid)
items: list[ItemInHistory] = []
response = client.list_objects(
Bucket=cos_bucket,
Prefix=f'{x_wx_openid}/history/'
)
logging.info(response['Contents'])
for obj in response['Contents']:
key:str = obj['Key']
response = client.get_object(
Bucket = cos_bucket,
Key = key
)
localfile = key.split('/')[2]
response['Body'].get_stream_to_file(localfile)
item = itemFromJsonFile(localfile)
items.append(item)
return {"items": items}
def itemFromJsonFile(jsonfile: str) -> ItemInHistory:
f = open(jsonfile, 'r')
content = f.read()
a = json.loads(content)
f.close()
return ItemInHistory(history_id = a['history_id'],union_id = a['union_id'],
item_fileurl = a['item_fileurl'],item_mediatype = a["item_mediatype"],
upload_datetime = a["upload_datetime"],ai_feedback = a['ai_feedback'])
def itemToJsonFile(item: ItemInHistory):
history_json = {
"history_id": item.history_id,
"union_id": item.union_id,
"item_fileurl": item.item_fileurl,
"item_mediatype": item.item_mediatype,
"upload_datetime": item.upload_datetime,
"ai_feedback": item.ai_feedback
}
b = json.dumps(history_json)
historyfile = f'{item.history_id}.json'
f = open(historyfile, 'w')
f.write(b)
f.close()
return historyfile