naveenus commited on
Commit
b355243
Β·
verified Β·
1 Parent(s): 09ffcaf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -22,11 +22,12 @@ MODEL_CHOICES = [
22
  "FractalAIResearch/Fathom-R1-14B" # placeholder β€” replace with real phantom model
23
  ]
24
 
 
25
 
26
- PIPELINES = {
27
- name: pipeline("zero-shot-classification", model=name)
28
- for name in MODEL_CHOICES
29
- }
30
 
31
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
32
  # Ensure log files exist
@@ -48,7 +49,7 @@ def run_stage1(question, model_name):
48
  if not question or not question.strip():
49
  return {}, gr.update(choices=[]), ""
50
  start = time.time()
51
- clf = PIPELINES[model_name]
52
  out = clf(question, candidate_labels=coarse_labels)
53
  labels, scores = out["labels"][:3], out["scores"][:3]
54
  duration = round(time.time() - start, 3)
@@ -71,7 +72,7 @@ def run_stage2(question, model_name, subject):
71
 
72
  # 2) Inference (fast, using preloaded pipeline)
73
  start = time.time()
74
- clf = PIPELINES[model_name]
75
  out = clf(question, candidate_labels=fine_labels)
76
  labels, scores = out["labels"][:3], out["scores"][:3]
77
  duration = round(time.time() - start, 3)
 
22
  "FractalAIResearch/Fathom-R1-14B" # placeholder β€” replace with real phantom model
23
  ]
24
 
25
+ PIPELINES = {}
26
 
27
+ def get_pipeline(name):
28
+ if name not in PIPELINES:
29
+ PIPELINES[name] = pipeline("zero-shot-classification", model=name)
30
+ return PIPELINES[name]
31
 
32
  # β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
33
  # Ensure log files exist
 
49
  if not question or not question.strip():
50
  return {}, gr.update(choices=[]), ""
51
  start = time.time()
52
+ clf = get_pipeline(model_name)
53
  out = clf(question, candidate_labels=coarse_labels)
54
  labels, scores = out["labels"][:3], out["scores"][:3]
55
  duration = round(time.time() - start, 3)
 
72
 
73
  # 2) Inference (fast, using preloaded pipeline)
74
  start = time.time()
75
+ clf = get_pipeline(model_name)
76
  out = clf(question, candidate_labels=fine_labels)
77
  labels, scores = out["labels"][:3], out["scores"][:3]
78
  duration = round(time.time() - start, 3)