DaddyDaniel commited on
Commit
979c542
ยท
1 Parent(s): 30f49a1

Add args to model

Browse files

User can select args for inference.
Inference logic moved to qwen2_inference.py

Files changed (3) hide show
  1. main.py +0 -7
  2. qwen2_inference.py +58 -0
  3. sketch2diagram.py +16 -12
main.py CHANGED
@@ -1,11 +1,4 @@
1
  import streamlit as st
2
- from transformers import pipeline
3
-
4
- @st.cache_resource
5
- def get_model():
6
- # Load the model here
7
- model = pipeline("image-to-text", model="itsumi-st/imgtikz_qwen2vl")
8
- return model
9
 
10
  st.logo("NLP_Group_logo.svg", size="large")
11
  main_page = st.Page("main_page.py", title="Main Page", icon="๐Ÿ ")
 
1
  import streamlit as st
 
 
 
 
 
 
 
2
 
3
  st.logo("NLP_Group_logo.svg", size="large")
4
  main_page = st.Page("main_page.py", title="Main Page", icon="๐Ÿ ")
qwen2_inference.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
5
+
6
+
7
+ # Inference steps taken from https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct
8
+
9
+ @st.cache_resource
10
+ def get_model(model_path):
11
+ try:
12
+ with st.spinner(f"Loading model {model_path}"):
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ # Load the model here
15
+ model_import = Qwen2VLForConditionalGeneration.from_pretrained(
16
+ model_path, torch_dtype="auto", device_map=device
17
+ )
18
+ processor_import = AutoProcessor.from_pretrained(model_path)
19
+
20
+ return model_import, processor_import
21
+ except Exception as e:
22
+ st.error(f"Error loading model: {e}")
23
+ return None, None
24
+
25
+
26
+ def run_inference(input_file, model_path, args):
27
+ model, processor = get_model(model_path)
28
+ if model is None or processor is None:
29
+ return "Error loading model."
30
+ image = Image.open(input_file)
31
+ conversation = [
32
+ {
33
+ "role": "user",
34
+ "content": [
35
+ {"type": "image"},
36
+ {"type": "text", "text": "Please generate TikZ code to draw the diagram of the given image."}
37
+ ],
38
+ }
39
+ ]
40
+ text_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
41
+ inputs = processor(image, text_prompt, return_tensors="pt").to("cuda")
42
+
43
+ output_ids = model.generate(**inputs,
44
+ max_new_tokens=args.max_length,
45
+ do_sample=True,
46
+ top_p=args.top_p,
47
+ top_k=args.top_k,
48
+ num_return_sequences=1,
49
+ temperature=args.temperature
50
+ )
51
+ generated_ids = [
52
+ output_ids[len(input_ids):]
53
+ for input_ids, output_ids in zip(inputs.input_ids, output_ids)
54
+ ]
55
+ output_text = processor.batch_decode(
56
+ generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
57
+ )
58
+ return output_text
sketch2diagram.py CHANGED
@@ -1,12 +1,19 @@
1
  import streamlit as st
2
- from PIL import Image
3
 
4
- from main import get_model
 
 
5
 
6
  # Sidebar Setup
7
  st.sidebar.title("Model Configuration")
8
- inference_strat = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
9
- help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
 
 
 
 
 
 
10
 
11
  # Introduction Section
12
  st.title("Sketch2Diagram")
@@ -14,7 +21,7 @@ st.title("Sketch2Diagram")
14
  st.write("This is a runnable demo of ImgTikZ model introduced in the Sketch2Diagram paper.")
15
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
16
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
17
- st.write(f"Inference Strategy: {inference_strat}")
18
 
19
  # User Input Section
20
  st.subheader("Upload your sketch")
@@ -35,12 +42,9 @@ if input_file is not None:
35
  st.image(input_file, caption="Uploaded Sketch")
36
  generate_command = st.button("Generate TikZ Code")
37
 
 
38
  if generate_command:
39
- model = get_model()
40
- image = Image.open(input_file)
41
  with st.spinner("Generating TikZ code..."):
42
- output = model(image)
43
-
44
- tikz_code = output[0]['generated_text']
45
- st.subheader("Generated TikZ Code")
46
- st.code(tikz_code, language='latex')
 
1
  import streamlit as st
 
2
 
3
+ from qwen2_inference import run_inference
4
+
5
+ args = {}
6
 
7
  # Sidebar Setup
8
  st.sidebar.title("Model Configuration")
9
+ model_name = st.sidebar.selectbox("Model Name", ['Itsumi-st/Imgtikz_Qwen2vl', 'Qwen/Qwen2-VL-7B-Instruct'])
10
+ args['inference_strat'] = st.sidebar.selectbox("Inference Strategy", ["Iterative", "Multi-candidate"],
11
+ help="Choose the inference strategy for the model. Iterative generates one candidate at a time until an output compiles, while Multi-candidate generates multiple candidates in parallel.")
12
+ args['max_length'] = st.sidebar.slider("Max Length", 1, 5096, 2048, help="Maximum length of the generated output. The model will generate text up to this length.")
13
+ args['seed'] = st.sidebar.number_input("Seed", min_value=0, value=42, step=1)
14
+ args['top_p'] = st.sidebar.slider("Top P", 0.0, 1.0, 1.0, step=0.01, help="Top P sampling parameter. The model will sample from the top P percentage of the probability distribution.")
15
+ args['temperature'] = st.sidebar.slider("Top P", 0.0, 1.0, 0.6, step=0.01, help="Temperature parameter for sampling. Higher values result in more random outputs.")
16
+ args['top_k'] = st.sidebar.slider("Top K", 0, 100, 50, step=1, help="Top K sampling parameter. The model will sample from the top K tokens with the highest probabilities.")
17
 
18
  # Introduction Section
19
  st.title("Sketch2Diagram")
 
21
  st.write("This is a runnable demo of ImgTikZ model introduced in the Sketch2Diagram paper.")
22
  st.write("Please refer to the [original paper](https://openreview.net/pdf?id=KvaDHPhhir) for more details.")
23
  st.write("The model is trained to convert sketches into TikZ code, which can be used to generate vectorized diagrams.")
24
+
25
 
26
  # User Input Section
27
  st.subheader("Upload your sketch")
 
42
  st.image(input_file, caption="Uploaded Sketch")
43
  generate_command = st.button("Generate TikZ Code")
44
 
45
+ # Run model inference
46
  if generate_command:
 
 
47
  with st.spinner("Generating TikZ code..."):
48
+ output = run_inference(input_file, model_name, args)
49
+ st.success("TikZ code generated successfully!")
50
+ st.code(output, language='latex')