czczup commited on
Commit
026bc32
·
verified ·
1 Parent(s): 360c87d

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_internvl_chat.py +13 -7
modeling_internvl_chat.py CHANGED
@@ -375,17 +375,24 @@ class InternVLChatModel(PreTrainedModel):
375
  vit_embeds = self.mlp1(vit_embeds)
376
  return vit_embeds
377
 
378
- def chat(self, tokenizer, pixel_values, question, generation_config,
379
  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
380
 
381
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
  self.img_context_token_id = img_context_token_id
383
 
384
- from .conversation import get_conv_template
385
 
386
  template = get_conv_template(self.template)
387
- image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN
388
- template.append_message(template.roles[0], image_tokens + '\n' + question)
 
 
 
 
 
 
 
389
  template.append_message(template.roles[1], None)
390
  query = template.get_prompt()
391
  model_inputs = tokenizer(query, return_tensors='pt')
@@ -399,9 +406,8 @@ class InternVLChatModel(PreTrainedModel):
399
  **generation_config
400
  )
401
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
402
- query_to_print = query.replace(image_tokens, '<image>')
403
- print(query_to_print, response)
404
- return response
405
 
406
  @torch.no_grad()
407
  def generate(
 
375
  vit_embeds = self.mlp1(vit_embeds)
376
  return vit_embeds
377
 
378
+ def chat(self, tokenizer, pixel_values, question, generation_config, history=None,
379
  IMG_START_TOKEN='<img>', IMG_END_TOKEN='</img>', IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'):
380
 
381
  img_context_token_id = tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
382
  self.img_context_token_id = img_context_token_id
383
 
384
+ from internvl.conversation import get_conv_template
385
 
386
  template = get_conv_template(self.template)
387
+ if history is None:
388
+ history = []
389
+ image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token + IMG_END_TOKEN
390
+ question = image_tokens + '\n' + question
391
+ else:
392
+ for (old_question, old_answer) in history:
393
+ template.append_message(template.roles[0], old_question)
394
+ template.append_message(template.roles[1], old_answer)
395
+ template.append_message(template.roles[0], question)
396
  template.append_message(template.roles[1], None)
397
  query = template.get_prompt()
398
  model_inputs = tokenizer(query, return_tensors='pt')
 
406
  **generation_config
407
  )
408
  response = tokenizer.batch_decode(generation_output, skip_special_tokens=True)[0]
409
+ history.append((question, response))
410
+ return response, history
 
411
 
412
  @torch.no_grad()
413
  def generate(