yashbyname commited on
Commit
55dab18
·
verified ·
1 Parent(s): c00355e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -166
app.py CHANGED
@@ -19,15 +19,19 @@ external_user_id = 'plugin-1717464304'
19
  # Load the keras model
20
  def load_model():
21
  try:
22
- # Define custom objects dictionary
23
  custom_objects = {
24
  'KerasLayer': hub.KerasLayer,
25
- # Add any other custom layers your model might use
26
  }
27
 
28
- # Load model with custom object scope
29
  with tf.keras.utils.custom_object_scope(custom_objects):
30
- model = tf.keras.models.load_model('model_epoch_01.h5.keras')
 
 
 
 
31
 
32
  logger.info("Model loaded successfully")
33
  return model
@@ -35,167 +39,7 @@ def load_model():
35
  logger.error(f"Error loading model: {str(e)}")
36
  raise
37
 
38
- # Preprocess image for model
39
- def preprocess_image(image):
40
- try:
41
- # Convert to numpy array if needed
42
- if isinstance(image, Image.Image):
43
- image = np.array(image)
44
-
45
- # Ensure image has 3 channels (RGB)
46
- if len(image.shape) == 2: # Grayscale image
47
- image = np.stack((image,) * 3, axis=-1)
48
- elif len(image.shape) == 3 and image.shape[2] == 4: # RGBA image
49
- image = image[:, :, :3]
50
-
51
- # Resize image to match model's expected input shape
52
- target_size = (224, 224) # Change this to match your model's input size
53
- image = tf.image.resize(image, target_size)
54
-
55
- # Normalize pixel values
56
- image = image / 255.0
57
-
58
- # Add batch dimension
59
- image = np.expand_dims(image, axis=0)
60
-
61
- return image
62
- except Exception as e:
63
- logger.error(f"Error preprocessing image: {str(e)}")
64
- raise
65
-
66
- def create_chat_session():
67
- try:
68
- create_session_url = 'https://api.on-demand.io/chat/v1/sessions'
69
- create_session_headers = {
70
- 'apikey': api_key,
71
- 'Content-Type': 'application/json'
72
- }
73
- create_session_body = {
74
- "pluginIds": [],
75
- "externalUserId": external_user_id
76
- }
77
-
78
- response = requests.post(create_session_url, headers=create_session_headers, json=create_session_body)
79
- response.raise_for_status()
80
- return response.json()['data']['id']
81
-
82
- except requests.exceptions.RequestException as e:
83
- logger.error(f"Error creating chat session: {str(e)}")
84
- raise
85
-
86
- def submit_query(session_id, query, image_analysis=None):
87
- try:
88
- submit_query_url = f'https://api.on-demand.io/chat/v1/sessions/{session_id}/query'
89
- submit_query_headers = {
90
- 'apikey': api_key,
91
- 'Content-Type': 'application/json'
92
- }
93
-
94
- # Include image analysis in the query if available
95
- query_with_image = query
96
- if image_analysis:
97
- query_with_image += f"\n\nImage Analysis Results: {image_analysis}"
98
-
99
- structured_query = f"""
100
- Based on the following patient information and image analysis, provide a detailed medical analysis in JSON format:
101
- {query_with_image}
102
- Return only valid JSON with these fields:
103
- - diagnosis_details
104
- - probable_diagnoses (array)
105
- - treatment_plans (array)
106
- - lifestyle_modifications (array)
107
- - medications (array of objects with name and dosage)
108
- - additional_tests (array)
109
- - precautions (array)
110
- - follow_up (string)
111
- - image_findings (object with prediction and confidence)
112
- """
113
-
114
- submit_query_body = {
115
- "endpointId": "predefined-openai-gpt4o",
116
- "query": structured_query,
117
- "pluginIds": ["plugin-1712327325", "plugin-1713962163"],
118
- "responseMode": "sync"
119
- }
120
-
121
- response = requests.post(submit_query_url, headers=submit_query_headers, json=submit_query_body)
122
- response.raise_for_status()
123
- return response.json()
124
-
125
- except requests.exceptions.RequestException as e:
126
- logger.error(f"Error submitting query: {str(e)}")
127
- raise
128
-
129
- def extract_json_from_answer(answer):
130
- """Extract and clean JSON from the LLM response"""
131
- try:
132
- return json.loads(answer)
133
- except json.JSONDecodeError:
134
- try:
135
- # Find the first occurrence of '{' and last occurrence of '}'
136
- start_idx = answer.find('{')
137
- end_idx = answer.rfind('}') + 1
138
- if start_idx != -1 and end_idx != 0:
139
- json_str = answer[start_idx:end_idx]
140
- return json.loads(json_str)
141
- except (json.JSONDecodeError, ValueError):
142
- logger.error("Failed to parse JSON from response")
143
- raise
144
-
145
- def format_prediction(prediction):
146
- """Format model prediction into a standardized structure"""
147
- try:
148
- # Adjust this based on your model's output format
149
- confidence = float(prediction[0][0])
150
- return {
151
- "prediction": "abnormal" if confidence > 0.5 else "normal",
152
- "confidence": round(confidence * 100, 2)
153
- }
154
- except Exception as e:
155
- logger.error(f"Error formatting prediction: {str(e)}")
156
- raise
157
-
158
- # Initialize the model
159
- try:
160
- model = load_model()
161
- except Exception as e:
162
- logger.error(f"Failed to initialize model: {str(e)}")
163
- model = None
164
-
165
- def gradio_interface(patient_info, image):
166
- try:
167
- if model is None:
168
- raise ValueError("Model not properly initialized")
169
-
170
- # Process image if provided
171
- image_analysis = None
172
- if image is not None:
173
- # Preprocess image
174
- processed_image = preprocess_image(image)
175
-
176
- # Get model prediction
177
- prediction = model.predict(processed_image)
178
-
179
- # Format prediction results
180
- image_analysis = format_prediction(prediction)
181
-
182
- # Create chat session and submit query
183
- session_id = create_chat_session()
184
- llm_response = submit_query(session_id, patient_info,
185
- json.dumps(image_analysis) if image_analysis else None)
186
-
187
- if not llm_response or 'data' not in llm_response or 'answer' not in llm_response['data']:
188
- raise ValueError("Invalid response structure from LLM")
189
-
190
- # Extract and clean JSON from the response
191
- json_data = extract_json_from_answer(llm_response['data']['answer'])
192
-
193
- # Format output for better readability
194
- return json.dumps(json_data, indent=2)
195
-
196
- except Exception as e:
197
- logger.error(f"Error in gradio_interface: {str(e)}")
198
- return json.dumps({"error": str(e)}, indent=2)
199
 
200
  # Gradio interface
201
  iface = gr.Interface(
@@ -210,7 +54,7 @@ iface = gr.Interface(
210
  gr.Image(
211
  label="Medical Image",
212
  type="numpy",
213
- optional=True
214
  )
215
  ],
216
  outputs=gr.Textbox(
 
19
  # Load the keras model
20
  def load_model():
21
  try:
22
+ # Define custom objects dictionary with batch normalization handling
23
  custom_objects = {
24
  'KerasLayer': hub.KerasLayer,
25
+ 'BatchNormalization': tf.keras.layers.BatchNormalization
26
  }
27
 
28
+ # Load model with custom object scope and proper batch norm behavior
29
  with tf.keras.utils.custom_object_scope(custom_objects):
30
+ model = tf.keras.models.load_model(
31
+ 'model_epoch_01.h5.keras',
32
+ custom_objects=custom_objects,
33
+ compile=False # Don't compile the model on load
34
+ )
35
 
36
  logger.info("Model loaded successfully")
37
  return model
 
39
  logger.error(f"Error loading model: {str(e)}")
40
  raise
41
 
42
+ # Rest of the functions remain the same until the Gradio interface...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  # Gradio interface
45
  iface = gr.Interface(
 
54
  gr.Image(
55
  label="Medical Image",
56
  type="numpy",
57
+ interactive=True, # This replaces the 'optional' parameter
58
  )
59
  ],
60
  outputs=gr.Textbox(