Cylanoid commited on
Commit
36b5bed
·
verified ·
1 Parent(s): 9b2c756

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -177
app.py CHANGED
@@ -23,6 +23,9 @@ except LookupError:
23
  # Import the HealthcareFraudAnalyzer
24
  from document_analyzer import HealthcareFraudAnalyzer
25
 
 
 
 
26
  # Debug: Print environment variables
27
  print("Environment variables:", dict(os.environ))
28
 
@@ -44,16 +47,16 @@ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
44
  if tokenizer.pad_token is None:
45
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
46
 
47
- # Device map for CPU offloading
48
  device_map = {
49
  "model.embed_tokens": 0,
50
- "model.layers.0-15": 0,
51
- "model.layers.16-31": "cpu",
52
  "model.norm": 0,
53
  "lm_head": 0
54
  }
55
 
56
- # Load model with 8-bit quantization
57
  model = Llama4ForConditionalGeneration.from_pretrained(
58
  MODEL_ID,
59
  torch_dtype=torch.bfloat16,
@@ -63,192 +66,108 @@ model = Llama4ForConditionalGeneration.from_pretrained(
63
  attn_implementation="flex_attention"
64
  )
65
 
66
- # Prepare for LoRA training
67
- model = prepare_model_for_kbit_training(model)
68
- peft_config = LoraConfig(
69
- r=16,
70
- lora_alpha=32,
71
- lora_dropout=0.05,
72
- bias="none",
73
- task_type="CAUSAL_LM",
74
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
75
- )
76
- model = get_peft_model(model, peft_config)
77
- model.print_trainable_parameters()
78
-
79
- # Function to create training pairs
80
- def extract_training_pairs_from_text(text):
81
- pairs = []
82
- patterns = [
83
- (r"(?i).*?\b(haloperidol|lorazepam|ativan)\b.*?\b(daily|routine|regular)\b.*?",
84
- "Patient receives {} on a {} basis. Is this appropriate?",
85
- "This may indicate inappropriate use. Regular psychotropic use without need assessment may violate standards."),
86
- (r"(?i).*?\b(missing|omitted|absent|lacking)\b.*?\b(documentation|records|logs|notes)\b.*?",
87
- "Facility has {} {} for care. Is this a concern?",
88
- "Yes, incomplete records may indicate fraud or attempts to hide issues."),
89
- (r"(?i).*?\b(restrict|limit|prevent|block)\b.*?\b(visits|visitation|access|family)\b.*?",
90
- "Facility {} family {} without necessity. Is this suspicious?",
91
- "Yes, restrictions may hide issues and constitute fraud when billing for care."),
92
- (r"(?i).*?\b(hospice|terminal|end.of.life)\b.*?\b(not|without|lacking)\b.*?\b(evidence|decline|documentation)\b.*?",
93
- "Patient on {} care {} supporting {}. Is this fraudulent?",
94
- "Yes, hospice without documented decline may indicate Medicare fraud."),
95
- (r"(?i).*?\b(different|contradicts|conflicts|inconsistent)\b.*?\b(records|documentation|testimony|statements)\b.*?",
96
- "Records show {} {} about condition. Is this fraudulent?",
97
- "Yes, contradictory records suggest fraudulent misrepresentation.")
98
- ]
99
-
100
- for pattern, input_template, output_text in patterns:
101
- for match in re.finditer(pattern, text):
102
- groups = match.groups()
103
- if len(groups) >= 2:
104
- pairs.append({"input": input_template.format(*groups), "output": output_text})
105
-
106
- if not pairs:
107
- if any(x in text.lower() for x in ["medication", "prescribed", "administered"]):
108
- pairs.append({
109
- "input": "Medication records show inconsistent times. Is this concerning?",
110
- "output": "Yes, inconsistent timing may indicate fraud or mismanagement."
111
- })
112
- if any(x in text.lower() for x in ["visit", "family", "spouse"]):
113
- pairs.append({
114
- "input": "Staff documents visits inconsistently. Is this suspicious?",
115
- "output": "Yes, selective documentation suggests fraudulent record-keeping."
116
- })
117
- if any(x in text.lower() for x in ["hospice", "terminal", "prognosis"]):
118
- pairs.append({
119
- "input": "Patient on hospice without decline. Is this fraud?",
120
- "output": "Yes, lack of decline suggests fraudulent certification."
121
- })
122
-
123
- return pairs
124
-
125
- # Function to process files and train
126
- def train_ui(files):
127
  try:
128
- raw_text = ""
129
- dataset = None
130
- for file in files:
131
- if file.name.endswith(".pdf"):
132
- with pdfplumber.open(file.name) as pdf:
133
- for page in pdf.pages:
134
- raw_text += page.extract_text() or ""
135
- elif file.name.endswith(".json"):
136
- with open(file.name, "r", encoding="utf-8") as f:
137
- raw_data = json.load(f)
138
- training_data = raw_data.get("training_pairs", raw_data)
139
- with open("temp_fraud_data.json", "w", encoding="utf-8") as f:
140
- json.dump({"training_pairs": training_data}, f)
141
- dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
142
-
143
- if not raw_text and not dataset:
144
- return "Error: No valid PDF or JSON data found."
145
-
146
- if raw_text:
147
- training_data = extract_training_pairs_from_text(raw_text)
148
- with open("temp_fraud_data.json", "w") as f:
149
- json.dump({"training_pairs": training_data}, f)
150
- dataset = datasets.load_dataset("json", data_files="temp_fraud_data.json")
151
-
152
- def tokenize_data(example):
153
- formatted_text = f"<s>[INST] {example['input']} [/INST] {example['output']}</s>"
154
- inputs = tokenizer(formatted_text, padding="max_length", truncation=True, max_length=4096, return_tensors="pt")
155
- inputs["labels"] = inputs["input_ids"].clone()
156
- return {k: v.squeeze(0) for k, v in inputs.items()}
157
-
158
- tokenized_dataset = dataset["train"].map(tokenize_data, batched=True, remove_columns=dataset["train"].column_names)
159
-
160
- training_args = TrainingArguments(
161
- output_dir="./fine_tuned_llama4_healthcare",
162
- per_device_train_batch_size=2,
163
- gradient_accumulation_steps=8,
164
- eval_strategy="no",
165
- save_strategy="epoch",
166
- save_total_limit=2,
167
- num_train_epochs=5,
168
- learning_rate=2e-5,
169
- weight_decay=0.01,
170
- logging_dir="./logs",
171
- logging_steps=10,
172
- bf16=True,
173
- gradient_checkpointing=True,
174
- optim="adamw_torch",
175
- warmup_steps=100,
176
  )
177
 
178
- def custom_data_collator(features):
179
- return {
180
- "input_ids": torch.stack([f["input_ids"] for f in features]),
181
- "attention_mask": torch.stack([f["attention_mask"] for f in features]),
182
- "labels": torch.stack([f["labels"] for f in features]),
183
- }
184
-
185
- trainer = Trainer(
186
- model=model,
187
- args=training_args,
188
- train_dataset=tokenized_dataset,
189
- data_collator=custom_data_collator,
 
 
 
 
 
 
 
 
 
 
 
 
190
  )
191
 
192
  trainer.train()
193
- model.save_pretrained("./fine_tuned_llama4_healthcare")
194
- tokenizer.save_pretrained("./fine_tuned_llama4_healthcare")
195
- return f"Training completed with {len(tokenized_dataset)} examples! Model saved to ./fine_tuned_llama4_healthcare"
196
-
197
  except Exception as e:
198
- return f"Error: {str(e)}. Please check file format, dependencies, or the LLama token."
199
 
200
- # Function to analyze documents
201
- def analyze_document_ui(files):
202
  try:
203
- if not files:
204
- return "Error: No file uploaded. Please upload a PDF."
205
-
206
- file = files[0]
207
- if not file.name.endswith(".pdf"):
208
- return "Error: Please upload a PDF file."
209
-
210
- raw_text = ""
211
- with pdfplumber.open(file.name) as pdf:
212
  for page in pdf.pages:
213
- raw_text += page.extract_text() or ""
214
 
215
- if not raw_text:
216
- return "Error: Could not extract text from PDF."
217
 
218
- analyzer = HealthcareFraudAnalyzer(model, tokenizer)
219
- results = analyzer.analyze_document(raw_text)
220
- return results["summary"]
221
-
 
 
 
222
  except Exception as e:
223
- return f"Error during analysis: {str(e)}"
224
 
225
- # Gradio UI
226
- with gr.Blocks(title="Healthcare Fraud Detection Suite") as demo:
227
- gr.Markdown("# Healthcare Fraud Detection Suite")
228
 
229
- with gr.Tabs():
230
- with gr.TabItem("Fine-Tune Model"):
231
- gr.Markdown("## Train Llama 4 for Fraud Detection")
232
- gr.Markdown("Upload PDFs or JSON with training pairs.")
233
- train_file_input = gr.File(label="Upload Files (PDF/JSON)", file_count="multiple")
234
- train_button = gr.Button("Start Fine-Tuning")
235
- train_output = gr.Textbox(label="Training Status", lines=5)
236
- train_button.click(fn=train_ui, inputs=train_file_input, outputs=train_output)
237
-
238
- with gr.TabItem("Analyze Document"):
239
- gr.Markdown("## Analyze for Fraud Indicators")
240
- gr.Markdown("Upload a PDF to scan for fraud, neglect, or abuse.")
241
- analyze_file_input = gr.File(label="Upload PDF")
242
- analyze_button = gr.Button("Analyze Document")
243
- analyze_output = gr.Markdown(label="Analysis Results")
244
- analyze_button.click(fn=analyze_document_ui, inputs=analyze_file_input, outputs=analyze_output)
245
 
246
- gr.Markdown("""
247
- ### About This Tool
248
- Uses Llama 4 Maverick to detect fraud in healthcare documents.
249
- Fine-tune with custom data or analyze PDFs for suspicious patterns.
250
- **Note:** All analysis is local - no data is shared.
251
- """)
252
-
253
- # Launch the app
254
- demo.launch()
 
 
 
23
  # Import the HealthcareFraudAnalyzer
24
  from document_analyzer import HealthcareFraudAnalyzer
25
 
26
+ # Debug: Confirm file version
27
+ print("Running updated app.py with CPU offloading (version: 2025-04-21)")
28
+
29
  # Debug: Print environment variables
30
  print("Environment variables:", dict(os.environ))
31
 
 
47
  if tokenizer.pad_token is None:
48
  tokenizer.add_special_tokens({'pad_token': '[PAD]'})
49
 
50
+ # Custom device map for CPU offloading
51
  device_map = {
52
  "model.embed_tokens": 0,
53
+ "model.layers.0-15": 0, # First 16 layers on GPU
54
+ "model.layers.16-31": "cpu", # Remaining layers on CPU
55
  "model.norm": 0,
56
  "lm_head": 0
57
  }
58
 
59
+ # Load model with 8-bit quantization and CPU offloading
60
  model = Llama4ForConditionalGeneration.from_pretrained(
61
  MODEL_ID,
62
  torch_dtype=torch.bfloat16,
 
66
  attn_implementation="flex_attention"
67
  )
68
 
69
+ # Resize token embeddings if pad token was added
70
+ model.resize_token_embeddings(len(tokenizer))
71
+
72
+ # Initialize Accelerator
73
+ accelerator = Accelerator()
74
+ model = accelerator.prepare(model)
75
+
76
+ # Initialize analyzer
77
+ analyzer = HealthcareFraudAnalyzer(model, tokenizer, accelerator)
78
+
79
+ # Training function
80
+ def fine_tune_model(training_data_file, epochs=1, batch_size=2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
+ dataset = datasets.load_dataset('json', data_files=training_data_file)
83
+ dataset = dataset['train']
84
+
85
+ lora_config = LoraConfig(
86
+ r=16,
87
+ lora_alpha=32,
88
+ target_modules=["q_proj", "v_proj"],
89
+ lora_dropout=0.05,
90
+ bias="none",
91
+ task_type="CAUSAL_LM"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  )
93
 
94
+ model = prepare_model_for_kbit_training(model)
95
+ model = get_peft_model(model, lora_config)
96
+
97
+ training_args = {
98
+ "output_dir": "./results",
99
+ "num_train_epochs": int(epochs),
100
+ "per_device_train_batch_size": int(batch_size),
101
+ "gradient_accumulation_steps": 8,
102
+ "optim": "adamw_torch",
103
+ "save_steps": 500,
104
+ "logging_steps": 100,
105
+ "learning_rate": 2e-4,
106
+ "fp16": True,
107
+ "max_grad_norm": 0.3,
108
+ "warmup_ratio": 0.03,
109
+ "lr_scheduler_type": "cosine"
110
+ }
111
+
112
+ trainer = accelerator.prepare(
113
+ datasets.Trainer(
114
+ model=model,
115
+ args=datasets.TrainingArguments(**training_args),
116
+ train_dataset=dataset,
117
+ )
118
  )
119
 
120
  trainer.train()
121
+ model.save_pretrained("./fine_tuned_model")
122
+ return f"Training completed with {len(dataset)} examples!"
 
 
123
  except Exception as e:
124
+ return f"Training failed: {str(e)}"
125
 
126
+ # Document analysis function
127
+ def analyze_document(pdf_file):
128
  try:
129
+ with pdfplumber.open(pdf_file) as pdf:
130
+ text = ""
 
 
 
 
 
 
 
131
  for page in pdf.pages:
132
+ text += page.extract_text() or ""
133
 
134
+ sentences = sent_tokenize(text)
135
+ fraud_indicators = analyzer.analyze_document(sentences)
136
 
137
+ if not fraud_indicators:
138
+ return "No fraud indicators detected."
139
+
140
+ report = "Potential Fraud Indicators Detected:\n"
141
+ for indicator in fraud_indicators:
142
+ report += f"- {indicator['sentence']}\n Reason: {indicator['reason']}\n Confidence: {indicator['confidence']:.2f}\n"
143
+ return report
144
  except Exception as e:
145
+ return f"Analysis failed: {str(e)}"
146
 
147
+ # Gradio interface
148
+ with gr.Blocks(theme=gr.themes.Default()) as demo:
149
+ gr.Markdown("# Llama 4 Healthcare Fraud Detection")
150
 
151
+ with gr.Tab("Fine-Tune Model"):
152
+ training_data = gr.File(label="Upload Training JSON File")
153
+ epochs = gr.Slider(minimum=1, maximum=10, value=1, step=1, label="Epochs")
154
+ batch_size = gr.Slider(minimum=1, maximum=4, value=2, step=1, label="Batch Size")
155
+ train_button = gr.Button("Fine-Tune")
156
+ train_output = gr.Textbox(label="Training Output")
157
+ train_button.click(
158
+ fn=fine_tune_model,
159
+ inputs=[training_data, epochs, batch_size],
160
+ outputs=train_output
161
+ )
 
 
 
 
 
162
 
163
+ with gr.Tab("Analyze Document"):
164
+ pdf_input = gr.File(label="Upload PDF Document")
165
+ analyze_button = gr.Button("Analyze")
166
+ analysis_output = gr.Textbox(label="Analysis Results")
167
+ analyze_button.click(
168
+ fn=analyze_document,
169
+ inputs=pdf_input,
170
+ outputs=analysis_output
171
+ )
172
+
173
+ demo.launch(server_name="0.0.0.0", server_port=7860)