DaddyDaniel commited on
Commit
e497738
Β·
1 Parent(s): e8a63bc

Reinstate model caching

Browse files

- uncomment @cache_resource to reactivate model caching
- Added subheaders for generated output

Files changed (3) hide show
  1. main.py +0 -7
  2. qwen2_inference.py +1 -2
  3. sketch2diagram.py +5 -4
main.py CHANGED
@@ -1,5 +1,4 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
 
4
  st.logo("NLP_Group_logo.svg", size="large")
5
  main_page = st.Page("main_page.py", title="Main Page", icon="🏠")
@@ -8,9 +7,3 @@ sketch2diagram_page = st.Page("sketch2diagram.py", title="Sketch2Diagram", icon=
8
  pg = st.navigation([main_page, sketch2diagram_page])
9
 
10
  pg.run()
11
-
12
- @st.cache_resource
13
- def get_model():
14
- # Load the model here
15
- model = pipeline("image-to-text", model="itsumi-st/imgtikz_qwen2vl")
16
- return model
 
1
  import streamlit as st
 
2
 
3
  st.logo("NLP_Group_logo.svg", size="large")
4
  main_page = st.Page("main_page.py", title="Main Page", icon="🏠")
 
7
  pg = st.navigation([main_page, sketch2diagram_page])
8
 
9
  pg.run()
 
 
 
 
 
 
qwen2_inference.py CHANGED
@@ -20,8 +20,7 @@ def print_gpu_memory(label, memory_allocated, memory_reserved):
20
 
21
 
22
  # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
23
-
24
- # @st.cache_resource
25
  def get_model(model_path):
26
  try:
27
  with st.spinner(f"Loading model {model_path}"):
 
20
 
21
 
22
  # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
23
+ @st.cache_resource
 
24
  def get_model(model_path):
25
  try:
26
  with st.spinner(f"Loading model {model_path}"):
sketch2diagram.py CHANGED
@@ -53,6 +53,7 @@ if generate_command:
53
  output = run_inference(input_file, model_name, args)[0]
54
  pdf_file_path = compile_tikz_to_pdf(output)
55
  if output and pdf_file_path:
 
56
  st.success("TikZ code generated successfully!")
57
  st.code(output, language='latex')
58
 
@@ -63,16 +64,16 @@ if generate_command:
63
  mime="text/plain"
64
  )
65
 
66
- # st.image(pdf_file_path, caption="Generated Diagram", use_column_width=True)
 
 
67
  with open(pdf_file_path, "rb") as f:
68
  st.download_button(
69
  label="Download PDF",
70
- data=f.read(), # βœ… this is the binary content
71
  file_name="output.pdf",
72
  mime="application/pdf"
73
  )
74
 
75
- images = convert_from_path(pdf_file_path)
76
- st.image(images[0], caption="Generated Diagram", use_column_width=True)
77
  else:
78
  st.error("Failed to generate TikZ code.")
 
53
  output = run_inference(input_file, model_name, args)[0]
54
  pdf_file_path = compile_tikz_to_pdf(output)
55
  if output and pdf_file_path:
56
+ st.subheader("Generated TikZ Code")
57
  st.success("TikZ code generated successfully!")
58
  st.code(output, language='latex')
59
 
 
64
  mime="text/plain"
65
  )
66
 
67
+ st.subheader("Generated Diagram")
68
+ images = convert_from_path(pdf_file_path)
69
+ st.image(images[0], caption="Generated Diagram", use_column_width=True)
70
  with open(pdf_file_path, "rb") as f:
71
  st.download_button(
72
  label="Download PDF",
73
+ data=f.read(),
74
  file_name="output.pdf",
75
  mime="application/pdf"
76
  )
77
 
 
 
78
  else:
79
  st.error("Failed to generate TikZ code.")