|
import gradio as gr |
|
import openai |
|
from openai import OpenAI |
|
import google.generativeai as genai |
|
import os |
|
import io |
|
import base64 |
|
|
|
|
|
|
|
api_key = "" |
|
|
|
endpoints = '' |
|
|
|
|
|
MODEL = os.environ.get("MODEL") |
|
|
|
DESCRIPTION = ''' |
|
<div> |
|
<h1 style="text-align: center;">Medster - Medical Diagnostic Assistant</h1> |
|
<p>An AI tool that helps you analyze symptoms and test reports. </p> |
|
<p>🔎 Select the department you need to consult, and enter the symptom description or physical examination information in the input box; you can also upload the test report image in the picture box. </p> |
|
<p>🦕 Please note that the generated information may be inaccurate and does not have any actual reference value. Please contact a professional doctor if necessary. </p> |
|
</div> |
|
''' |
|
|
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display: block; |
|
} |
|
footer { |
|
display:none !important |
|
} |
|
""" |
|
|
|
|
|
LICENSE = 'MODEL: ' + MODEL + ' LOADED' |
|
|
|
|
|
def endpoints(api_key): |
|
if api_key not None: |
|
if api_key[:3] == "sk-": |
|
return 'OPENAI' |
|
else: |
|
return 'GOOGLE' |
|
return |
|
|
|
endpoints = endpoints(api_key) |
|
|
|
def read(filename): |
|
with open(filename) as f: |
|
data = f.read() |
|
return data |
|
|
|
SYS_PROMPT = read('system_prompt.txt') |
|
|
|
def process_text(text_input, unit): |
|
if text_input and endpoints == 'OPENAI': |
|
client = OpenAI(api_key=api_key) |
|
completion = client.chat.completions.create( |
|
model=MODEL, |
|
messages=[ |
|
{"role": "system", "content": f" You are a experienced {unit} doctor." + SYS_PROMPT}, |
|
{"role": "user", "content": f"Hello! Could you solve {text_input}?"} |
|
] |
|
) |
|
return completion.choices[0].message.content |
|
elif text_input and endpoints == "GOOGLE": |
|
genai.configure(api_key=GOOGLE_API_KEY) |
|
model = genai.GenerativeModel(MODEL), |
|
return response = model.generate_content(SYS_PROMPT + text_input).text |
|
return "" |
|
|
|
def encode_image_to_base64(image_input): |
|
buffered = io.BytesIO() |
|
image_input.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
return img_str |
|
|
|
def process_image(image_input, unit): |
|
if image_input is not None and endpoints == 'OPENAI': |
|
client = OpenAI(api_key=api_key) |
|
|
|
|
|
base64_image = encode_image_to_base64(image_input) |
|
response = client.chat.completions.create( |
|
model=MODEL, |
|
messages=[ |
|
{"role": "system", "content": f" You are a experienced {unit} doctor." + SYS_PROMPT}, |
|
{"role": "user", "content": [ |
|
{"type": "text", "text": "Help me understand what is in this picture and analysis."}, |
|
{"type": "image_url", |
|
"image_url": { |
|
"url": f"data:image/jpeg;base64,{base64_image}", |
|
"detail":"low"} |
|
} |
|
]} |
|
], |
|
temperature=0.0, |
|
max_tokens=1024, |
|
) |
|
return response.choices[0].message.content |
|
elif image_input is not None and ENDPOINTS == "GOOGLE": |
|
genai.configure(api_key=GOOGLE_API_KEY) |
|
model = genai.GenerativeModel(MODEL), |
|
return response = model.generate_content(image_input).text |
|
|
|
|
|
def main(text_input="", image_input=None, unit=""): |
|
if text_input and image_input is None: |
|
return process_text(text_input,unit) |
|
elif image_input is not None: |
|
return process_image(image_input,unit) |
|
|
|
with gr.Blocks(theme='shivi/calm_seafoam', css=css, title="Medster - Medical Diagnostic Assistant") as iface: |
|
with gr.Accordion(""): |
|
gr.Markdown(DESCRIPTION) |
|
unit = gr.Dropdown(label="🩺Department", value='Traditional Medicine', elem_id="units", |
|
choices=["Traditional Medicine", "Internal Medicine", "Surgery", "Obstetrics and Gynecology", "Pediatrics", \ |
|
"Orthodontics", "Andrology", "Dermatology and Venereology", "Infectious Diseases", "Psychiatry", \ |
|
"Plastic Surgery Department", "Nutrition Department", "Reproductive Center", "Anesthesiology Department", "Medical Imaging Department", \ |
|
"Orthopedics", "Oncology", "Emergency Department", "Laboratory Department"]) |
|
with gr.Row(): |
|
output_box = gr.Markdown(label="Diagnosis") |
|
with gr.Row(): |
|
api_key = gr.Textbox(label="API Key") |
|
with gr.Row(): |
|
image_input = gr.Image(type="pil", label="Upload Image") |
|
text_input = gr.Textbox(label="Submit") |
|
with gr.Row(): |
|
submit_btn = gr.Button("🚀 Send") |
|
clear_btn = gr.ClearButton(output_box, value="🗑️ Clear") |
|
|
|
|
|
submit_btn.click(main, inputs=[api_key, text_input, image_input, unit], outputs=output_box) |
|
|
|
gr.Markdown(LICENSE) |
|
|
|
|
|
|
|
iface.queue().launch(show_api=False) |