mjavadmt commited on
Commit
754963a
·
1 Parent(s): da5c744

better visualization

Browse files
Files changed (1) hide show
  1. app.py +35 -13
app.py CHANGED
@@ -1,6 +1,8 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  import torch
 
 
4
  from model import EnergySmellsDetector
5
  from config import SMELLS, BEST_THRESHOLD
6
 
@@ -8,29 +10,49 @@ TOKENIZER = "microsoft/graphcodebert-base"
8
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
9
  model = EnergySmellsDetector.load_model_from_hf()
10
 
11
-
12
- def round_logit(logits, threshold):
13
- logits = (logits > threshold).to(int)
14
- return logits.cpu().numpy()
15
-
16
-
17
- def greet(code_snippet):
18
  inputs = tokenizer(code_snippet, return_tensors="pt", truncation=True)
19
  with torch.no_grad():
20
  logits = model(**inputs)[0]
21
- rounded_logits = round_logit(logits, BEST_THRESHOLD)
22
- return f"{dict(zip(SMELLS, map(int, rounded_logits)))}"
 
 
 
 
 
 
 
 
 
 
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
- textbox = gr.Textbox(label="Enter your code snippet", placeholder="Here goes your code")
27
- description = "An application to identify whether your code has energy smells or not. It predicts the presence of 9 different energy smells."
28
  title = "Energy Smells Detector"
 
29
 
30
  gr.Interface(
31
  title=title,
32
  description=description,
33
  inputs=textbox,
34
- fn=greet,
35
- outputs="text"
 
 
 
36
  ).launch()
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer
3
  import torch
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
  from model import EnergySmellsDetector
7
  from config import SMELLS, BEST_THRESHOLD
8
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
11
  model = EnergySmellsDetector.load_model_from_hf()
12
 
13
+ def get_predictions(code_snippet):
 
 
 
 
 
 
14
  inputs = tokenizer(code_snippet, return_tensors="pt", truncation=True)
15
  with torch.no_grad():
16
  logits = model(**inputs)[0]
17
+ probs = torch.sigmoid(logits).cpu().numpy().flatten()
18
+ rounded_logits = (probs > BEST_THRESHOLD).astype(int)
19
+
20
+ # Prepare results in a dictionary
21
+ results = {label: {"Detected": bool(pred), "Confidence": round(prob * 100, 2)}
22
+ for label, pred, prob in zip(SMELLS, rounded_logits, probs)}
23
+
24
+ return results, plot_bar_chart(results)
25
+
26
+
27
+ def plot_bar_chart(results):
28
+ labels = list(results.keys())
29
+ confidences = [results[label]["Confidence"] for label in labels]
30
 
31
+ plt.figure(figsize=(8, 4))
32
+ plt.barh(labels, confidences, color=['green' if results[label]["Detected"] else 'red' for label in labels])
33
+ plt.xlabel("Confidence (%)")
34
+ plt.xlim(0, 100)
35
+ plt.title("Energy Smells Detection Confidence")
36
+ plt.gca().invert_yaxis() # Invert y-axis for better readability
37
+ plt.tight_layout()
38
+ img_path = "confidence_chart.png"
39
+ plt.savefig(img_path)
40
+ plt.close()
41
+
42
+ return img_path
43
 
44
 
45
+ textbox = gr.Textbox(label="Enter your code snippet", placeholder="Paste your code here...")
 
46
  title = "Energy Smells Detector"
47
+ description = "Analyze your code for potential energy smells. The model detects 9 different energy inefficiencies in your code."
48
 
49
  gr.Interface(
50
  title=title,
51
  description=description,
52
  inputs=textbox,
53
+ fn=get_predictions,
54
+ outputs=[
55
+ gr.Json(label="Detection Results"),
56
+ gr.Image(label="Confidence Bar Chart")
57
+ ]
58
  ).launch()