Spaces:
Runtime error
Runtime error
#pip install fastapi ###for fastapi | |
#pip install uvicorn ###for server. to run the api serice from terminal: uvicorn main:app --reload | |
#pip install gunicorn ###gunicorn --bind 0.0.0.0:8000 -k uvicorn.workers.UvicornWorker main:app | |
#pip install python-multipart ###for UploadFile | |
#pip install pillow ###for PIL | |
#pip install transformers ###for transformers | |
#pip install torch ###for torch | |
#pip install sentencepiece ###for AutoTokenizer | |
#pip install -U cos-python-sdk-v5 ###腾讯云对象存储SDK(COS-SDK) | |
from typing import Optional | |
from fastapi import FastAPI, Header | |
from PIL import Image | |
#from transformers import pipeline, EfficientNetImageProcessor, EfficientNetForImageClassification, AutoTokenizer, AutoModelForSeq2SeqLM | |
import torch | |
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification, pipeline | |
from models import ItemInHistory, ItemUploaded, ServiceLoginInfo | |
from openai import OpenAI | |
from qcloud_cos import CosConfig, CosS3Client | |
import sys, os, logging | |
import urllib.parse as urlparse | |
import json, requests | |
# class Conversation: | |
# def __init__(self, openai_client: OpenAI, prompt, num_of_round): | |
# self.openai_client = openai_client | |
# self.prompt = prompt | |
# self.num_of_round = num_of_round | |
# self.messages = [] | |
# self.messages.append({"role": "system", "content": self.prompt}) | |
# def ask(self, question): | |
# message = '' | |
# num_of_tokens = 0 | |
# try: | |
# self.messages.append( {"role": "user", "content": question}) | |
# chat_completion = self.openai_client.chat.completions.create( | |
# model="gpt-3.5-turbo", | |
# messages=self.messages, | |
# temperature=0.5, | |
# max_tokens=2048, | |
# top_p=1, | |
# ) | |
# message = chat_completion.choices[0].message.content | |
# # num_of_tokens = chat_completion.usage.total_tokens | |
# self.messages.append({"role": "assistant", "content": message}) | |
# except Exception as e: | |
# print(e) | |
# return e | |
# if len(self.messages) > self.num_of_round*2 + 1: | |
# del self.messages[1:3] | |
# return message, num_of_tokens | |
app = FastAPI() | |
logging.basicConfig(level=logging.INFO, stream=sys.stdout) | |
cos_secret_id = os.environ['COS_SECRET_ID'] | |
cos_secret_key = os.environ['COS_SECRET_KEY'] | |
cos_region = 'ap-shanghai' | |
cos_bucket = '7072-prod-3g52ms9o7a81f23c-1324125412' | |
token = None | |
scheme = 'https' | |
config = CosConfig(Region=cos_region, SecretId=cos_secret_id, SecretKey=cos_secret_key, Token=token, Scheme=scheme) | |
client = CosS3Client(config) | |
logging.info(f"COS init succeeded.") | |
try: | |
ai_model_bc_preprocessor = EfficientNetImageProcessor.from_pretrained("./birds-classifier-efficientnetb2") | |
ai_model_bc_model = EfficientNetForImageClassification.from_pretrained("./birds-classifier-efficientnetb2") | |
logging.info(f"local model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.") | |
except Exception as e: | |
logging.error(e) | |
try: | |
openai_client = OpenAI( | |
api_key=os.environ.get("OPENAI_API_KEY"), | |
) | |
# prompt = """你是一个鸟类学家,用中文回答关于鸟类的问题。你的回答需要满足以下要求: | |
# 1. 你的回答必须是中文 | |
# 2. 回答限制在100个字以内""" | |
# conv = Conversation(open_client, prompt, 3) | |
logging.info(f"openai chat model loaded.") | |
except Exception as e: | |
logging.error(e) | |
try: | |
ai_model_bc_pipe= pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2") | |
logging.info(f"remote model dennisjooo/Birds-Classifier-EfficientNetB2 loaded.") | |
except Exception as e: | |
print(e) | |
#try: | |
# ai_model_ez_preprocessor = AutoTokenizer.from_pretrained("./opus-mt-en-zh") | |
# ai_model_ez_model = AutoModelForSeq2SeqLM.from_pretrained("./opus-mt-en-zh") | |
# print(f"local model Helsinki-NLP/opus-mt-en-zh loaded.") | |
#except Exception as e: | |
# print(e) | |
#try: | |
# ai_model_ez_pipe= pipeline(task="translation_en_to_zh", model="Helsinki-NLP/opus-mt-en-zh", device=0) | |
# print(f"remote model Helsinki-NLP/opus-mt-en-zh loaded.") | |
#except Exception as e: | |
# print(e) | |
def bird_classifier(image_file: str) -> str: | |
# Opening the image using PIL | |
img = Image.open(image_file) | |
logging.info(f"image file {image_file} is opened.") | |
result:str = "" | |
try: | |
inputs = ai_model_bc_preprocessor(img, return_tensors="pt") | |
# Running the inference | |
with torch.no_grad(): | |
logits = ai_model_bc_model(**inputs).logits | |
# Getting the predicted label | |
predicted_label = logits.argmax(-1).item() | |
result = ai_model_bc_model.config.id2label[predicted_label] | |
logging.info(f"{ai_model_bc_model.config.id2label[predicted_label]}:{ai_model_bc_pipe(img)[0]['label']}") | |
except Exception as e: | |
logging.error(e) | |
logging.info(result) | |
return result | |
# def text_en_zh(text_en: str) -> str: | |
# text_zh = "" | |
# if ai_model_ez_status is MODEL_STATUS.LOCAL: | |
# input = ai_model_ez_preprocessor(text_en) | |
# translated = ai_model_ez_model.generate(**ai_model_ez_preprocessor(text_en, return_tensors="pt", padding=True)) | |
# for t in translated: | |
# text_zh += ai_model_ez_preprocessor.decode(t, skip_special_tokens=True) | |
# elif ai_model_ez_status is MODEL_STATUS.REMOTE: | |
# text_zh = ai_model_ez_pipe(text_en) | |
# return text_zh | |
# Route to upload a file | |
# @app.post("/uploadfile/") | |
# async def create_upload_file(file: UploadFile): | |
# contents: bytes = await file.read() | |
# contents_len = len(contents) | |
# file_name = file.filename | |
# server_file_name = f"server-{file_name}" | |
# with open(server_file_name,"wb") as server_file: | |
# server_file.write(contents) | |
# logging.info(f"{file_name} is received and saved as {server_file_name}.") | |
# bird_classification = bird_classifier(server_file_name) | |
# # if bird_classification != "": | |
# # bird_classification = "the species of bird is " + bird_classification | |
# # bird_classification = text_en_zh(bird_classification) | |
# logging.info(f"AI feedback: {bird_classification}.") | |
# return {"filename": server_file_name, "AI feedback": bird_classification} | |
# Route to login to zhizhi-service | |
def service_login(item: ServiceLoginInfo): | |
logging.info("service_login") | |
logging.info(item) | |
code2Session = f"http://api.weixin.qq.com/sns/jscode2session?appid={item.appid}&secret={item.secret}&js_code={item.js_code}&grant_type={item.grant_type}" | |
logging.info(code2Session) | |
response = requests.get(code2Session) | |
json_response = response.json() | |
logging.info(json_response) | |
return {"user_openid": json_response.get("openid")} | |
# Route to create an item | |
async def create_item(item: ItemUploaded, x_wx_openid: Optional[str]=Header(None)): | |
logging.info("create_item") | |
logging.info(item) | |
logging.info(x_wx_openid) | |
if x_wx_openid is None: | |
x_wx_openid = "" | |
url = urlparse.urlparse(item.item_fileurl) | |
key = url[2][1::] | |
bucket = url[1].split('.')[1] | |
contentfile = key.split('/')[1] | |
historyid = contentfile.split('.')[0] | |
# historyfile = f'{historyid}.json' | |
response = client.get_object( | |
Bucket = bucket, | |
Key = key | |
) | |
response['Body'].get_stream_to_file(contentfile) | |
if item.item_mediatype == "image": | |
bird_classification = bird_classifier(contentfile) | |
try: | |
# question = f"鸟类的英文名是{bird_classification},它的中文名是什么?有什么样的习性?" | |
# answer, num_of_tokens = conv.ask(question) | |
# logging.info(f"chatgpt feedback: {answer}.\n") | |
prompt = """你是一个鸟类学家,用中文回答关于鸟类的问题。你的回答需要满足以下要求: | |
1. 你的回答必须是中文 | |
2. 回答限制在100个字以内""" | |
messages = [] | |
messages.append({"role": "system", "content": prompt}) | |
question = f"鸟类的英文名是{bird_classification},它的中文名是什么?有什么样的习性?" | |
messages.append( {"role": "user", "content": question}) | |
chat_completion = openai_client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0.5, | |
max_tokens=2048, | |
top_p=1, | |
) | |
response = chat_completion.choices[0].message.content | |
logging.info(f"chatgpt feedback: {response}.\n") | |
except Exception as e: | |
logging.error(e) | |
else: | |
bird_classification = "不是image类型,暂不能识别" | |
logging.info(f"AI feedback: {bird_classification}.\n") | |
historyfile = itemToJsonFile(ItemInHistory(history_id = historyid,union_id = x_wx_openid, | |
item_fileurl = item.item_fileurl,item_mediatype = item.item_mediatype, | |
upload_datetime = item.upload_datetime,ai_feedback = bird_classification)) | |
response = client.upload_file( | |
Bucket = cos_bucket, | |
LocalFilePath=historyfile, | |
Key=f'{x_wx_openid}/history/{historyfile}', | |
PartSize=1, | |
MAXThread=10, | |
EnableMD5=False | |
) | |
logging.info(response['ETag']) | |
return {"filename": historyfile, "AI feedback": bird_classification} | |
# Route to list all items uploaded by a specific user by unionid | |
# @app.get("/items/{user_unionid}") | |
# def list_items(user_unionid: str) -> dict[str, list[ItemInHistory]]: | |
# logging.info("list_items") | |
# logging.info(user_unionid) | |
# items: list[ItemInHistory] = [] | |
# response = client.list_objects( | |
# Bucket=cos_bucket, | |
# Prefix=f'{user_unionid}/history/' | |
# ) | |
# logging.info(response['Contents']) | |
# for obj in response['Contents']: | |
# key:str = obj['Key'] | |
# response = client.get_object( | |
# Bucket = cos_bucket, | |
# Key = key | |
# ) | |
# localfile = key.split('/')[2] | |
# response['Body'].get_stream_to_file(localfile) | |
# item = itemFromJsonFile(localfile) | |
# items.append(item) | |
return {"items": items} | |
# Route to list all items uploaded by a specific user by unionid from header | |
def list_items_byheader(x_wx_openid: Optional[str]=Header(None)) -> dict[str, list[ItemInHistory]]: | |
logging.info("list_items_byheader") | |
logging.info(x_wx_openid) | |
items: list[ItemInHistory] = [] | |
response = client.list_objects( | |
Bucket=cos_bucket, | |
Prefix=f'{x_wx_openid}/history/' | |
) | |
logging.info(response['Contents']) | |
for obj in response['Contents']: | |
key:str = obj['Key'] | |
response = client.get_object( | |
Bucket = cos_bucket, | |
Key = key | |
) | |
localfile = key.split('/')[2] | |
response['Body'].get_stream_to_file(localfile) | |
item = itemFromJsonFile(localfile) | |
items.append(item) | |
return {"items": items} | |
def itemFromJsonFile(jsonfile: str) -> ItemInHistory: | |
f = open(jsonfile, 'r') | |
content = f.read() | |
a = json.loads(content) | |
f.close() | |
return ItemInHistory(history_id = a['history_id'],union_id = a['union_id'], | |
item_fileurl = a['item_fileurl'],item_mediatype = a["item_mediatype"], | |
upload_datetime = a["upload_datetime"],ai_feedback = a['ai_feedback']) | |
def itemToJsonFile(item: ItemInHistory): | |
history_json = { | |
"history_id": item.history_id, | |
"union_id": item.union_id, | |
"item_fileurl": item.item_fileurl, | |
"item_mediatype": item.item_mediatype, | |
"upload_datetime": item.upload_datetime, | |
"ai_feedback": item.ai_feedback | |
} | |
b = json.dumps(history_json) | |
historyfile = f'{item.history_id}.json' | |
f = open(historyfile, 'w') | |
f.write(b) | |
f.close() | |
return historyfile | |