harshithasudhakar commited on
Commit
d151635
·
verified ·
1 Parent(s): 884190d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -3
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import torch
2
- from transformers import pipeline
3
  import streamlit as st
4
  import fitz # PyMuPDF for PDF text extraction
5
 
@@ -19,8 +19,14 @@ model_name = model_options[model_choice]
19
 
20
  @st.cache_resource(show_spinner=True)
21
  def load_model(name):
22
- task = "text2text-generation" if "t5" in name.lower() or "pegasus" in name.lower() else "text-generation"
23
- return pipeline(task, model=name)
 
 
 
 
 
 
24
 
25
  simplifier = load_model(model_name)
26
 
 
1
  import torch
2
+ from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
3
  import streamlit as st
4
  import fitz # PyMuPDF for PDF text extraction
5
 
 
19
 
20
  @st.cache_resource(show_spinner=True)
21
  def load_model(name):
22
+ if "t5" in name.lower():
23
+ tokenizer = AutoTokenizer.from_pretrained(name, use_fast=False)
24
+ model = AutoModelForSeq2SeqLM.from_pretrained(name)
25
+ return pipeline("text2text-generation", model=model, tokenizer=tokenizer)
26
+ elif "pegasus" in name.lower():
27
+ return pipeline("text2text-generation", model=name)
28
+ else:
29
+ return pipeline("text-generation", model=name)
30
 
31
  simplifier = load_model(model_name)
32