Spaces:
Runtime error
Runtime error
File size: 3,145 Bytes
2f9ea03 bb16e72 2f9ea03 bb16e72 d58d5be 14f626b 8e2bfc0 ab9c414 bb16e72 ab9c414 8e2bfc0 bb16e72 ab9c414 bb16e72 ab9c414 bb16e72 8e2bfc0 bb16e72 8e2bfc0 ab9c414 8e2bfc0 bb16e72 8e2bfc0 bb16e72 ab9c414 8e2bfc0 bb16e72 8e2bfc0 bb16e72 8e2bfc0 bb16e72 ab9c414 8e2bfc0 bb16e72 8e2bfc0 ab9c414 08137ac bb16e72 ab9c414 08137ac ab9c414 e6713e2 ab9c414 e6713e2 bb16e72 ab9c414 8e2bfc0 ab9c414 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
from janus.models import MultiModalityCausalLM, VLChatProcessor
from PIL import Image
from diffusers import AutoencoderKL
import numpy as np
import gradio as gr
# Configure device and attention implementation
device = "cuda" if torch.cuda.is_available() else "cpu"
attn_implementation = "flash_attention_2" if device == "cuda" else "eager"
print(f"Using device: {device} with {attn_implementation}")
# Initialize medical imaging components
def load_medical_models():
try:
processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
model = MultiModalityCausalLM.from_pretrained(
"deepseek-ai/Janus-1.3B",
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
attn_implementation=attn_implementation,
use_flash_attention_2=(attn_implementation == "flash_attention_2")
).to(device).eval()
vae = AutoencoderKL.from_pretrained(
"stabilityai/sdxl-vae",
torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
).to(device).eval()
return processor, model, vae
except Exception as e:
print(f"Error loading medical models: {str(e)}")
raise
processor, model, vae = load_medical_models()
# Medical image analysis function with attention control
def medical_analysis(image, question, seed=42):
try:
torch.manual_seed(seed)
np.random.seed(seed)
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
inputs = processor(
text=f"<medical_query>{question}</medical_query>",
images=[image],
return_tensors="pt"
).to(device)
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
pad_token_id=processor.tokenizer.eos_token_id
)
return processor.decode(outputs[0], skip_special_tokens=True)
except Exception as e:
return f"Radiology analysis error: {str(e)}"
# Medical interface
with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as demo:
gr.Markdown("""# AI Radiology Assistant
**CT/MRI/X-ray Analysis System**""")
with gr.Tab("Diagnostic Imaging"):
with gr.Row():
med_image = gr.Image(label="DICOM Image", type="pil")
med_question = gr.Textbox(label="Clinical Query",
placeholder="Describe findings in this CT scan...")
analysis_btn = gr.Button("Analyze", variant="primary")
report_output = gr.Textbox(label="Radiology Report", interactive=False)
med_question.submit(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
analysis_btn.click(
medical_analysis,
inputs=[med_image, med_question],
outputs=report_output
)
demo.launch(server_name="0.0.0.0", server_port=7860) |