Spaces:
Running
Running
| import os | |
| import sys | |
| import inspect | |
| import json | |
| from json import JSONDecodeError | |
| import tiktoken | |
| import random | |
| import google.generativeai as palm | |
| currentdir = os.path.dirname(os.path.abspath( | |
| inspect.getfile(inspect.currentframe()))) | |
| parentdir = os.path.dirname(currentdir) | |
| sys.path.append(parentdir) | |
| from prompt_catalog import PromptCatalog | |
| from general_utils import num_tokens_from_string | |
| """ | |
| DEPRECATED: | |
| Safety setting regularly block a response, so set to 4 to disable | |
| class HarmBlockThreshold(Enum): | |
| HARM_BLOCK_THRESHOLD_UNSPECIFIED = 0 | |
| BLOCK_LOW_AND_ABOVE = 1 | |
| BLOCK_MEDIUM_AND_ABOVE = 2 | |
| BLOCK_ONLY_HIGH = 3 | |
| BLOCK_NONE = 4 | |
| """ | |
| SAFETY_SETTINGS = [ | |
| { | |
| "category": "HARM_CATEGORY_DEROGATORY", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_TOXICITY", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_VIOLENCE", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_SEXUAL", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_MEDICAL", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| { | |
| "category": "HARM_CATEGORY_DANGEROUS", | |
| "threshold": "BLOCK_NONE", | |
| }, | |
| ] | |
| PALM_SETTINGS = { | |
| 'model': 'models/text-bison-001', | |
| 'temperature': 0, | |
| 'candidate_count': 1, | |
| 'top_k': 40, | |
| 'top_p': 0.95, | |
| 'max_output_tokens': 8000, | |
| 'stop_sequences': [], | |
| 'safety_settings': SAFETY_SETTINGS, | |
| } | |
| PALM_SETTINGS_REDO = { | |
| 'model': 'models/text-bison-001', | |
| 'temperature': 0.05, | |
| 'candidate_count': 1, | |
| 'top_k': 40, | |
| 'top_p': 0.95, | |
| 'max_output_tokens': 8000, | |
| 'stop_sequences': [], | |
| 'safety_settings': SAFETY_SETTINGS, | |
| } | |
| def OCR_to_dict_PaLM(logger, OCR, prompt_version, VVE): | |
| try: | |
| logger.info(f'Length of OCR raw -- {len(OCR)}') | |
| except: | |
| print(f'Length of OCR raw -- {len(OCR)}') | |
| # prompt = PROMPT_PaLM_UMICH_skeleton_all_asia(OCR, in_list, out_list) # must provide examples to PaLM differently than for chatGPT, at least 2 examples | |
| Prompt = PromptCatalog(OCR) | |
| if prompt_version in ['prompt_v2_palm2']: | |
| version = 'v2' | |
| prompt = Prompt.prompt_v2_palm2(OCR) | |
| elif prompt_version in ['prompt_v1_palm2',]: | |
| version = 'v1' | |
| # create input: output: for PaLM | |
| # Find a similar example from the domain knowledge | |
| domain_knowledge_example = VVE.query_db(OCR, 4) | |
| similarity= VVE.get_similarity() | |
| domain_knowledge_example_string = json.dumps(domain_knowledge_example) | |
| in_list, out_list = create_OCR_analog_for_input(domain_knowledge_example) | |
| prompt = Prompt.prompt_v1_palm2(in_list, out_list, OCR) | |
| elif prompt_version in ['prompt_v1_palm2_noDomainKnowledge',]: | |
| version = 'v1' | |
| prompt = Prompt.prompt_v1_palm2_noDomainKnowledge(OCR) | |
| else: | |
| version = 'custom' | |
| prompt, n_fields, xlsx_headers = Prompt.prompt_v2_custom(prompt_version, OCR=OCR, is_palm=True) | |
| # raise | |
| nt = num_tokens_from_string(prompt, "cl100k_base") | |
| # try: | |
| logger.info(f'Prompt token length --- {nt}') | |
| # except: | |
| # print(f'Prompt token length --- {nt}') | |
| do_use_SOP = False ######## | |
| if do_use_SOP: | |
| '''TODO: Check back later to see if LangChain will support PaLM''' | |
| # logger.info(f'Waiting for PaLM API call --- Using StructuredOutputParser') | |
| # response = structured_output_parser(OCR, prompt, logger) | |
| # return response['Dictionary'] | |
| pass | |
| else: | |
| # try: | |
| logger.info(f'Waiting for PaLM 2 API call') | |
| # except: | |
| # print(f'Waiting for PaLM 2 API call --- Content') | |
| # safety_thresh = 4 | |
| # PaLM_settings = {'model': 'models/text-bison-001','temperature': 0,'candidate_count': 1,'top_k': 40,'top_p': 0.95,'max_output_tokens': 8000,'stop_sequences': [], | |
| # 'safety_settings': [{"category":"HARM_CATEGORY_DEROGATORY","threshold":safety_thresh},{"category":"HARM_CATEGORY_TOXICITY","threshold":safety_thresh},{"category":"HARM_CATEGORY_VIOLENCE","threshold":safety_thresh},{"category":"HARM_CATEGORY_SEXUAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_MEDICAL","threshold":safety_thresh},{"category":"HARM_CATEGORY_DANGEROUS","threshold":safety_thresh}],} | |
| response = palm.generate_text(prompt=prompt, **PALM_SETTINGS) | |
| if response and response.result: | |
| if isinstance(response.result, (str, bytes)): | |
| response_valid = check_and_redo_JSON(response, logger, version) | |
| else: | |
| response_valid = {} | |
| else: | |
| response_valid = {} | |
| logger.info(f'Candidate JSON\n{response.result}') | |
| return response_valid, nt | |
| def check_and_redo_JSON(response, logger, version): | |
| try: | |
| response_valid = json.loads(response.result) | |
| logger.info(f'Response --- First call passed') | |
| return response_valid | |
| except JSONDecodeError: | |
| try: | |
| response_valid = json.loads(response.result.strip('```').replace('json\n', '', 1).replace('json', '', 1)) | |
| logger.info(f'Response --- Manual removal of ```json succeeded') | |
| return response_valid | |
| except: | |
| logger.info(f'Response --- First call failed. Redo...') | |
| Prompt = PromptCatalog() | |
| if version == 'v1': | |
| prompt_redo = Prompt.prompt_palm_redo_v1(response.result) | |
| elif version == 'v2': | |
| prompt_redo = Prompt.prompt_palm_redo_v2(response.result) | |
| elif version == 'custom': | |
| prompt_redo = Prompt.prompt_v2_custom_redo(response.result, is_palm=True) | |
| # prompt_redo = PROMPT_PaLM_Redo(response.result) | |
| try: | |
| response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS) | |
| response_valid = json.loads(response.result) | |
| logger.info(f'Response --- Second call passed') | |
| return response_valid | |
| except JSONDecodeError: | |
| logger.info(f'Response --- Second call failed. Final redo. Temperature changed to 0.05') | |
| try: | |
| response = palm.generate_text(prompt=prompt_redo, **PALM_SETTINGS_REDO) | |
| response_valid = json.loads(response.result) | |
| logger.info(f'Response --- Third call passed') | |
| return response_valid | |
| except JSONDecodeError: | |
| return None | |
| def create_OCR_analog_for_input(domain_knowledge_example): | |
| in_list = [] | |
| out_list = [] | |
| # Iterate over the domain_knowledge_example (list of dictionaries) | |
| for row_dict in domain_knowledge_example: | |
| # Convert the dictionary to a JSON string and add it to the out_list | |
| domain_knowledge_example_string = json.dumps(row_dict) | |
| out_list.append(domain_knowledge_example_string) | |
| # Create a single string from all values in the row_dict | |
| row_text = '||'.join(str(v) for v in row_dict.values()) | |
| # Split the row text by '||', shuffle the parts, and then re-join with a single space | |
| parts = row_text.split('||') | |
| random.shuffle(parts) | |
| shuffled_text = ' '.join(parts) | |
| # Add the shuffled_text to the in_list | |
| in_list.append(shuffled_text) | |
| return in_list, out_list | |
| def strip_problematic_chars(s): | |
| return ''.join(c for c in s if c.isprintable()) | |