KaiChen1998 commited on
Commit
e0d34c8
·
verified ·
1 Parent(s): f158243

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -7
app.py CHANGED
@@ -76,20 +76,18 @@ mode2func = dict(
76
 
77
  ##########################################
78
  # LLM part
79
- # TODO: 1) change model 2) change arguments
80
  ##########################################
81
  import torch
82
  from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
83
  from threading import Thread
84
 
85
- model_name = "Emova-ollm/emova_llama3_1-8b"
86
  model = AutoModel.from_pretrained(
87
  model_name,
88
  torch_dtype=torch.bfloat16,
89
- use_flash_attn=True,
90
  low_cpu_mem_usage=True,
91
- trust_remote_code=True,
92
- token=auth_token).eval().cuda()
93
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, token=auth_token)
94
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
95
 
@@ -235,8 +233,8 @@ def http_bot(state, temperature, top_p, max_new_tokens, speaker):
235
  # Process inputs
236
  inputs = processor(text=[prompt], images=all_images if len(all_images) > 0 else None, return_tensors="pt")
237
  inputs.to(model.device)
238
- if len(all_images) > 0:
239
- inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype) # TODO
240
 
241
  # Process hyperparameters
242
  temperature = float(pload.get("temperature", 1.0))
 
76
 
77
  ##########################################
78
  # LLM part
 
79
  ##########################################
80
  import torch
81
  from transformers import AutoModel, AutoProcessor, TextIteratorStreamer
82
  from threading import Thread
83
 
84
+ model_name = "Emova-ollm/emova-qwen-2-5-7b-hf"
85
  model = AutoModel.from_pretrained(
86
  model_name,
87
  torch_dtype=torch.bfloat16,
88
+ attn_implementation='flash_attention_2',
89
  low_cpu_mem_usage=True,
90
+ trust_remote_code=True).eval().cuda()
 
91
  processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, token=auth_token)
92
  streamer = TextIteratorStreamer(processor.tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
93
 
 
233
  # Process inputs
234
  inputs = processor(text=[prompt], images=all_images if len(all_images) > 0 else None, return_tensors="pt")
235
  inputs.to(model.device)
236
+ # if len(all_images) > 0:
237
+ # inputs['pixel_values'] = inputs['pixel_values'].to(model.dtype)
238
 
239
  # Process hyperparameters
240
  temperature = float(pload.get("temperature", 1.0))