ayyuce commited on
Commit
398bcaf
·
verified ·
1 Parent(s): 9474534

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -51
app.py CHANGED
@@ -1,79 +1,67 @@
1
  import gradio as gr
2
- from llava.model.builder import load_pretrained_model
3
- from llava.mm_utils import process_images, tokenizer_image_token
4
- from llava.constants import IMAGE_TOKEN_INDEX
5
- import torch
6
  from PIL import Image
 
7
 
8
- model_path = "microsoft/llava-med-v1.5-mistral-7b"
9
- tokenizer, model, image_processor, _ = load_pretrained_model(
10
- model_path=model_path,
11
- model_base=None,
12
- model_name="llava-med-v1.5-mistral-7b",
13
- load_4bit=False, # Disable 4-bit quantization for CPU
14
- device_map="cpu" # Force CPU usage
 
 
15
  )
16
- model.to('cpu')
17
 
18
  def analyze_medical_image(image, question):
19
- if isinstance(image, str):
20
- image = Image.open(image)
21
- else:
22
- image = Image.fromarray(image)
23
 
24
- image_tensor = process_images([image], image_processor, model.config)[0]
25
- prompt = f"USER: <image>\n{question}\nASSISTANT:"
 
 
 
 
 
26
 
27
- input_ids = tokenizer_image_token(
28
- prompt,
29
- tokenizer,
30
- IMAGE_TOKEN_INDEX,
31
- return_tensors='pt'
32
- ).unsqueeze(0)
33
-
34
- with torch.inference_mode():
35
- output_ids = model.generate(
36
- input_ids,
37
- images=image_tensor.unsqueeze(0),
38
- max_new_tokens=512,
39
  do_sample=True,
40
  temperature=0.7,
41
- use_cache=True
42
  )
43
-
44
- response = tokenizer.decode(
45
- output_ids[0][input_ids.shape[1]:],
 
46
  skip_special_tokens=True
47
- ).strip()
48
 
49
  return response
50
 
 
51
  with gr.Blocks() as demo:
52
- gr.Markdown("# LLaVA-Med Medical Image Analysis")
53
- gr.Markdown("Ask questions about medical images using Microsoft's LLaVA-Med 1.5-Mistral-7B")
54
 
55
  with gr.Row():
56
  with gr.Column():
57
- image_input = gr.Image(label="Upload Medical Image", type="pil")
58
- question_input = gr.Textbox(label="Question", placeholder="Ask about the medical image...")
59
  submit_btn = gr.Button("Analyze")
60
 
61
  with gr.Column():
62
- output_text = gr.Textbox(label="Analysis Result", interactive=False)
63
-
64
- examples = gr.Examples(
65
- examples=[
66
- ["examples/chest_xray.jpg", "What abnormalities are present in this chest X-ray?"],
67
- ["examples/retina_scan.jpg", "Are there any signs of diabetic retinopathy?"]
68
- ],
69
- inputs=[image_input, question_input],
70
- label="Example Queries"
71
- )
72
-
73
  submit_btn.click(
74
  fn=analyze_medical_image,
75
  inputs=[image_input, question_input],
76
  outputs=output_text
77
  )
78
 
79
- demo.queue(max_size=10).launch()
 
1
  import gradio as gr
2
+ from llava_med import LlavaMedProcessor, LlavaMedForCausalLM
 
 
 
3
  from PIL import Image
4
+ import torch
5
 
6
+ # Load model and processor
7
+ model = LlavaMedForCausalLM.from_pretrained(
8
+ "microsoft/llava-med-v1.5-mistral-7b",
9
+ torch_dtype=torch.float32, # Use float32 for CPU stability
10
+ low_cpu_mem_usage=True,
11
+ device_map="cpu"
12
+ )
13
+ processor = LlavaMedProcessor.from_pretrained(
14
+ "microsoft/llava-med-v1.5-mistral-7b"
15
  )
 
16
 
17
  def analyze_medical_image(image, question):
18
+ # Prepare inputs
19
+ prompt = f"Question: {question} Answer:"
 
 
20
 
21
+ # Process inputs
22
+ inputs = processor(
23
+ text=prompt,
24
+ images=image,
25
+ return_tensors="pt",
26
+ padding=True
27
+ ).to("cpu")
28
 
29
+ # Generate response
30
+ with torch.no_grad():
31
+ outputs = model.generate(
32
+ **inputs,
33
+ max_new_tokens=256,
 
 
 
 
 
 
 
34
  do_sample=True,
35
  temperature=0.7,
36
+ top_p=0.9
37
  )
38
+
39
+ # Decode response
40
+ response = processor.batch_decode(
41
+ outputs,
42
  skip_special_tokens=True
43
+ )[0].split("Answer:")[-1].strip()
44
 
45
  return response
46
 
47
+ # Gradio interface
48
  with gr.Blocks() as demo:
49
+ gr.Markdown("# LLaVA-Med Medical Analysis (CPU)")
50
+ gr.Markdown("Official Microsoft LLaVA-Med 1.5-Mistral-7B implementation")
51
 
52
  with gr.Row():
53
  with gr.Column():
54
+ image_input = gr.Image(label="Medical Image", type="pil")
55
+ question_input = gr.Textbox(label="Clinical Question", placeholder="Enter your medical question...")
56
  submit_btn = gr.Button("Analyze")
57
 
58
  with gr.Column():
59
+ output_text = gr.Textbox(label="Clinical Analysis", interactive=False)
60
+
 
 
 
 
 
 
 
 
 
61
  submit_btn.click(
62
  fn=analyze_medical_image,
63
  inputs=[image_input, question_input],
64
  outputs=output_text
65
  )
66
 
67
+ demo.queue(max_size=5).launch()