chabane commited on
Commit
67c4556
·
1 Parent(s): 6c3c19b

Modifying the load for the models

Browse files
Files changed (1) hide show
  1. main.py +44 -17
main.py CHANGED
@@ -8,8 +8,10 @@ import re
8
  import io
9
  import base64
10
  import matplotlib.pyplot as plt
11
- from transformers import pipeline
12
-
 
 
13
  import fitz
14
  from docx import Document
15
  from pptx import Presentation
@@ -29,23 +31,45 @@ app.add_middleware(
29
  allow_headers=["*"],
30
  )
31
  try:
32
- interpreter = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
 
 
 
33
  except Exception as exp:
34
- print("[ERROR] Can't load nlpconnect/vit-gpt2-image-captioning ")
35
  print(str(exp))
 
 
 
 
 
 
 
36
  try:
37
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn",device=0)
38
- except Exception as exp:
39
- print("[ERROR] Can't load facebook/bart-large-cnn ")
40
- print(str(exp))
 
 
 
 
 
 
 
 
 
 
 
 
41
  try:
42
- generator = pipeline("text-generation", model="deepseek-ai/deepseek-coder-1.3b-instruct", device_map="auto")
43
- except Exception as exp:
 
44
  print("[ERROR] Can't load deepseek-ai/deepseek-coder-1.3b-instruct ")
45
  print(str(exp))
46
 
47
 
48
-
49
  app.mount("/static",StaticFiles(directory='static'),'static')
50
  templates = Jinja2Templates(directory='templates')
51
 
@@ -71,9 +95,12 @@ def caption(file:UploadFile=File(...)):
71
  if extension not in Supported_extensions:
72
  return {"error": "Unsupported file type"}
73
  image = Image.open(file.file)
74
-
75
- caption = interpreter(image)
76
- return {"caption": caption[0]['generated_text']}
 
 
 
77
 
78
  @app.post("/summerize")
79
  def summerzation(file:UploadFile=File(...)):
@@ -93,8 +120,8 @@ def summerzation(file:UploadFile=File(...)):
93
  return {"error": "File is empty"}
94
 
95
  result=""
96
- for i in range(0,len(text),1024):
97
- result+=summarizer(text, max_length=150, min_length=30, do_sample=False)[0]['summary_text']
98
  return {"summary": result}
99
 
100
 
@@ -133,7 +160,7 @@ error.
133
  ##Prompt: {prompt}.
134
  """
135
 
136
- output = generator(message, max_length=1000)
137
  match = re.search(r'```python(.*?)```', output[0]["generated_text"], re.DOTALL)
138
  code =''
139
  if not match:
 
8
  import io
9
  import base64
10
  import matplotlib.pyplot as plt
11
+ import torch
12
+ from transformers import pipeline,VisionEncoderDecoderModel,ViTImageProcessor,AutoTokenizer
13
+ from transformers import BartForConditionalGeneration, BartTokenizer
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
  import fitz
16
  from docx import Document
17
  from pptx import Presentation
 
31
  allow_headers=["*"],
32
  )
33
  try:
34
+ #interpreter = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
35
+ interpreter_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
36
+ interpreter_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
37
+ interpreter_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
38
  except Exception as exp:
39
+ print("[ERROR] Can't load nlpconnect/vit-gpt2-image-captioning")
40
  print(str(exp))
41
+
42
+ #try:
43
+ # summarizer = pipeline("summarization", model="facebook/bart-large-cnn",device=0)
44
+ #except Exception as exp:
45
+ # print("[ERROR] Can't load facebook/bart-large-cnn ")
46
+ # print(str(exp))
47
+
48
  try:
49
+ summarizer_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")
50
+ except OSError as e:
51
+ print(f"[INFO] PyTorch weights not found. Falling back to TensorFlow weights.\n{e}")
52
+ summarizer_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn", from_tf=True)
53
+
54
+ summarizer_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
55
+
56
+
57
+
58
+ #try:
59
+ # generator = pipeline("text-generation", model="deepseek-ai/deepseek-coder-1.3b-instruct", device_map="auto")
60
+ #except Exception as exp:
61
+ # print("[ERROR] Can't load deepseek-ai/deepseek-coder-1.3b-instruct ")
62
+ # print(str(exp))
63
+
64
+
65
  try:
66
+ generator_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True)
67
+ tokengenerator_modelizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-coder-1.3b-instruct", trust_remote_code=True)
68
+ except Exception as exp :
69
  print("[ERROR] Can't load deepseek-ai/deepseek-coder-1.3b-instruct ")
70
  print(str(exp))
71
 
72
 
 
73
  app.mount("/static",StaticFiles(directory='static'),'static')
74
  templates = Jinja2Templates(directory='templates')
75
 
 
95
  if extension not in Supported_extensions:
96
  return {"error": "Unsupported file type"}
97
  image = Image.open(file.file)
98
+ #caption = interpreter(image)
99
+ pixel_values = interpreter_processor(images=image, return_tensors="pt").pixel_values
100
+ output_ids = interpreter_model.generate(pixel_values, max_length=16, num_beams=4)
101
+ caption = interpreter_tokenizer.decode(output_ids[0], skip_special_tokens=True)
102
+ return {"caption":caption}
103
+ #return {"caption": caption[0]['generated_text']}
104
 
105
  @app.post("/summerize")
106
  def summerzation(file:UploadFile=File(...)):
 
120
  return {"error": "File is empty"}
121
 
122
  result=""
123
+ #for i in range(0,len(text),1024):
124
+ # result+=summarizer(text, max_length=150, min_length=30, do_sample=False)[0]['summary_text']
125
  return {"summary": result}
126
 
127
 
 
160
  ##Prompt: {prompt}.
161
  """
162
 
163
+ output = [{"generated_text":""}]#generator(message, max_length=1000)
164
  match = re.search(r'```python(.*?)```', output[0]["generated_text"], re.DOTALL)
165
  code =''
166
  if not match: