ayyuce commited on
Commit
3644e34
·
verified ·
1 Parent(s): 66c8382

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -53
app.py CHANGED
@@ -1,67 +1,59 @@
1
  import gradio as gr
2
- from llava_med import LlavaMedProcessor, LlavaMedForCausalLM
 
 
3
  from PIL import Image
4
  import torch
5
 
6
- # Load model and processor
7
- model = LlavaMedForCausalLM.from_pretrained(
8
- "microsoft/llava-med-v1.5-mistral-7b",
9
- torch_dtype=torch.float32, # Use float32 for CPU stability
10
- low_cpu_mem_usage=True,
11
- device_map="cpu"
12
- )
13
- processor = LlavaMedProcessor.from_pretrained(
14
- "microsoft/llava-med-v1.5-mistral-7b"
15
  )
16
 
17
  def analyze_medical_image(image, question):
18
- # Prepare inputs
19
- prompt = f"Question: {question} Answer:"
20
-
21
- # Process inputs
22
- inputs = processor(
23
- text=prompt,
24
- images=image,
25
- return_tensors="pt",
26
- padding=True
27
- ).to("cpu")
28
 
29
- # Generate response
30
- with torch.no_grad():
31
- outputs = model.generate(
32
- **inputs,
33
- max_new_tokens=256,
34
- do_sample=True,
35
- temperature=0.7,
36
- top_p=0.9
37
- )
38
 
39
- # Decode response
40
- response = processor.batch_decode(
41
- outputs,
42
- skip_special_tokens=True
43
- )[0].split("Answer:")[-1].strip()
 
 
 
 
 
 
 
44
 
45
- return response
46
 
47
  # Gradio interface
48
  with gr.Blocks() as demo:
49
- gr.Markdown("# LLaVA-Med Medical Analysis (CPU)")
50
- gr.Markdown("Official Microsoft LLaVA-Med 1.5-Mistral-7B implementation")
51
-
52
  with gr.Row():
53
- with gr.Column():
54
- image_input = gr.Image(label="Medical Image", type="pil")
55
- question_input = gr.Textbox(label="Clinical Question", placeholder="Enter your medical question...")
56
- submit_btn = gr.Button("Analyze")
57
-
58
- with gr.Column():
59
- output_text = gr.Textbox(label="Clinical Analysis", interactive=False)
60
-
61
- submit_btn.click(
62
- fn=analyze_medical_image,
63
- inputs=[image_input, question_input],
64
- outputs=output_text
65
- )
66
-
67
- demo.queue(max_size=5).launch()
 
1
  import gradio as gr
2
+ from llava.model.builder import load_pretrained_model
3
+ from llava.mm_utils import get_model_name_from_path
4
+ from llava.eval.run_llava import eval_model
5
  from PIL import Image
6
  import torch
7
 
8
+ # Load model configuration
9
+ model_path = "microsoft/llava-med-v1.5-mistral-7b"
10
+ model_name = get_model_name_from_path(model_path)
11
+ tokenizer, model, image_processor, _ = load_pretrained_model(
12
+ model_path=model_path,
13
+ model_base=None,
14
+ model_name=model_name,
15
+ device_map="cpu",
16
+ load_4bit=False
17
  )
18
 
19
  def analyze_medical_image(image, question):
20
+ # Convert Gradio input to PIL Image
21
+ if isinstance(image, str):
22
+ image = Image.open(image)
23
+ else:
24
+ image = Image.fromarray(image)
 
 
 
 
 
25
 
26
+ # Prepare prompt
27
+ prompt = f"<image>\nUSER: {question}\nASSISTANT:"
 
 
 
 
 
 
 
28
 
29
+ # Run inference
30
+ args = type('Args', (), {
31
+ "model_name": model_name,
32
+ "query": prompt,
33
+ "conv_mode": None,
34
+ "image_file": image,
35
+ "sep": ",",
36
+ "temperature": 0.2,
37
+ "top_p": None,
38
+ "num_beams": 1,
39
+ "max_new_tokens": 512
40
+ })()
41
 
42
+ return eval_model(args, tokenizer, model, image_processor)
43
 
44
  # Gradio interface
45
  with gr.Blocks() as demo:
46
+ gr.Markdown("# LLaVA-Med Medical Analysis")
 
 
47
  with gr.Row():
48
+ gr.Image(type="pil", label="Input Image", source="upload", elem_id="image")
49
+ gr.Textbox(label="Question", placeholder="Ask about the medical image...")
50
+ gr.Textbox(label="Analysis Result", interactive=False)
51
+
52
+ examples = [
53
+ ["examples/xray.jpg", "Are there any signs of pneumonia in this chest X-ray?"],
54
+ ["examples/mri.jpg", "What abnormalities are visible in this brain MRI?"]
55
+ ]
56
+
57
+ gr.Examples(examples=examples, inputs=[image, question])
58
+
59
+ demo.launch()