BenkHel commited on
Commit
c3c0c8e
·
verified ·
1 Parent(s): c2b8ea8

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -203
app.py DELETED
@@ -1,203 +0,0 @@
1
- # --- Imports bleiben unverändert ---
2
- import subprocess
3
- import sys
4
- import os
5
- from transformers import TextIteratorStreamer
6
- import argparse
7
- import time
8
- import subprocess
9
- import spaces
10
- import cumo.serve.gradio_web_server as gws
11
- from transformers import AutoProcessor, AutoTokenizer, AutoImageProcessor
12
- import datetime
13
- import json
14
- import gradio as gr
15
- import requests
16
- from PIL import Image
17
- from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
18
- from cumo.constants import LOGDIR
19
- from cumo.model.language_model.llava_mistral import LlavaMistralForCausalLM
20
- from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
21
- import hashlib
22
- import torch
23
- import io
24
- from cumo.constants import WORKER_HEART_BEAT_INTERVAL
25
- from cumo.utils import (build_logger, server_error_msg, pretty_print_semaphore)
26
- from cumo.model.builder import load_pretrained_model
27
- from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
28
- from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
29
- from threading import Thread
30
-
31
- # --- Model Setup ---
32
- headers = {"User-Agent": "CuMo"}
33
- no_change_btn = gr.Button()
34
- enable_btn = gr.Button(interactive=True)
35
- disable_btn = gr.Button(interactive=False)
36
-
37
- device = "cuda" if torch.cuda.is_available() else "cpu"
38
- model_path = 'BenkHel/CumoThesis'
39
- model_base = 'mistralai/Mistral-7B-Instruct-v0.2'
40
- model_name = 'CuMo-mistral-7b'
41
- conv_mode = 'mistral_instruct_system'
42
- load_8bit = False
43
- load_4bit = False
44
-
45
- tokenizer, model, image_processor, context_len = load_pretrained_model(
46
- model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False
47
- )
48
- model.config.training = False
49
-
50
- # --- Prompt ---
51
- FIXED_PROMPT = "<image>\nWhat material is this item and how to dispose of it?"
52
-
53
- # --- Functions ---
54
- def clear_history():
55
- state = default_conversation.copy()
56
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
57
-
58
- def add_text(state, imagebox, textbox, image_process_mode):
59
- if state is None:
60
- state = conv_templates[conv_mode].copy()
61
- if imagebox is not None:
62
- image = Image.open(imagebox).convert('RGB')
63
- textbox = (FIXED_PROMPT, image, image_process_mode)
64
- state.append_message(state.roles[0], textbox)
65
- state.append_message(state.roles[1], None)
66
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
67
-
68
- def delete_text(state, image_process_mode):
69
- state.messages[-1][-1] = None
70
- prev_human_msg = state.messages[-2]
71
- if type(prev_human_msg[1]) in (tuple, list):
72
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
73
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
74
-
75
- @spaces.GPU
76
- def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
77
- prompt = FIXED_PROMPT
78
- images = state.get_images(return_pil=True)
79
- ori_prompt = prompt
80
- num_image_tokens = 0
81
-
82
- if images and len(images) > 0:
83
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
84
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
85
- image_sizes = [image.size for image in images]
86
- images = process_images(images, image_processor, model.config)
87
- if isinstance(images, list):
88
- images = [image.to(model.device, dtype=torch.float16) for image in images]
89
- else:
90
- images = images.to(model.device, dtype=torch.float16)
91
- replace_token = DEFAULT_IMAGE_TOKEN
92
- if getattr(model.config, 'mm_use_im_start_end', False):
93
- replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
94
- prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)
95
- num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
96
- image_args = {"images": images, "image_sizes": image_sizes}
97
- else:
98
- image_args = {}
99
-
100
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
101
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
102
- max_new_tokens = min(512, max_context_length - input_ids.shape[-1] - num_image_tokens)
103
- if max_new_tokens < 1:
104
- yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation.", "error_code": 0}).encode() + b"\0"
105
- return
106
-
107
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
108
- thread = Thread(target=model.generate, kwargs=dict(
109
- inputs=input_ids,
110
- do_sample=(temperature > 0.001),
111
- temperature=temperature,
112
- top_p=top_p,
113
- max_new_tokens=max_new_tokens,
114
- streamer=streamer,
115
- use_cache=True,
116
- pad_token_id=tokenizer.eos_token_id,
117
- **image_args
118
- ))
119
- thread.start()
120
- generated_text = ''
121
- stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
122
-
123
- for new_text in streamer:
124
- generated_text += new_text
125
- if generated_text.endswith(stop_str):
126
- generated_text = generated_text[:-len(stop_str)]
127
- state.messages[-1][-1] = generated_text
128
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
129
- yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
130
- torch.cuda.empty_cache()
131
-
132
- # --- UI Setup ---
133
- textbox = gr.Textbox(
134
- show_label=False,
135
- placeholder="Prompt is fixed: What material is this item and how to dispose of it.",
136
- container=False,
137
- interactive=False
138
- )
139
-
140
- with gr.Blocks(title="CuMo", theme=gr.themes.Default(), css="""
141
- #buttons button {
142
- min-width: min(120px,100%);
143
- }
144
- """) as demo:
145
- state = gr.State()
146
-
147
- gr.Markdown("# CuMo: Trained for waste management")
148
- gr.Markdown(f"**Prompt:** `{FIXED_PROMPT}`")
149
-
150
- with gr.Row():
151
- with gr.Column(scale=3):
152
- imagebox = gr.Image(label="Input Image", type="filepath")
153
- image_process_mode = gr.Radio(
154
- ["Crop", "Resize", "Pad", "Default"],
155
- value="Default",
156
- label="Preprocess for non-square image", visible=False)
157
-
158
- cur_dir = './cumo/serve'
159
- gr.Examples(examples=[
160
- [f"{cur_dir}/examples/0165 CB.jpg"],
161
- [f"{cur_dir}/examples/0225 PA.jpg"],
162
- [f"{cur_dir}/examples/0787 GM.jpg"],
163
- [f"{cur_dir}/examples/1396 A.jpg"],
164
- [f"{cur_dir}/examples/2001 P.jpg"],
165
- [f"{cur_dir}/examples/2658 PE.jpg"],
166
- [f"{cur_dir}/examples/3113 R.jpg"],
167
- [f"{cur_dir}/examples/3750 RPC.jpg"],
168
- [f"{cur_dir}/examples/5033 CC.jpg"],
169
- [f"{cur_dir}/examples/5307 B.jpg"],
170
- ], inputs=[imagebox], cache_examples=False)
171
-
172
- with gr.Accordion("Parameters", open=False):
173
- temperature = gr.Slider(0.0, 1.0, value=0.2, step=0.1, interactive=True, label="Temperature")
174
- top_p = gr.Slider(0.0, 1.0, value=0.7, step=0.1, interactive=True, label="Top P")
175
- max_output_tokens = gr.Slider(0, 1024, value=512, step=64, interactive=True, label="Max output tokens")
176
-
177
- with gr.Column(scale=8):
178
- chatbot = gr.Chatbot(elem_id="chatbot", label="CuMo Chatbot", height=650, layout="panel")
179
- with gr.Row():
180
- with gr.Column(scale=8):
181
- textbox.render()
182
- with gr.Column(scale=1, min_width=50):
183
- submit_btn = gr.Button(value="Send", variant="primary")
184
- with gr.Row(elem_id="buttons") as button_row:
185
- stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
186
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
187
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
188
-
189
- gr.Markdown(tos_markdown)
190
- gr.Markdown(learn_more_markdown)
191
- url_params = gr.JSON(visible=False)
192
-
193
- # --- Event Bindings ---
194
- btn_list = [regenerate_btn, clear_btn]
195
- clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list, queue=False)
196
- regenerate_btn.click(delete_text, [state, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
197
- ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
198
- textbox.submit(add_text, [state, imagebox, textbox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
199
- ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
200
- submit_btn.click(add_text, [state, imagebox, textbox, image_process_mode], [state, chatbot, textbox, imagebox] + btn_list
201
- ).then(generate, [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens], [state, chatbot, textbox, imagebox] + btn_list)
202
-
203
- demo.queue(status_update_rate=10, api_open=False).launch()