File size: 4,394 Bytes
b72ab63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""

from __future__ import annotations

from typing import TYPE_CHECKING

from gradio.pipelines_utils import (
    handle_diffusers_pipeline,
    handle_transformers_js_pipeline,
    handle_transformers_pipeline,
)

if TYPE_CHECKING:
    import diffusers
    import transformers


def load_from_pipeline(
    pipeline: transformers.Pipeline | diffusers.DiffusionPipeline,  # type: ignore
) -> dict:
    """
    Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline or diffusers.DiffusionPipeline.
    pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
    Returns:
    (dict): a dictionary of kwargs that can be used to construct an Interface object
    """

    if str(type(pipeline).__module__).startswith("transformers.pipelines."):
        pipeline_info = handle_transformers_pipeline(pipeline)
    elif str(type(pipeline).__module__).startswith("diffusers.pipelines."):
        pipeline_info = handle_diffusers_pipeline(pipeline)
    else:
        raise ValueError(
            "pipeline must be a transformers.pipeline or diffusers.pipeline"
        )

    def fn(*params):
        if pipeline_info:
            data = pipeline_info["preprocess"](*params)
            if str(type(pipeline).__module__).startswith("transformers.pipelines"):
                from transformers import pipelines

                # special cases that needs to be handled differently
                if isinstance(
                    pipeline,
                    (
                        pipelines.text_classification.TextClassificationPipeline,
                        pipelines.text2text_generation.Text2TextGenerationPipeline,
                        pipelines.text2text_generation.TranslationPipeline,
                    ),
                ):
                    data = pipeline(*data)
                else:
                    data = pipeline(**data)  # type: ignore
                # special case for object-detection
                # original input image sent to postprocess function
                if isinstance(
                    pipeline,
                    pipelines.object_detection.ObjectDetectionPipeline,
                ):
                    output = pipeline_info["postprocess"](data, params[0])
                else:
                    output = pipeline_info["postprocess"](data)
                return output

            elif str(type(pipeline).__module__).startswith("diffusers.pipelines"):
                data = pipeline(**data)  # type: ignore
                output = pipeline_info["postprocess"](data)
                return output
        else:
            raise ValueError("pipeline_info can not be None.")

    interface_info = pipeline_info.copy() if pipeline_info else {}
    interface_info["fn"] = fn
    del interface_info["preprocess"]
    del interface_info["postprocess"]

    # define the title/description of the Interface
    interface_info["title"] = (
        pipeline.model.config.name_or_path
        if str(type(pipeline).__module__).startswith("transformers.pipelines")
        else pipeline.__class__.__name__
    )

    return interface_info


def load_from_js_pipeline(pipeline) -> dict:
    if str(type(pipeline).__module__).startswith("transformers_js_py."):
        pipeline_info = handle_transformers_js_pipeline(pipeline)
    else:
        raise ValueError("pipeline must be a transformers_js_py's pipeline")

    async def fn(*params):
        preprocess = pipeline_info["preprocess"]
        postprocess = pipeline_info["postprocess"]
        postprocess_takes_inputs = pipeline_info.get("postprocess_takes_inputs", False)

        preprocessed_params = preprocess(*params) if preprocess else params
        pipeline_output = await pipeline(*preprocessed_params)
        postprocessed_output = (
            postprocess(pipeline_output, *(params if postprocess_takes_inputs else ()))
            if postprocess
            else pipeline_output
        )

        return postprocessed_output

    interface_info = {
        "fn": fn,
        "inputs": pipeline_info["inputs"],
        "outputs": pipeline_info["outputs"],
        "title": f"{pipeline.task} ({pipeline.model.config._name_or_path})",
    }
    return interface_info