BenkHel commited on
Commit
9b68ec8
·
verified ·
1 Parent(s): d717525

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -31
app.py CHANGED
@@ -51,34 +51,54 @@ conv_mode = 'mistral_instruct_system' # Diese Variable wird noch für die Konver
51
  load_8bit = False
52
  load_4bit = False
53
 
54
- # Laden Sie den Prozessor, der Tokenizer und Bildprozessor kombiniert
55
- processor = AutoProcessor.from_pretrained(model_path)
56
-
57
- # Laden Sie das Modell mit der korrekten Klasse
58
- model = LlavaMistralForCausalLM.from_pretrained(
59
- model_path,
60
- torch_dtype=torch.bfloat16, # Ihre config.json spezifiziert bfloat16
61
- low_cpu_mem_usage=True, # Empfohlen für große Modelle
62
- load_in_4bit=load_4bit,
63
- load_in_8bit=load_8bit,
64
- )
65
-
66
- # Weisen Sie die Komponenten den alten Variablennamen zu, damit der restliche Code funktioniert
67
- from transformers import AutoTokenizer, AutoImageProcessor
68
-
69
- tokenizer = AutoTokenizer.from_pretrained(model_path)
70
- try:
71
- image_processor = AutoImageProcessor.from_pretrained(model_path)
72
- if image_processor is None:
73
- raise Exception()
74
- except Exception:
75
- # Fallback: Passender Vision Tower!
76
- image_processor = AutoImageProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
77
- assert image_processor is not None, "Could not load image_processor!"
78
-
79
-
80
- # Setzen Sie die Kontextlänge (falls der restliche Code sie benötigt)
81
- context_len = model.config.max_position_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  model.config.training = False
83
 
84
  def upvote_last_response(state):
@@ -146,10 +166,9 @@ def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, m
146
  images = process_images(images, image_processor, model.config)
147
 
148
  if type(images) is list:
149
- images = [image.to(model.device, dtype=torch.bfloat16) for image in images]
150
  else:
151
- images = images.to(model.device, dtype=torch.bfloat16)
152
-
153
 
154
  replace_token = DEFAULT_IMAGE_TOKEN
155
  if getattr(model.config, 'mm_use_im_start_end', False):
 
51
  load_8bit = False
52
  load_4bit = False
53
 
54
+ import sys
55
+ import os
56
+ import argparse
57
+ import time
58
+ import subprocess
59
+ import spaces
60
+ import cumo.serve.gradio_web_server as gws
61
+
62
+ import datetime
63
+ import json
64
+
65
+ import gradio as gr
66
+ import requests
67
+ from PIL import Image
68
+
69
+ from cumo.conversation import (default_conversation, conv_templates, SeparatorStyle)
70
+ from cumo.constants import LOGDIR
71
+ from cumo.utils import (build_logger, server_error_msg, violates_moderation, moderation_msg)
72
+ import hashlib
73
+
74
+ import torch
75
+ import io
76
+ from cumo.constants import WORKER_HEART_BEAT_INTERVAL
77
+ from cumo.utils import (build_logger, server_error_msg,
78
+ pretty_print_semaphore)
79
+ from cumo.model.builder import load_pretrained_model
80
+ from cumo.mm_utils import process_images, load_image_from_base64, tokenizer_image_token
81
+ from cumo.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
82
+ from transformers import TextIteratorStreamer
83
+ from threading import Thread
84
+
85
+ # Execute the pip install command with additional options
86
+ #subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']
87
+
88
+ headers = {"User-Agent": "CuMo"}
89
+
90
+ no_change_btn = gr.Button()
91
+ enable_btn = gr.Button(interactive=True)
92
+ disable_btn = gr.Button(interactive=False)
93
+
94
+ device = "cuda" if torch.cuda.is_available() else "cpu"
95
+ model_path = 'BenkHel/CumoThesis'
96
+ model_base = 'mistralai/Mistral-7B-Instruct-v0.2'
97
+ model_name = 'CuMo-mistral-7b'
98
+ conv_mode = 'mistral_instruct_system'
99
+ load_8bit = False
100
+ load_4bit = False
101
+ tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, model_base, model_name, load_8bit, load_4bit, device=device, use_flash_attn=False)
102
  model.config.training = False
103
 
104
  def upvote_last_response(state):
 
166
  images = process_images(images, image_processor, model.config)
167
 
168
  if type(images) is list:
169
+ images = [image.to(model.device, dtype=torch.float16) for image in images]
170
  else:
171
+ images = images.to(model.device, dtype=torch.float16)
 
172
 
173
  replace_token = DEFAULT_IMAGE_TOKEN
174
  if getattr(model.config, 'mm_use_im_start_end', False):