DaddyDaniel commited on
Commit
e8a63bc
·
2 Parent(s): f4b5b0a 847829a

Merge branch 'streamlit'

Browse files

# Conflicts:
# main_page.py
# requirements.txt
# sketch2diagram.py

Files changed (9) hide show
  1. .dockerignore +1 -0
  2. Dockerfile +34 -0
  3. NLP_Group_logo.png +0 -0
  4. app.py +15 -0
  5. main_page.py +6 -0
  6. qwen2_inference.py +108 -0
  7. requirements.txt +11 -2
  8. sketch2diagram.py +45 -13
  9. util.py +26 -0
.dockerignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .venv
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.1-cudnn-runtime-ubuntu22.04
2
+
3
+ # Set environment variables to reduce interactive prompts
4
+ ENV DEBIAN_FRONTEND=noninteractive
5
+
6
+ # Install dependencies
7
+ RUN apt-get update && apt-get install -y \
8
+ python3.10 \
9
+ python3-pip \
10
+ git \
11
+ texlive-latex-base \
12
+ texlive-latex-extra \
13
+ texlive-fonts-recommended \
14
+ texlive-latex-recommended \
15
+ latexmk \
16
+ poppler-utils \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Copy the files
20
+ WORKDIR /app
21
+ COPY requirements.txt .
22
+
23
+ RUN pip install --upgrade pip \
24
+ && pip install --no-cache-dir -r requirements.txt
25
+
26
+ ENV PATH="/root/.local/bin:$PATH"
27
+ ENV STREAMLIT_WATCHER_TYPE none
28
+
29
+ RUN pip install --no-cache-dir https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.0.6/flash_attn-2.6.3+cu124torch2.6-cp310-cp310-linux_x86_64.whl
30
+
31
+ COPY . .
32
+
33
+ # Default command
34
+ ENTRYPOINT ["streamlit", "run", "app.py"]
NLP_Group_logo.png ADDED
app.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ logo_path = os.path.join(os.path.dirname(__file__), "NLP_Group_logo.png")
7
+ logo = Image.open(logo_path)
8
+ st.logo(logo, size="large")
9
+ main_page = st.Page("main_page.py", title="Main Page", icon="🏠")
10
+ sketch2diagram_page = st.Page("sketch2diagram.py", title="Sketch2Diagram", icon="🖼️")
11
+ # Add pages to the main page
12
+
13
+ pg = st.navigation([main_page, sketch2diagram_page])
14
+
15
+ pg.run()
main_page.py CHANGED
@@ -3,3 +3,9 @@ import streamlit as st
3
  st.title("Tohoku NLP Group - Language and Information Science Laboratory ")
4
  st.write("Welcome to the Language and Information Science Laboratory!")
5
  st.write("We are working on various projects and research focused on Visual Language Models.")
 
 
 
 
 
 
 
3
  st.title("Tohoku NLP Group - Language and Information Science Laboratory ")
4
  st.write("Welcome to the Language and Information Science Laboratory!")
5
  st.write("We are working on various projects and research focused on Visual Language Models.")
6
+
7
+
8
+ # Link to sketch2diagram page
9
+ st.subheader("You can check out our models and demos here:")
10
+
11
+ st.write("[Sketch2Diagram](sketch2diagram) - A model that generates TikZ code from sketches.")
qwen2_inference.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import streamlit as st
4
+ import torch
5
+ from PIL import Image
6
+ from dotenv import load_dotenv
7
+ from qwen_vl_utils import process_vision_info
8
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
9
+
10
+ load_dotenv()
11
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_HUB_TOKEN")
12
+
13
+
14
+ def print_gpu_memory(label, memory_allocated, memory_reserved):
15
+ if torch.cuda.is_available():
16
+ print("-----------------------------------")
17
+ print(f"{label} GPU Memory Usage:")
18
+ print(f"Allocated: {memory_allocated / 1024 ** 2:.2f} MB")
19
+ print(f"Cached: {memory_reserved / 1024 ** 2:.2f} MB")
20
+
21
+
22
+ # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
23
+
24
+ # @st.cache_resource
25
+ def get_model(model_path):
26
+ try:
27
+ with st.spinner(f"Loading model {model_path}"):
28
+ # Load the model here
29
+ model_import = Qwen2VLForConditionalGeneration.from_pretrained(
30
+ model_path, torch_dtype="auto", device_map="auto",
31
+ attn_implementation="flash_attention_2",
32
+ token=HUGGINGFACE_TOKEN,
33
+ )
34
+ model_import = model_import.to("cuda")
35
+ size = {
36
+ "shortest_edge": 224,
37
+ "longest_edge": 1024,
38
+ }
39
+ processor_import = AutoProcessor.from_pretrained("itsumi-st/imgtikz_qwen2vl",
40
+ size=size,
41
+ min_pixels=256 * 256,
42
+ max_pixels=1024 * 1024,
43
+ token=HUGGINGFACE_TOKEN)
44
+ processor_import.tokenizer.padding_side = 'left'
45
+
46
+ return model_import, processor_import
47
+ except Exception as e:
48
+ st.error(f"Error loading model: {e}")
49
+ return None, None
50
+
51
+
52
+ def run_inference(input_file, model_path, args):
53
+ model, processor = get_model(model_path)
54
+ if model is None or processor is None:
55
+ return "Error loading model."
56
+
57
+ # GPU Memory after model loading:
58
+ after_model_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
59
+
60
+ image = Image.open(input_file)
61
+ conversation = [
62
+ {
63
+ "role": "user",
64
+ "content": [
65
+ {"type": "image", "image": image},
66
+ {"type": "text", "text": "Please generate TikZ code to draw the diagram of the given image."}
67
+ ],
68
+ }
69
+ ]
70
+ text_prompt = processor.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
71
+ image_input, video_inputs = process_vision_info(conversation)
72
+ inputs = processor(
73
+ text=[text_prompt],
74
+ images=image_input,
75
+ videos=video_inputs,
76
+ padding=True,
77
+ return_tensors="pt",
78
+ ).to("cuda")
79
+
80
+ # GPU Memory after input processing
81
+ after_input_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
82
+
83
+ output_ids = model.generate(**inputs,
84
+ max_new_tokens=args['max_length'],
85
+ do_sample=True,
86
+ top_p=args['top_p'],
87
+ top_k=args['top_k'],
88
+ use_cache=True,
89
+ num_return_sequences=1,
90
+ pad_token_id=processor.tokenizer.pad_token_id,
91
+ temperature=args['temperature']
92
+ )
93
+ generated_ids = [
94
+ output_ids[len(input_ids):]
95
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
96
+ ]
97
+ output_text = processor.batch_decode(
98
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
99
+ )
100
+
101
+ # GPU Memory after generation
102
+ after_gen_dump = (torch.cuda.memory_allocated(), torch.cuda.memory_reserved())
103
+
104
+ print_gpu_memory("Before Model", after_model_dump[0], after_model_dump[1])
105
+ print_gpu_memory("After Input", after_input_dump[0], after_input_dump[1])
106
+ print_gpu_memory("After Generation", after_gen_dump[0], after_gen_dump[1])
107
+
108
+ return output_text
requirements.txt CHANGED
@@ -1,3 +1,12 @@
1
  streamlit~=1.43.2
2
- transformers~=4.50.0
3
- pillow~=11.1.0
 
 
 
 
 
 
 
 
 
 
1
  streamlit~=1.43.2
2
+ torch==2.6.0
3
+ torchvision==0.21.0
4
+ torchaudio
5
+ transformers==4.48.2
6
+ qwen-vl-utils==0.0.10
7
+ packaging
8
+ accelerate==1.0.1
9
+ requests
10
+ pillow
11
+ python-dotenv
12
+ pdf2image
sketch2diagram.py CHANGED
@@ -1,12 +1,25 @@
1
  import streamlit as st
2
- from PIL import Image
3
 
4
- from main import get_model
 
 
 
5
 
6
  # Sidebar Setup
7
  st.sidebar.title("Model Configuration")
8
- inference_strat = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
9
- help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Introduction Section
12
  st.title("Sketch2Diagram")
@@ -14,7 +27,6 @@ st.title("Sketch2Diagram")
14
  st.write("This is a runnable demo of ImgTikZ model introduced in the Sketch2Diagram paper.")
15
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
16
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
17
- st.write(f"Inference Strategy: {inference_strat}")
18
 
19
  # User Input Section
20
  st.subheader("Upload your sketch")
@@ -25,22 +37,42 @@ input_method = st.selectbox("Input Method", ["Upload", "Camera"],
25
  input_file = None
26
  if input_method == "Camera":
27
  input_file = st.camera_input("Take a picture of your sketch")
28
- # Implement camera input functionality here
29
  else:
30
  input_file = st.file_uploader("Upload an image of your sketch", type=["png", "jpg", "jpeg"])
31
-
32
  generate_command = None
33
  # Display the uploaded image
34
  if input_file is not None:
35
  st.image(input_file, caption="Uploaded Sketch")
36
  generate_command = st.button("Generate TikZ Code")
37
 
 
38
  if generate_command:
39
- model = get_model()
40
- image = Image.open(input_file)
41
  with st.spinner("Generating TikZ code..."):
42
- output = model(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- tikz_code = output[0]['generated_text']
45
- st.subheader("Generated TikZ Code")
46
- st.code(tikz_code, language='latex')
 
 
1
  import streamlit as st
2
+ from pdf2image import convert_from_path
3
 
4
+ from qwen2_inference import run_inference
5
+ from util import compile_tikz_to_pdf
6
+
7
+ args = {}
8
 
9
  # Sidebar Setup
10
  st.sidebar.title("Model Configuration")
11
+ model_name = st.sidebar.selectbox("Model Name", ['Itsumi-st/Imgtikz_Qwen2vl', 'Qwen/Qwen2-VL-7B-Instruct'])
12
+ args['inference_strat'] = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
13
+ help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
14
+ args['max_length'] = st.sidebar.slider("Max Length", 1, 5096, 2048,
15
+ help="Maximum length of the generated output. The model will generate text up to this length.")
16
+ args['seed'] = st.sidebar.number_input("Seed", min_value=0, value=42, step=1)
17
+ args['temperature'] = st.sidebar.slider("Temperature", 0.0, 1.0, 0.6, step=0.01,
18
+ help="Temperature parameter for sampling. Higher values result in more random outputs.")
19
+ args['top_p'] = st.sidebar.slider("Top P", 0.0, 1.0, 1.0, step=0.01,
20
+ help="Top P sampling parameter. The model will sample from the top P percentage of the probability distribution.")
21
+ args['top_k'] = st.sidebar.slider("Top K", 0, 100, 50, step=1,
22
+ help="Top K sampling parameter. The model will sample from the top K tokens with the highest probabilities.")
23
 
24
  # Introduction Section
25
  st.title("Sketch2Diagram")
 
27
  st.write("This is a runnable demo of ImgTikZ model introduced in the Sketch2Diagram paper.")
28
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
29
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
 
30
 
31
  # User Input Section
32
  st.subheader("Upload your sketch")
 
37
  input_file = None
38
  if input_method == "Camera":
39
  input_file = st.camera_input("Take a picture of your sketch")
40
+ # todo: Implement camera input functionality here
41
  else:
42
  input_file = st.file_uploader("Upload an image of your sketch", type=["png", "jpg", "jpeg"])
43
+ st.write(args)
44
  generate_command = None
45
  # Display the uploaded image
46
  if input_file is not None:
47
  st.image(input_file, caption="Uploaded Sketch")
48
  generate_command = st.button("Generate TikZ Code")
49
 
50
+ # Run model inference
51
  if generate_command:
 
 
52
  with st.spinner("Generating TikZ code..."):
53
+ output = run_inference(input_file, model_name, args)[0]
54
+ pdf_file_path = compile_tikz_to_pdf(output)
55
+ if output and pdf_file_path:
56
+ st.success("TikZ code generated successfully!")
57
+ st.code(output, language='latex')
58
+
59
+ st.download_button(
60
+ label="Download LaTeX Code",
61
+ data=output,
62
+ file_name="output.tex",
63
+ mime="text/plain"
64
+ )
65
+
66
+ # st.image(pdf_file_path, caption="Generated Diagram", use_column_width=True)
67
+ with open(pdf_file_path, "rb") as f:
68
+ st.download_button(
69
+ label="Download PDF",
70
+ data=f.read(), # ✅ this is the binary content
71
+ file_name="output.pdf",
72
+ mime="application/pdf"
73
+ )
74
 
75
+ images = convert_from_path(pdf_file_path)
76
+ st.image(images[0], caption="Generated Diagram", use_column_width=True)
77
+ else:
78
+ st.error("Failed to generate TikZ code.")
util.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ import tempfile
4
+
5
+
6
+ def compile_tikz_to_pdf(tikz_code):
7
+ temp_dir = tempfile.mkdtemp()
8
+
9
+ tex_path = os.path.join(temp_dir, "output.tex")
10
+ pdf_path = os.path.join(temp_dir, "output.pdf")
11
+
12
+ with open(tex_path, "w") as f:
13
+ f.write(tikz_code)
14
+
15
+ try:
16
+ subprocess.run(
17
+ ["pdflatex", "-interaction=nonstopmode", tex_path],
18
+ cwd=temp_dir,
19
+ stdout=subprocess.PIPE,
20
+ stderr=subprocess.PIPE,
21
+ check=True,
22
+ )
23
+ return pdf_path
24
+ except subprocess.CalledProcessError as e:
25
+ print("PDF compilation failed:", e)
26
+ return None